You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When device=None (GPU case), the following (extremely long) Jax error message shows up. Note that it shows up before any data is processed, as there is still an error when the pkl files already exist:
KeyError: ('jnp_pair_rmsd', { lambda ; a:f32[14000,51,3] b:f32[100,51,3]. let
c:f32[14000,51,3] = copy a
d:f32[100,51,3] = copy b
e:f32[14000,3] = reduce_sum[axes=(1,)] c
f:f32[14000,3] = div e 51.0
g:f32[14000,1,3] = broadcast_in_dim[
broadcast_dimensions=(0, 2)
shape=(14000, 1, 3)
] f
h:f32[14000,51,3] = sub c g
i:f32[100,3] = reduce_sum[axes=(1,)] d
j:f32[100,3] = div i 51.0
k:f32[100,1,3] = broadcast_in_dim[
broadcast_dimensions=(0, 2)
shape=(100, 1, 3)
] j
l:f32[100,51,3] = sub d k
m:f32[14000,3,51] = transpose[permutation=(0, 2, 1)] h
n:f32[14000,3,100,3] = dot_general[dimension_numbers=(([2], [1]), ([], []))] m
l
o:f32[100,14000,3,3] p:f32[100,14000,3,3] = xla_call[
call_jaxpr={ lambda ; q:f32[14000,3,100,3]. let
r:f32[100,14000,3,3] = transpose[permutation=(2, 0, 1, 3)] q
s:f32[100,14000,3] t:f32[100,14000,3,3] u:f32[100,14000,3,3] = svd[
compute_uv=True
full_matrices=False
] r
in (t, u) }
name=svd
] n
v:f32[100,14000,3,3] = transpose[permutation=(0, 1, 3, 2)] p
w:f32[100,14000,3,3] = transpose[permutation=(0, 1, 3, 2)] o
x:f32[100,14000,3,3] = dot_general[
dimension_numbers=(([3], [2]), ([0, 1], [0, 1]))
] v w
y:f32[100,14000] = custom_jvp_call[
call_jaxpr={ lambda ; z:f32[100,14000,3,3]. let
ba:f32[100,14000] = xla_call[
call_jaxpr={ lambda ; bb:f32[100,14000,3,3]. let
bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
be:i32[2] = concatenate[dimension=0] bc bd
bf:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb be
bg:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
bh:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
bi:i32[2] = concatenate[dimension=0] bg bh
bj:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb bi
bk:f32[100,14000] = mul bf bj
bl:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
bm:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
bn:i32[2] = concatenate[dimension=0] bl bm
bo:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb bn
bp:f32[100,14000] = mul bk bo
bq:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
br:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
bs:i32[2] = concatenate[dimension=0] bq br
bt:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb bs
bu:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
bv:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
bw:i32[2] = concatenate[dimension=0] bu bv
bx:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb bw
by:f32[100,14000] = mul bt bx
bz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
ca:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
cb:i32[2] = concatenate[dimension=0] bz ca
cc:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb cb
cd:f32[100,14000] = mul by cc
ce:f32[100,14000] = add bp cd
cf:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
cg:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
ch:i32[2] = concatenate[dimension=0] cf cg
ci:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb ch
cj:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
ck:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
cl:i32[2] = concatenate[dimension=0] cj ck
cm:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb cl
cn:f32[100,14000] = mul ci cm
co:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
cp:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
cq:i32[2] = concatenate[dimension=0] co cp
cr:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb cq
cs:f32[100,14000] = mul cn cr
ct:f32[100,14000] = add ce cs
cu:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
cv:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
cw:i32[2] = concatenate[dimension=0] cu cv
cx:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb cw
cy:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
cz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
da:i32[2] = concatenate[dimension=0] cy cz
db:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb da
dc:f32[100,14000] = mul cx db
dd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
de:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
df:i32[2] = concatenate[dimension=0] dd de
dg:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb df
dh:f32[100,14000] = mul dc dg
di:f32[100,14000] = sub ct dh
dj:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
dk:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
dl:i32[2] = concatenate[dimension=0] dj dk
dm:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb dl
dn:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
do:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
dp:i32[2] = concatenate[dimension=0] dn do
dq:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb dp
dr:f32[100,14000] = mul dm dq
ds:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
dt:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
du:i32[2] = concatenate[dimension=0] ds dt
dv:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb du
dw:f32[100,14000] = mul dr dv
dx:f32[100,14000] = sub di dw
dy:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
dz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
ea:i32[2] = concatenate[dimension=0] dy dz
eb:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb ea
ec:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
ed:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
ee:i32[2] = concatenate[dimension=0] ec ed
ef:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb ee
eg:f32[100,14000] = mul eb ef
eh:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
ei:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
ej:i32[2] = concatenate[dimension=0] eh ei
ek:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb ej
el:f32[100,14000] = mul eg ek
em:f32[100,14000] = sub dx el
in (em,) }
name=det
] z
in (ba,) }
jvp_jaxpr_thunk=<function _memoize..memoized at 0x7f04d7a14900>
num_consts=0
] x
en:f32[100,14000] = xla_call[
call_jaxpr={ lambda ; eo:f32[100,14000]. let
ep:f32[100,14000] = sign eo
in (ep,) }
name=sign
] y
eq:i32[3,3] = iota[dimension=0 dtype=int32 shape=(3, 3)]
er:i32[3,3] = add eq 0
es:i32[3,3] = iota[dimension=1 dtype=int32 shape=(3, 3)]
et:bool[3,3] = eq er es
eu:f32[3,3] = convert_element_type[new_dtype=float32 weak_type=False] et
ev:i32[] = add -1 3
ew:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ev
ex:i32[] = add -1 3
ey:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ex
ez:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ew
fa:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ey
fb:i32[2] = concatenate[dimension=0] ez fa
fc:f32[14000,3,3] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(14000, 3, 3)
] eu
fd:f32[100,14000,3,3] = broadcast_in_dim[
broadcast_dimensions=(1, 2, 3)
shape=(100, 14000, 3, 3)
] fc
fe:f32[100,14000,3,3] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(2, 3), scatter_dims_to_operand_dims=(2, 3))
indices_are_sorted=True
mode=GatherScatterMode.FILL_OR_DROP
unique_indices=True
update_consts=()
update_jaxpr={ lambda ; ff:f32[] fg:f32[]. let in (fg,) }
] fd fb en
fh:f32[100,14000,3,3] = transpose[permutation=(0, 1, 3, 2)] p
fi:f32[100,14000,3,3] = dot_general[
dimension_numbers=(([3], [2]), ([0, 1], [0, 1]))
] fh fe
fj:f32[100,14000,3,3] = transpose[permutation=(0, 1, 3, 2)] o
fk:f32[100,14000,3,3] = dot_general[
dimension_numbers=(([3], [2]), ([0, 1], [0, 1]))
] fi fj
fl:f32[100,51,14000,3] = dot_general[
dimension_numbers=(([2], [2]), ([0], [0]))
] l fk
fm:f32[100,14000,51,3] = transpose[permutation=(0, 2, 1, 3)] fl
fn:f32[1,14000,51,3] = broadcast_in_dim[
broadcast_dimensions=(1, 2, 3)
shape=(1, 14000, 51, 3)
] h
fo:f32[100,14000,51,3] = sub fn fm
fp:f32[100,14000,51,3] = integer_pow[y=2] fo
fq:f32[100,14000,51] = reduce_sum[axes=(3,)] fp
fr:f32[100,14000] = reduce_sum[axes=(2,)] fq
fs:f32[100,14000] = div fr 51.0
ft:f32[100,14000] = sqrt fs
in (ft,) }, ())
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1218, in _lower_jaxpr_to_fun_cached
func_op = ctx.cached_call_jaxpr_lowerings[key]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: ('svd', { lambda ; a:f32[14000,3,100,3]. let
b:f32[100,14000,3,3] = transpose[permutation=(2, 0, 1, 3)] a
c:f32[100,14000,3] d:f32[100,14000,3,3] e:f32[100,14000,3,3] = svd[
compute_uv=True
full_matrices=False
] b
in (d, e) }, ())
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "", line 173, in _run_module_as_main
File "", line 65, in _run_code
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel_launcher.py", line 0, in
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/traitlets/config/application.py", line 1035, in launch_instance @classmethod
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 717, in start
def start(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 194, in start
def start(self) -> None:
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/asyncio/base_events.py", line 593, in run_forever
def run_forever(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/asyncio/base_events.py", line 1845, in _run_once
def _run_once(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/asyncio/events.py", line 78, in _run
def _run(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 507, in dispatch_queue
async def dispatch_queue(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 493, in process_one
async def process_one(self, wait=True):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 363, in dispatch_shell
async def dispatch_shell(self, msg):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 695, in execute_request
async def execute_request(self, stream, ident, parent):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 339, in do_execute
async def do_execute(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 543, in run_cell
def run_cell(self, *args, **kwargs):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 2976, in run_cell
def run_cell(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3018, in _run_cell
def _run_cell(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 120, in _pseudo_sync_runner
def _pseudo_sync_runner(coro):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3117, in run_cell_async
async def run_cell_async(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3349, in run_ast_nodes
async def run_ast_nodes(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3472, in run_code
async def run_code(self, code_obj, result=None, *, async_=False):
File "/scratch/local/jobs/7534193/ipykernel_3608298/886065729.py", line 0, in
File "/project2/andrewferguson/berlaga/peptoids/multiwalker/dMap_JAX/dmap_JAX/rmsd_jax.py", line 67, in run_rmsd
def run_rmsd(traj_jax_array, nref_frames, batch_ref_frame_size=100, output_file_prefix="", device=None, traj2_jax_array=None, overwrite=False):
File "/project2/andrewferguson/berlaga/peptoids/multiwalker/dMap_JAX/dmap_JAX/rmsd_jax.py", line 45, in get_pairwise_rmsd_traj
def get_pairwise_rmsd_traj(traj, ref_index):
File "/project2/andrewferguson/berlaga/peptoids/multiwalker/dMap_JAX/dmap_JAX/rmsd_jax.py", line 10, in jnp_pair_rmsd
def jnp_pair_rmsd(ref, target):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/numpy/linalg.py", line 69, in svd
@_wraps(np.linalg.svd)
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/scratch/local/jobs/7534193/ipykernel_3608298/886065729.py", line 8, in
rjax.run_rmsd(jnp.array(test_trajs),
File "/project2/andrewferguson/berlaga/peptoids/multiwalker/dMap_JAX/dmap_JAX/rmsd_jax.py", line 91, in run_rmsd
prmsd = jit(vmap(get_pairwise_rmsd_traj, in_axes=(None, 0)), device = set_device)(cut_traj, ref_indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/api.py", line 564, in cache_miss
execute = dispatch.xla_call_impl_lazy(fun, *tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 241, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 357, in _xla_callable_uncached
computation = sharded_lowering(fun, device, backend, name, donated_invars,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 348, in sharded_lowering
return pxla.lower_sharding_computation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/pxla.py", line 2933, in lower_sharding_computation
lowering_result = mlir.lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 718, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1007, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1141, in jaxpr_subcomp
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1251, in _xla_call_lower
out_nodes, tokens = _call_lowering(
^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1237, in _call_lowering
symbol_name = _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects).name.value
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1220, in _lower_jaxpr_to_fun_cached
func_op = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1007, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1141, in jaxpr_subcomp
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1251, in _xla_call_lower
out_nodes, tokens = _call_lowering(
^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1237, in _call_lowering
symbol_name = _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects).name.value
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1220, in _lower_jaxpr_to_fun_cached
func_op = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1007, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1141, in jaxpr_subcomp
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/lax/linalg.py", line 1666, in _svd_cpu_gpu_lowering
s, u, vt, info = gesvd_impl(operand_aval.dtype, operand,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jaxlib/gpu_solver.py", line 333, in _gesvd_hlo
lwork, opaque = gpu_solver.build_gesvdj_descriptor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
stb = self.InteractiveTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1428, in structured_traceback
return FormattedTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1319, in structured_traceback
return VerboseTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1191, in structured_traceback
formatted_exceptions += self.format_exception_as_a_whole(etype, evalue, etb, lines_of_context,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1087, in format_exception_as_a_whole
frames.append(self.format_record(record))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 969, in format_record
frame_info.lines, Colors, self.has_colors, lvals
^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 792, in lines
return self._sd.lines
^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/utils.py", line 144, in cached_property_wrapper
value = obj.dict[self.func.name] = self.func(obj)
^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/core.py", line 734, in lines
pieces = self.included_pieces
^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/utils.py", line 144, in cached_property_wrapper
value = obj.dict[self.func.name] = self.func(obj)
^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/core.py", line 681, in included_pieces
pos = scope_pieces.index(self.executing_piece)
^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/utils.py", line 144, in cached_property_wrapper
value = obj.dict[self.func.name] = self.func(obj)
^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/core.py", line 660, in executing_piece
return only(
^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/executing/executing.py", line 190, in only
raise NotOneValueFound('Expected one value, found 0')
The text was updated successfully, but these errors were encountered:
When
device=None
(GPU case), the following (extremely long) Jax error message shows up. Note that it shows up before any data is processed, as there is still an error when the pkl files already exist:KeyError: ('jnp_pair_rmsd', { lambda ; a:f32[14000,51,3] b:f32[100,51,3]. let
c:f32[14000,51,3] = copy a
d:f32[100,51,3] = copy b
e:f32[14000,3] = reduce_sum[axes=(1,)] c
f:f32[14000,3] = div e 51.0
g:f32[14000,1,3] = broadcast_in_dim[
broadcast_dimensions=(0, 2)
shape=(14000, 1, 3)
] f
h:f32[14000,51,3] = sub c g
i:f32[100,3] = reduce_sum[axes=(1,)] d
j:f32[100,3] = div i 51.0
k:f32[100,1,3] = broadcast_in_dim[
broadcast_dimensions=(0, 2)
shape=(100, 1, 3)
] j
l:f32[100,51,3] = sub d k
m:f32[14000,3,51] = transpose[permutation=(0, 2, 1)] h
n:f32[14000,3,100,3] = dot_general[dimension_numbers=(([2], [1]), ([], []))] m
l
o:f32[100,14000,3,3] p:f32[100,14000,3,3] = xla_call[
call_jaxpr={ lambda ; q:f32[14000,3,100,3]. let
r:f32[100,14000,3,3] = transpose[permutation=(2, 0, 1, 3)] q
s:f32[100,14000,3] t:f32[100,14000,3,3] u:f32[100,14000,3,3] = svd[
compute_uv=True
full_matrices=False
] r
in (t, u) }
name=svd
] n
v:f32[100,14000,3,3] = transpose[permutation=(0, 1, 3, 2)] p
w:f32[100,14000,3,3] = transpose[permutation=(0, 1, 3, 2)] o
x:f32[100,14000,3,3] = dot_general[
dimension_numbers=(([3], [2]), ([0, 1], [0, 1]))
] v w
y:f32[100,14000] = custom_jvp_call[
call_jaxpr={ lambda ; z:f32[100,14000,3,3]. let
ba:f32[100,14000] = xla_call[
call_jaxpr={ lambda ; bb:f32[100,14000,3,3]. let
bc:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
bd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
be:i32[2] = concatenate[dimension=0] bc bd
bf:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb be
bg:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
bh:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
bi:i32[2] = concatenate[dimension=0] bg bh
bj:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb bi
bk:f32[100,14000] = mul bf bj
bl:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
bm:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
bn:i32[2] = concatenate[dimension=0] bl bm
bo:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb bn
bp:f32[100,14000] = mul bk bo
bq:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
br:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
bs:i32[2] = concatenate[dimension=0] bq br
bt:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb bs
bu:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
bv:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
bw:i32[2] = concatenate[dimension=0] bu bv
bx:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb bw
by:f32[100,14000] = mul bt bx
bz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
ca:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
cb:i32[2] = concatenate[dimension=0] bz ca
cc:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb cb
cd:f32[100,14000] = mul by cc
ce:f32[100,14000] = add bp cd
cf:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
cg:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
ch:i32[2] = concatenate[dimension=0] cf cg
ci:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb ch
cj:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
ck:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
cl:i32[2] = concatenate[dimension=0] cj ck
cm:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb cl
cn:f32[100,14000] = mul ci cm
co:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
cp:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
cq:i32[2] = concatenate[dimension=0] co cp
cr:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb cq
cs:f32[100,14000] = mul cn cr
ct:f32[100,14000] = add ce cs
cu:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
cv:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
cw:i32[2] = concatenate[dimension=0] cu cv
cx:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb cw
cy:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
cz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
da:i32[2] = concatenate[dimension=0] cy cz
db:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb da
dc:f32[100,14000] = mul cx db
dd:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
de:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
df:i32[2] = concatenate[dimension=0] dd de
dg:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb df
dh:f32[100,14000] = mul dc dg
di:f32[100,14000] = sub ct dh
dj:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
dk:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
dl:i32[2] = concatenate[dimension=0] dj dk
dm:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb dl
dn:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
do:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
dp:i32[2] = concatenate[dimension=0] dn do
dq:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb dp
dr:f32[100,14000] = mul dm dq
ds:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
dt:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
du:i32[2] = concatenate[dimension=0] ds dt
dv:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb du
dw:f32[100,14000] = mul dr dv
dx:f32[100,14000] = sub di dw
dy:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
dz:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
ea:i32[2] = concatenate[dimension=0] dy dz
eb:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb ea
ec:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1
ed:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 0
ee:i32[2] = concatenate[dimension=0] ec ed
ef:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb ee
eg:f32[100,14000] = mul eb ef
eh:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
ei:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 2
ej:i32[2] = concatenate[dimension=0] eh ei
ek:f32[100,14000] = gather[
dimension_numbers=GatherDimensionNumbers(offset_dims=(0, 1), collapsed_slice_dims=(2, 3), start_index_map=(2, 3))
fill_value=None
indices_are_sorted=True
mode=GatherScatterMode.PROMISE_IN_BOUNDS
slice_sizes=(100, 14000, 1, 1)
unique_indices=True
] bb ej
el:f32[100,14000] = mul eg ek
em:f32[100,14000] = sub dx el
in (em,) }
name=det
] z
in (ba,) }
jvp_jaxpr_thunk=<function _memoize..memoized at 0x7f04d7a14900>
num_consts=0
] x
en:f32[100,14000] = xla_call[
call_jaxpr={ lambda ; eo:f32[100,14000]. let
ep:f32[100,14000] = sign eo
in (ep,) }
name=sign
] y
eq:i32[3,3] = iota[dimension=0 dtype=int32 shape=(3, 3)]
er:i32[3,3] = add eq 0
es:i32[3,3] = iota[dimension=1 dtype=int32 shape=(3, 3)]
et:bool[3,3] = eq er es
eu:f32[3,3] = convert_element_type[new_dtype=float32 weak_type=False] et
ev:i32[] = add -1 3
ew:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ev
ex:i32[] = add -1 3
ey:i32[] = convert_element_type[new_dtype=int32 weak_type=False] ex
ez:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ew
fa:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] ey
fb:i32[2] = concatenate[dimension=0] ez fa
fc:f32[14000,3,3] = broadcast_in_dim[
broadcast_dimensions=(1, 2)
shape=(14000, 3, 3)
] eu
fd:f32[100,14000,3,3] = broadcast_in_dim[
broadcast_dimensions=(1, 2, 3)
shape=(100, 14000, 3, 3)
] fc
fe:f32[100,14000,3,3] = scatter[
dimension_numbers=ScatterDimensionNumbers(update_window_dims=(0, 1), inserted_window_dims=(2, 3), scatter_dims_to_operand_dims=(2, 3))
indices_are_sorted=True
mode=GatherScatterMode.FILL_OR_DROP
unique_indices=True
update_consts=()
update_jaxpr={ lambda ; ff:f32[] fg:f32[]. let in (fg,) }
] fd fb en
fh:f32[100,14000,3,3] = transpose[permutation=(0, 1, 3, 2)] p
fi:f32[100,14000,3,3] = dot_general[
dimension_numbers=(([3], [2]), ([0, 1], [0, 1]))
] fh fe
fj:f32[100,14000,3,3] = transpose[permutation=(0, 1, 3, 2)] o
fk:f32[100,14000,3,3] = dot_general[
dimension_numbers=(([3], [2]), ([0, 1], [0, 1]))
] fi fj
fl:f32[100,51,14000,3] = dot_general[
dimension_numbers=(([2], [2]), ([0], [0]))
] l fk
fm:f32[100,14000,51,3] = transpose[permutation=(0, 2, 1, 3)] fl
fn:f32[1,14000,51,3] = broadcast_in_dim[
broadcast_dimensions=(1, 2, 3)
shape=(1, 14000, 51, 3)
] h
fo:f32[100,14000,51,3] = sub fn fm
fp:f32[100,14000,51,3] = integer_pow[y=2] fo
fq:f32[100,14000,51] = reduce_sum[axes=(3,)] fp
fr:f32[100,14000] = reduce_sum[axes=(2,)] fq
fs:f32[100,14000] = div fr 51.0
ft:f32[100,14000] = sqrt fs
in (ft,) }, ())
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1218, in _lower_jaxpr_to_fun_cached
func_op = ctx.cached_call_jaxpr_lowerings[key]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^
KeyError: ('svd', { lambda ; a:f32[14000,3,100,3]. let
b:f32[100,14000,3,3] = transpose[permutation=(2, 0, 1, 3)] a
c:f32[100,14000,3] d:f32[100,14000,3,3] e:f32[100,14000,3,3] = svd[
compute_uv=True
full_matrices=False
] b
in (d, e) }, ())
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "", line 173, in _run_module_as_main
File "", line 65, in _run_code
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel_launcher.py", line 0, in
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/traitlets/config/application.py", line 1035, in launch_instance
@classmethod
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 717, in start
def start(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 194, in start
def start(self) -> None:
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/asyncio/base_events.py", line 593, in run_forever
def run_forever(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/asyncio/base_events.py", line 1845, in _run_once
def _run_once(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/asyncio/events.py", line 78, in _run
def _run(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 507, in dispatch_queue
async def dispatch_queue(self):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 493, in process_one
async def process_one(self, wait=True):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 363, in dispatch_shell
async def dispatch_shell(self, msg):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 695, in execute_request
async def execute_request(self, stream, ident, parent):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 339, in do_execute
async def do_execute(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 543, in run_cell
def run_cell(self, *args, **kwargs):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 2976, in run_cell
def run_cell(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3018, in _run_cell
def _run_cell(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 120, in _pseudo_sync_runner
def _pseudo_sync_runner(coro):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3117, in run_cell_async
async def run_cell_async(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3349, in run_ast_nodes
async def run_ast_nodes(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3472, in run_code
async def run_code(self, code_obj, result=None, *, async_=False):
File "/scratch/local/jobs/7534193/ipykernel_3608298/886065729.py", line 0, in
File "/project2/andrewferguson/berlaga/peptoids/multiwalker/dMap_JAX/dmap_JAX/rmsd_jax.py", line 67, in run_rmsd
def run_rmsd(traj_jax_array, nref_frames, batch_ref_frame_size=100, output_file_prefix="", device=None, traj2_jax_array=None, overwrite=False):
File "/project2/andrewferguson/berlaga/peptoids/multiwalker/dMap_JAX/dmap_JAX/rmsd_jax.py", line 45, in get_pairwise_rmsd_traj
def get_pairwise_rmsd_traj(traj, ref_index):
File "/project2/andrewferguson/berlaga/peptoids/multiwalker/dMap_JAX/dmap_JAX/rmsd_jax.py", line 10, in jnp_pair_rmsd
def jnp_pair_rmsd(ref, target):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/numpy/linalg.py", line 69, in svd
@_wraps(np.linalg.svd)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "/scratch/local/jobs/7534193/ipykernel_3608298/886065729.py", line 8, in
rjax.run_rmsd(jnp.array(test_trajs),
File "/project2/andrewferguson/berlaga/peptoids/multiwalker/dMap_JAX/dmap_JAX/rmsd_jax.py", line 91, in run_rmsd
prmsd = jit(vmap(get_pairwise_rmsd_traj, in_axes=(None, 0)), device = set_device)(cut_traj, ref_indices)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/api.py", line 564, in cache_miss
execute = dispatch.xla_call_impl_lazy(fun, *tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 241, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 357, in _xla_callable_uncached
computation = sharded_lowering(fun, device, backend, name, donated_invars,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/dispatch.py", line 348, in sharded_lowering
return pxla.lower_sharding_computation(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/pxla.py", line 2933, in lower_sharding_computation
lowering_result = mlir.lower_jaxpr_to_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 718, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1007, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1141, in jaxpr_subcomp
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1251, in _xla_call_lower
out_nodes, tokens = _call_lowering(
^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1237, in _call_lowering
symbol_name = _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects).name.value
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1220, in _lower_jaxpr_to_fun_cached
func_op = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1007, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1141, in jaxpr_subcomp
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1251, in _xla_call_lower
out_nodes, tokens = _call_lowering(
^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1237, in _call_lowering
symbol_name = _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects).name.value
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1220, in _lower_jaxpr_to_fun_cached
func_op = lower_jaxpr_to_fun(ctx, fn_name, call_jaxpr, effects)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1007, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/interpreters/mlir.py", line 1141, in jaxpr_subcomp
ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jax/_src/lax/linalg.py", line 1666, in _svd_cpu_gpu_lowering
s, u, vt, info = gesvd_impl(operand_aval.dtype, operand,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/jaxlib/gpu_solver.py", line 333, in _gesvd_hlo
lwork, opaque = gpu_solver.build_gesvdj_descriptor(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/gpu/solver_kernels.cc:45: operation gpusolverDnCreate(&handle) failed: cuSolver internal error
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 2105, in showtraceback
stb = self.InteractiveTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1428, in structured_traceback
return FormattedTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1319, in structured_traceback
return VerboseTB.structured_traceback(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1191, in structured_traceback
formatted_exceptions += self.format_exception_as_a_whole(etype, evalue, etb, lines_of_context,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 1087, in format_exception_as_a_whole
frames.append(self.format_record(record))
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 969, in format_record
frame_info.lines, Colors, self.has_colors, lvals
^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/IPython/core/ultratb.py", line 792, in lines
return self._sd.lines
^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/utils.py", line 144, in cached_property_wrapper
value = obj.dict[self.func.name] = self.func(obj)
^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/core.py", line 734, in lines
pieces = self.included_pieces
^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/utils.py", line 144, in cached_property_wrapper
value = obj.dict[self.func.name] = self.func(obj)
^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/core.py", line 681, in included_pieces
pos = scope_pieces.index(self.executing_piece)
^^^^^^^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/utils.py", line 144, in cached_property_wrapper
value = obj.dict[self.func.name] = self.func(obj)
^^^^^^^^^^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/stack_data/core.py", line 660, in executing_piece
return only(
^^^^^
File "/project2/andrewferguson/berlaga/conda_env/envs/djax/lib/python3.11/site-packages/executing/executing.py", line 190, in only
raise NotOneValueFound('Expected one value, found 0')
The text was updated successfully, but these errors were encountered: