Skip to content

Commit

Permalink
Merge pull request #138 from tud-amr/ft-inputs-refactor
Browse files Browse the repository at this point in the history
Refactoring input handling for casadi function wrapper.
  • Loading branch information
maxspahn authored Jul 26, 2024
2 parents 3f2b34a + ca0585d commit 4545f21
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 59 deletions.
1 change: 1 addition & 0 deletions examples/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.pbz2
*.c
panda_local
MUJOCO_LOG.TXT
7 changes: 3 additions & 4 deletions examples/panda_capsules_cuboid.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,9 @@ def set_planner(goal: GoalComposition, degrees_of_freedom: int = 7, obstacle_res
tf_capsule_origin = forward_kinematics.casadi(
q, link_name, link_transformation=tf,
)
# planner.add_capsule_sphere_geometry(
# "obst_1", f"capsule_{i}", tf_capsule_origin, length
# )
#todo: WHEN UNCOMMENTING THIS, i GET ACTIONS NAN, AN ERROR SOMEWHERE!
planner.add_capsule_sphere_geometry(
"obst_1", f"capsule_{i}", tf_capsule_origin, length
)
planner.add_capsule_cuboid_geometry(
"obst_cuboid_1", f"capsule_{i}", tf_capsule_origin, length
)
Expand Down
3 changes: 1 addition & 2 deletions examples/panda_ring.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def set_planner(goal: GoalComposition, degrees_of_freedom: int = 7, obstacle_res
[-0.0175, 3.7525],
[-2.8973, 2.8973]
]
collision_links = ['panda_link1', 'panda_link4', 'panda_link6', 'vacuum_link']
collision_links = ['panda_link4', 'panda_link6', 'vacuum_link']
self_collision_pairs = {}
# The planner hides all the logic behind the function set_components.
planner.set_components(
Expand Down Expand Up @@ -207,7 +207,6 @@ def run_panda_ring_example(n_steps=5000, render=True, serialize=False, planner=N
weight_goal_0=ob_robot["FullSensor"]["goals"][obstacle_resolution_ring+3]["weight"],
x_goal_1=ob_robot["FullSensor"]["goals"][obstacle_resolution_ring+4]["position"],
weight_goal_1=ob_robot["FullSensor"]["goals"][obstacle_resolution_ring+4]["weight"],
radius_body_panda_link1=0.1,
radius_body_panda_link4=0.1,
radius_body_panda_link6=0.15,
radius_body_vacuum_link=0.1,
Expand Down
95 changes: 43 additions & 52 deletions fabrics/helpers/casadiFunctionWrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pdb
import casadi as ca
from fabrics.helpers.variables import Variables
import numpy as np
Expand All @@ -23,68 +22,95 @@ def __init__(self, name: str, variables: Variables, expressions: dict):
self.create_function()

def create_function(self):
self._input_keys = sorted(tuple(self._inputs.keys()))
self._input_sizes = {i: self._inputs[i].size() for i in self._inputs}
self._list_expressions = [self._expressions[i] for i in sorted(self._expressions.keys())]
input_expressions = [self._inputs[i] for i in self._input_keys]
self._function = ca.Function(self._name, input_expressions, self._list_expressions)
input_values = []
input_keys = []
expression_keys = []
expression_values = []
for input_key, input_value in self._inputs.items():
input_keys.append(input_key)
input_values.append(input_value)
for expression_key, expression_value in self._expressions.items():
expression_keys.append(expression_key)
expression_values.append(expression_value)
self._function = ca.Function(self._name, input_values, expression_values, input_keys, expression_keys)

def function(self) -> ca.Function:
return self._function

def serialize(self, file_name):
with bz2.BZ2File(file_name, 'w') as f:
pickle.dump(self._function.serialize(), f)
pickle.dump(list(self._expressions.keys()), f)
pickle.dump(self._input_keys, f)
pickle.dump(self._argument_dictionary, f)

def evaluate(self, **kwargs):
self.process_inputs(**kwargs)
try:
output_dict = self._function(**self._argument_dictionary)
except NotImplementedError:
expected_inputs = list(self._inputs.keys())
received_inputs = list(self._argument_dictionary.keys())
unique_expected = [x for x in expected_inputs if x not in received_inputs]
unique_received = [x for x in received_inputs if x not in expected_inputs]

msg = "Inputs do not match\n"
msg += f"Found unexpected inputs: {unique_received}\n"
msg += f"Found missing inputs: {unique_expected}\n"
raise InputMissmatchError(msg)
for key, value in output_dict.items():
if value.size() == (1, 1):
output_dict[key] = np.array(value)[:, 0]
elif value.size()[1] == 1:
output_dict[key] = np.array(value)[:, 0]
else:
output_dict[key] = np.array(value)
return output_dict

def process_inputs(self, **kwargs):
for key in kwargs: # pragma no cover
if key == 'x_obst' or key == 'x_obsts':
obstacle_dictionary = {}
for j, x_obst_j in enumerate(kwargs[key]):
obstacle_dictionary[f'x_obst_{j}'] = x_obst_j
self._argument_dictionary.update(obstacle_dictionary)
if key == 'radius_obst' or key == 'radius_obsts':
elif key == 'radius_obst' or key == 'radius_obsts':
radius_dictionary = {}
for j, radius_obst_j in enumerate(kwargs[key]):
radius_dictionary[f'radius_obst_{j}'] = radius_obst_j
self._argument_dictionary.update(radius_dictionary)
if key == 'x_obst_dynamic' or key == 'x_obsts_dynamic':
elif key == 'x_obst_dynamic' or key == 'x_obsts_dynamic':
obstacle_dyn_dictionary = {}
for j, x_obst_dyn_j in enumerate(kwargs[key]):
obstacle_dyn_dictionary[f'x_obst_dynamic_{j}'] = x_obst_dyn_j
self._argument_dictionary.update(obstacle_dyn_dictionary)
if key == 'xdot_obst_dynamic' or key == 'xdot_obsts_dynamic':
elif key == 'xdot_obst_dynamic' or key == 'xdot_obsts_dynamic':
xdot_dyn_dictionary = {}
for j, xdot_obst_dyn_j in enumerate(kwargs[key]):
xdot_dyn_dictionary[f'xdot_obst_dynamic_{j}'] = xdot_obst_dyn_j
self._argument_dictionary.update(xdot_dyn_dictionary)
if key == 'xddot_obst_dynamic' or key == 'xddot_obsts_dynamic':
elif key == 'xddot_obst_dynamic' or key == 'xddot_obsts_dynamic':
xddot_dyn_dictionary = {}
for j, xddot_obst_dyn_j in enumerate(kwargs[key]):
xddot_dyn_dictionary[f'xddot_obst_dynamic_{j}'] = xddot_obst_dyn_j
self._argument_dictionary.update(xddot_dyn_dictionary)
if key == 'radius_obst_dynamic' or key == 'radius_obsts_dynamic':
elif key == 'radius_obst_dynamic' or key == 'radius_obsts_dynamic':
radius_dyn_dictionary = {}
for j, radius_obst_dyn_j in enumerate(kwargs[key]):
radius_dyn_dictionary[f'radius_obst_dynamic_{j}'] = radius_obst_dyn_j
self._argument_dictionary.update(radius_dyn_dictionary)
if key == 'x_obst_cuboid' or key == 'x_obsts_cuboid':
elif key == 'x_obst_cuboid' or key == 'x_obsts_cuboid':
x_obst_cuboid_dictionary = {}
for j, x_obst_cuboid_j in enumerate(kwargs[key]):
x_obst_cuboid_dictionary[f'x_obst_cuboid_{j}'] = x_obst_cuboid_j
self._argument_dictionary.update(x_obst_cuboid_dictionary)
if key == 'size_obst_cuboid' or key == 'size_obsts_cuboid':
elif key == 'size_obst_cuboid' or key == 'size_obsts_cuboid':
size_obst_cuboid_dictionary = {}
for j, size_obst_cuboid_j in enumerate(kwargs[key]):
size_obst_cuboid_dictionary[f'size_obst_cuboid_{j}'] = size_obst_cuboid_j
self._argument_dictionary.update(size_obst_cuboid_dictionary)
if key.startswith('radius_body') and key.endswith('links'):
elif key.startswith('radius_body') and key.endswith('links'):
# Radius bodies can be passed using a dictionary where the keys are simple integers.
radius_body_dictionary = {}
body_size_inputs = [input_exp for input_exp in self._input_keys if input_exp.startswith('radius_body')]
body_size_inputs = [input_exp for input_exp in list(self._inputs.keys()) if input_exp.startswith('radius_body')]
for link_nr, radius_body_j in kwargs[key].items():
try:
key = [body_size_input for body_size_input in body_size_inputs if str(link_nr) in body_size_input][0]
Expand All @@ -94,36 +120,6 @@ def evaluate(self, **kwargs):
self._argument_dictionary.update(radius_body_dictionary)
else:
self._argument_dictionary[key] = kwargs[key]
input_arrays = []
try:
for i in self._input_keys:
"""
if not self._argument_dictionary[i].size == self._input_sizes[i][0] * self._input_sizes[i][1]:
raise InputMissmatchError(f"Size of input argument {i} with size {self._argument_dictionary[i].size} does not match size required {self._input_sizes[i][0]}")
"""
input_arrays.append(self._argument_dictionary[i])
input_arrays = [self._argument_dictionary[i] for i in self._input_keys]
except KeyError as e:
msg = f"Key {e} is not contained in the inputs\n"
msg += f"Possible keys are {self._input_keys}\n"
msg += f"You provided {list(kwargs.keys())}\n"
raise InputMissmatchError(msg)
try:
list_array_outputs = self._function(*input_arrays)
except RuntimeError as runtime_error:
raise InputMissmatchError(runtime_error.args)
output_dict = {}
if isinstance(list_array_outputs, ca.DM):
return {list(self._expressions.keys())[0]: np.array(list_array_outputs)[:, 0]}
for i, key in enumerate(sorted(self._expressions.keys())):
raw_output = list_array_outputs[i]
if raw_output.size() == (1, 1):
output_dict[key] = np.array(raw_output)[:, 0]
elif raw_output.size()[1] == 1:
output_dict[key] = np.array(raw_output)[:, 0]
else:
output_dict[key] = np.array(raw_output)
return output_dict


class CasadiFunctionWrapper_deserialized(CasadiFunctionWrapper):
Expand All @@ -133,12 +129,7 @@ def __init__(self, file_name: str):
logging.info(f"Initializing casadiFunctionWrapper from {file_name}")
data = bz2.BZ2File(file_name, 'rb')
self._function = ca.Function().deserialize(cPickle.load(data))
expression_keys = cPickle.load(data)
self._input_keys = cPickle.load(data)
self._argument_dictionary = cPickle.load(data)
self._expressions = {}
for key in expression_keys:
self._expressions[key] = []
self._isload = True


2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fabrics"
version = "0.9.5"
version = "0.9.6"
description = "Optimization fabrics in python."
authors = ["Max Spahn <[email protected]>"]
readme = "README.md"
Expand Down

0 comments on commit 4545f21

Please sign in to comment.