diff --git a/mjx/mujoco/mjx/_src/collision_driver.py b/mjx/mujoco/mjx/_src/collision_driver.py index 721b4345cf..0944311b85 100644 --- a/mjx/mujoco/mjx/_src/collision_driver.py +++ b/mjx/mujoco/mjx/_src/collision_driver.py @@ -29,6 +29,7 @@ from mujoco.mjx._src.collision_convex import convex_convex from mujoco.mjx._src.collision_convex import plane_convex from mujoco.mjx._src.collision_convex import sphere_convex +from mujoco.mjx._src.warp_capsule_capsule import capsule_capsule_opt from mujoco.mjx._src.collision_primitive import capsule_capsule from mujoco.mjx._src.collision_primitive import plane_capsule from mujoco.mjx._src.collision_primitive import plane_sphere @@ -52,7 +53,7 @@ (GeomType.SPHERE, GeomType.CAPSULE): sphere_capsule, (GeomType.SPHERE, GeomType.BOX): sphere_convex, (GeomType.SPHERE, GeomType.MESH): sphere_convex, - (GeomType.CAPSULE, GeomType.CAPSULE): capsule_capsule, + (GeomType.CAPSULE, GeomType.CAPSULE): capsule_capsule_opt, (GeomType.CAPSULE, GeomType.BOX): capsule_convex, (GeomType.CAPSULE, GeomType.MESH): capsule_convex, (GeomType.BOX, GeomType.BOX): convex_convex, diff --git a/mjx/mujoco/mjx/_src/jax_warp.py b/mjx/mujoco/mjx/_src/jax_warp.py new file mode 100644 index 0000000000..58a1fcae0f --- /dev/null +++ b/mjx/mujoco/mjx/_src/jax_warp.py @@ -0,0 +1,294 @@ +# Copyright 2024 NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp +import ctypes +import jax.numpy as jp +import numpy as np + +jax_warp_p = None + +# Holder for the custom callback to keep it alive. +cc_callback = None +registered_kernels = [None] +registered_kernel_to_id = {} + +def jax_kernel(wp_kernel): + if jax_warp_p == None: + # Create and register the primitive + create_jax_warp_primitive() + if not wp_kernel in registered_kernel_to_id: + id = len(registered_kernels) + registered_kernels.append(wp_kernel) + registered_kernel_to_id[wp_kernel] = id + else: + id = registered_kernel_to_id[wp_kernel] + def bind(*args): + return jax_warp_p.bind(*args, kernel=id) + return bind + +def base_type(t): + while True: + if hasattr(t, 'dtype'): + t = t.dtype + elif hasattr(t, '_type_'): + t = t._type_ + else: + return t + +def warp_custom_callback(stream, buffers, opaque, opaque_len): + # The descriptor is the form + # ||| + # Example: 42|16,32|fv3,2;fm33,2;f,2|f,2;fv3,2 + [kernel_id_str, dim_str, inputs_str, outputs_str] = opaque.decode().split('|') + + # Get the kernel from the registry. + kernel = registered_kernels[int(kernel_id_str)] + + # Parse the dimensions. + dims = [int(d) for d in dim_str.split(',')] + + # The inputs and outputs are of the following form + # [ + # E.g., fm33,16,32 is a 16x32 array of 3x3 float matrices. + # The individual input/output descriptors are semicolon-separated. + args = [] + for i, a in enumerate(inputs_str.split(';') + outputs_str.split(';')): + dtype = None + # Parse base type. + if a[0] == 'f': + dtype = wp.float32 + else: + raise Exception(f'Unknown base type "{a[0]}"') + # Parse vector/matrix type and skip the comma. + if a[1] == 'v': + dtype = wp.types.vector(length=int(a[2]), dtype = dtype) + # Index 3 is comma, let us skip. + assert a[3] == ',' + a = a[4:] + elif a[1] == 'm': + dtype = wp.types.matrix(shape=(int(a[2]), int(a[3])), dtype = dtype) + # Index 4 is comma, let us skip. + assert a[4] == ',' + a = a[5:] + else: + # Index 1 is comma, let us skip. + assert a[1] == ',' + a = a[2:] + # Parse the array shape. + shape = [int(s) for s in a.split(',')] + assert len(shape) > 0, 'Currently only arrays are supported' + # Add the array to the arg list. + args.append(wp.array(ptr = buffers[i], dtype=dtype, shape=shape, owner=False, copy=False)) + + # Launch the kernel on the provided stream. + stream = wp.Stream(cuda_stream=ctypes.c_void_p(stream)) + wp.launch(kernel, dims, args, stream=stream, device="cuda") + +def create_jax_warp_primitive(): + from functools import reduce + import jax + from jax._src.interpreters import batching + from jax.interpreters import mlir + from jax.interpreters.mlir import ir + from jaxlib.hlo_helpers import custom_call + + global jax_warp_p + global cc_callback + + # Create and register the primitive. + # TODO add default implementation that calls the kernel via warp. + jax_warp_p = jax.core.Primitive("jax_warp") + jax_warp_p.multiple_results = True + + # TODO Just launch the kernel directly, but make sure the argument + # shapes are massaged the same way as below so that vmap works. + def impl(*args): + raise Exception('Not implementes') + jax_warp_p.def_impl(impl) + + # Auto-batching. Make sure all the arguments are fully broadcasted + # so that Warp is not confused about dimensions. + def vectorized_multi_batcher(args, dims, **params): + # Figure out the number of outputs. + wp_kernel = registered_kernels[params['kernel']] + output_count = len(wp_kernel.adj.args) - len(args) + shape, dim = next((a.shape, d) for a, d in zip(args, dims) + if d is not None) + size = shape[dim] + args = [batching.bdim_at_front(a, d, size) if len(a.shape) else a + for a, d in zip(args, dims)] + # Create the batched primitive. + return jax_warp_p.bind(*args, **params), [dims[0]] * output_count + batching.primitive_batchers[jax_warp_p] = vectorized_multi_batcher + + def get_mat_vec_shape(warp_type): + if wp.types.type_is_matrix(warp_type.dtype) or wp.types.type_is_vector(warp_type.dtype): + return warp_type.dtype._shape_ + return [] + + def strip_vecmat_dimensions(warp_arg, actual_shape): + shape = get_mat_vec_shape(warp_arg.type) + for i, s in enumerate(reversed(shape)): + item = actual_shape[- i - 1] + if s != item: + raise Exception(f'The vector/matric shape for argument {warp_arg.label} does not match') + return actual_shape[:len(actual_shape) - len(shape)] + + def collapse_into_leading_dimension(warp_arg, actual_shape): + if len(actual_shape) < warp_arg.type.ndim: + raise Exception(f'Argument {warp_arg.label} has too few non-matrix/vector dimensions') + index_rest = len(actual_shape) - warp_arg.type.ndim + 1 + leading_size = reduce(lambda x, y: x * y,actual_shape[:index_rest]) + return [leading_size] + actual_shape[index_rest:] + + # Infer array dimensions from input type. + def infer_dimensions(warp_arg, actual_shape): + actual_shape = strip_vecmat_dimensions(warp_arg, actual_shape) + return collapse_into_leading_dimension(warp_arg, actual_shape) + + # Abstract evaluation. + def jax_warp_abstract(*args, kernel=None): + wp_kernel = registered_kernels[kernel] + dtype = jax.dtypes.canonicalize_dtype(args[0].dtype) + # All the extra arguments to the warp kernel are outputs. + outputs = [ o.type for o in wp_kernel.adj.args[len(args):] ] + # Let's just use the first input dimension to infer the output's dimensions. + dims = strip_vecmat_dimensions(wp_kernel.adj.args[0], list(args[0].shape)) + return [ jax.core.ShapedArray(list(dims) + list(get_mat_vec_shape(o)), dtype) for o in outputs ] + jax_warp_p.def_abstract_eval(jax_warp_abstract) + + # Lowering to MLIR. + + # Create python-land custom call target. + CCALLFUNC = ctypes.CFUNCTYPE(ctypes.c_voidp, ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_char_p, ctypes.c_size_t) + cc_callback = CCALLFUNC(warp_custom_callback) + ccall_address = ctypes.cast(cc_callback, ctypes.c_void_p) + + # Put the custom call into a capsule, as required by XLA. + PyCapsule_Destructor = ctypes.CFUNCTYPE(None, ctypes.py_object) + PyCapsule_New = ctypes.pythonapi.PyCapsule_New + PyCapsule_New.restype = ctypes.py_object + PyCapsule_New.argtypes = (ctypes.c_void_p, ctypes.c_char_p, PyCapsule_Destructor) + capsule = PyCapsule_New(ccall_address.value, + b"xla._CUSTOM_CALL_TARGET", PyCapsule_Destructor(0)) + + # Register the callback in XLA. + jax.lib.xla_client.register_custom_call_target("warp_call", capsule, platform="gpu") + + def default_layout(shape): + return range(len(shape) - 1, -1, -1) + + def warp_call_lowering(ctx, *args, kernel=None): + if not kernel: + raise Exception('Unknown kernel ' + str(kernel)) + wp_kernel = registered_kernels[kernel] + + # TODO This may not be necessary, but it is perhaps better not to be + # mucking with kernel loading while already running the workload. + module = wp_kernel.module + device = "cuda" + if not module.load(device): + raise Exception("Could not load kernel on device") + + # Infer dimensions from the first input. + warp_arg0 = wp_kernel.adj.args[0] + actual_shape0 = ir.RankedTensorType(args[0].type).shape + dims = strip_vecmat_dimensions(warp_arg0, actual_shape0) + warp_dims = collapse_into_leading_dimension(warp_arg0, dims) + + # Figure out the types and sizes and create the descriptor for the inputs. + i = 0 + input_descriptors = [] + operand_layouts = [] + for actual, warg in zip(args, wp_kernel.adj.args): + # Check supported cases. + wtype = warg.type + if not wp.types.is_array(wtype): + raise Exception('Only arrays are supported') + if base_type(wtype) == 'f': + if str(ir.RankedTensorType(actual.type).element_type) != 'f32': + raise Exception(f'Unexpected base type for {warg.label}') + else: + raise Exception(f'Currently only float32 is supported') + + # Add the base type to the descriptor. + desc = base_type(warg.type) + # Add vector/matrix types. + shape = [] + if wp.types.type_is_matrix(wtype.dtype): + desc += 'm' + if wp.types.type_is_vector(wtype.dtype): + desc += 'v' + # Get matrix/vector shapes and check that they fit. + if wp.types.type_is_matrix(wtype.dtype) or wp.types.type_is_vector(wtype.dtype): + shape = wtype.dtype._shape_ + desc += ''.join([str(s) for s in shape]) + # Infer array dimension (by removing the vector/matrix dimensions and + # collapsing the initial dimensions). + array_shape = infer_dimensions(warg, ir.RankedTensorType(actual.type).shape) + desc += ',' + ','.join([str(s) for s in array_shape]) + input_descriptors.append(desc) + operand_layouts.append(default_layout(ir.RankedTensorType(actual.type).shape)) + i += 1 + + # Infer dimensions from the first input. + output_descriptors = [] + result_types = [] + result_layouts = [] + for warg in wp_kernel.adj.args[len(args):]: + wtype = warg.type + # Add base type to descriptor. + desc = base_type(warg.type) + # Add vector/matrix types to descriptor if needed. + shape = [] + if wp.types.type_is_matrix(wtype.dtype): + desc += 'm' + if wp.types.type_is_vector(wtype.dtype): + desc += 'v' + # Get matrix/vector shapes and check that they fit. + if wp.types.type_is_matrix(wtype.dtype) or wp.types.type_is_vector(wtype.dtype): + shape = wtype.dtype._shape_ + desc += ''.join([str(s) for s in shape]) + # Add the dimensions (harvested from the first input). + desc += ',' + ','.join([str(s) for s in warp_dims]) + output_descriptors.append(desc) + result_shape = list(dims) + list(shape) + result_types.append((ir.RankedTensorType.get(result_shape, ir.F32Type.get()))) + result_layouts.append(default_layout(result_shape)) + + # Build the full descriptor. + descriptor = str(kernel) + "|" + ','.join([str(d) for d in warp_dims]) + "|" + descriptor += ';'.join(input_descriptors) + '|' + descriptor += ';'.join(output_descriptors) + print("Descriptor for custom call: ", descriptor) + + out = custom_call( + b"warp_call", + result_types=result_types, + operands=args, + backend_config=descriptor.encode('utf-8'), + operand_layouts=operand_layouts, + result_layouts=result_layouts, + ).results + return out + + mlir.register_lowering( + jax_warp_p, + warp_call_lowering, + platform="gpu", + ) diff --git a/mjx/mujoco/mjx/_src/warp_capsule_capsule.py b/mjx/mujoco/mjx/_src/warp_capsule_capsule.py new file mode 100644 index 0000000000..bf170e93a2 --- /dev/null +++ b/mjx/mujoco/mjx/_src/warp_capsule_capsule.py @@ -0,0 +1,165 @@ +# Copyright 2024 NVIDIA CORPORATION. +# Copyright 2023 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import warp as wp +from typing import Optional, Tuple, Union + +import jax +import jax.numpy as jp +import numpy as np +from mujoco.mjx._src.jax_warp import jax_kernel +from mujoco.mjx._src.collision_base import GeomInfo + +wp.init() + +@wp.func +def orthogonals(a: wp.vec3): + y, z = wp.vec3(0., 1., 0.), wp.vec3(0., 0., 1.) + b = wp.select((-0.5 < a[1]) and (a[1] < 0.5), z, y) + b = b - a * wp.dot(a, b) + b = wp.normalize(b) + return b, wp.cross(a, b) + +@wp.func +def make_frame(a: wp.vec3) -> wp.mat33: + """Makes a right-handed 3D frame given a direction.""" + a = wp.normalize(a) + b, c = orthogonals(a) + return wp.mat33(a, b, c) + +@wp.func +def closest_segment_point(a: wp.vec3, b: wp.vec3, pt: wp.vec3): + """Returns the closest point on the a-b line segment to a point pt.""" + ab = b - a + t = wp.dot(pt - a, ab) / (wp.dot(ab, ab) + 1e-6) + return a + wp.clamp(t, 0.0, 1.0) * ab + +@wp.func +def closest_segment_point_and_dist(a: wp.vec3, b: wp.vec3, pt: wp.vec3): + """Returns closest point on the line segment and the distance squared.""" + closest = closest_segment_point(a, b, pt) + dist = wp.dot(pt - closest, pt - closest) + return closest, dist + +@wp.func +def closest_segment_to_segment_points(a0: wp.vec3, a1: wp.vec3, b0: wp.vec3, b1: wp.vec3): + """Returns closest points between two line segments.""" + # Gets the closest segment points by first finding the closest points + # between two lines. Points are then clipped to be on the line segments + # and edge cases with clipping are handled. + len_a = wp.length(a1 - a0) + dir_a = wp.normalize(a1 - a0) + len_b = wp.length(b1 - b0) + dir_b = wp.normalize(b1 - b0) + + # Segment mid-points. + half_len_a = len_a * 0.5 + half_len_b = len_b * 0.5 + a_mid = a0 + dir_a * half_len_a + b_mid = b0 + dir_b * half_len_b + + # Translation between two segment mid-points. + trans = a_mid - b_mid + + # Parametrize points on each line as follows: + # point_on_a = a_mid + t_a * dir_a + # point_on_b = b_mid + t_b * dir_b + # and analytically minimize the distance between the two points. + dira_dot_dirb = wp.dot(dir_a, dir_b) + dira_dot_trans = wp.dot(dir_a, trans) + dirb_dot_trans = wp.dot(dir_b, trans) + denom = 1.0 - dira_dot_dirb * dira_dot_dirb + + orig_t_a = (-dira_dot_trans + dira_dot_dirb * dirb_dot_trans) / (denom + 1e-6) + orig_t_b = dirb_dot_trans + orig_t_a * dira_dot_dirb + t_a = wp.clamp(orig_t_a, -half_len_a, half_len_a) + t_b = wp.clamp(orig_t_b, -half_len_b, half_len_b) + + best_a = a_mid + dir_a * t_a + best_b = b_mid + dir_b * t_b + + # Resolve edge cases where both closest points are clipped to the segment + # endpoints by recalculating the closest segment points for the current + # clipped points, and then picking the pair of points with smallest + # distance. An example of this edge case is when lines intersect but line + # segments don't. + new_a, d1 = closest_segment_point_and_dist(a0, a1, best_b) + new_b, d2 = closest_segment_point_and_dist(b0, b1, best_a) + best_a = wp.select(d1 < d2, best_a, new_a) + best_b = wp.select(d1 < d2, new_b, best_b) + + return best_a, best_b + +@wp.func +def _sphere_sphere(pos1: wp.vec3, radius1: float, pos2: wp.vec3, radius2: float): + """Returns the penetration, contact point, and normal between two spheres.""" + dist = wp.length(pos2 - pos1) + n = wp.normalize(pos2 - pos1) + n = wp.select(dist == 0.0, n, wp.vec3(1.0, 0.0, 0.0)) + dist = dist - (radius1 + radius2) + pos = pos1 + n * (radius1 + dist * 0.5) + return dist, pos, n + +@wp.func +def capsule_capsule(cap1_pos: wp.vec3, cap1_mat: wp.mat33, cap1_size: wp.vec3, + cap2_pos: wp.vec3, cap2_mat: wp.mat33, cap2_size: wp.vec3): + """Calculates one contact between two capsules.""" + axis1, length1, axis2, length2 = ( + wp.transpose(cap1_mat)[2], + cap1_size[1], + wp.transpose(cap2_mat)[2], + cap2_size[1], + ) + seg1, seg2 = axis1 * length1, axis2 * length2 + pt1, pt2 = closest_segment_to_segment_points( + cap1_pos - seg1, + cap1_pos + seg1, + cap2_pos - seg2, + cap2_pos + seg2, + ) + radius1, radius2 = cap1_size[0], cap2_size[0] + dist, pos, n = _sphere_sphere(pt1, radius1, pt2, radius2) + return dist, pos, make_frame(n) + +@wp.kernel +def capsule_capsule_kernel( + cap1_pos: wp.array(dtype=wp.vec3), + cap1_mat: wp.array(dtype=wp.mat33), + cap1_size: wp.array(dtype=wp.vec3), + cap2_pos: wp.array(dtype=wp.vec3), + cap2_mat: wp.array(dtype=wp.mat33), + cap2_size: wp.array(dtype=wp.vec3), + dist_out: wp.array(dtype=float), + pos_out: wp.array(dtype=wp.vec3), + frame_out: wp.array(dtype=wp.mat33)): + tid = wp.tid() + dist, pos, frame = capsule_capsule(cap1_pos[tid], cap1_mat[tid], cap1_size[tid], + cap2_pos[tid], cap2_mat[tid], cap2_size[tid]) + dist_out[tid] = dist + pos_out[tid] = pos + frame_out[tid] = wp.transpose(frame) + +capsule_capsule_warp = jax_kernel(capsule_capsule_kernel) +# @jax.jit + +def capsule_capsule_opt(g1: GeomInfo, g2: GeomInfo): + # Hack: Lift to operate on 1-element array and leave it to the batching machinery + # to lift to multiple dimensions. + (jpos1, jmat1, jsize1) = jax.tree_map(lambda x: jp.expand_dims(x, axis=0), jax.tree.flatten(g1)[0]) + (jpos2, jmat2, jsize2) = jax.tree_map(lambda x: jp.expand_dims(x, axis=0), jax.tree.flatten(g2)[0]) + return capsule_capsule_warp(jpos1, jmat1, jsize1, jpos2, jmat2, jsize2) + +capsule_capsule_opt.ncon = 1 diff --git a/mjx/pyproject.toml b/mjx/pyproject.toml index b4af482d8e..0d8f073142 100644 --- a/mjx/pyproject.toml +++ b/mjx/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "mujoco>=3.1.3.dev0", "scipy", "trimesh", + "warp-lang", ] [project.scripts]