You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
After upgrading to jax==0.4.31 I am seeing this error when trying to save a model using the PyTreeCheckpointer.
Downgrading to 0.4.30 fixed it for now.
File ".../site-packages/orbax/checkpoint/checkpointer.py", line 151, in save
self._handler.save(tmpdir, args=ckpt_args)
File ".../site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 500, in save
super().save(directory, args=args)
File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 615, in save
asyncio.run(async_save(directory, *args, **kwargs))
File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/runners.py", line 190, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
return future.result()
^^^^^^^^^^^^^^^
File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 608, in async_save
commit_futures = await self.async_save(*args, **kwargs) # pytype: disable=bad-return-type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 489, in async_save
return await super().async_save(directory, args=args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 568, in async_save
commit_futures = await asyncio.gather(*serialize_ops)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/orbax/checkpoint/type_handlers.py", line 1350, in serialize
await asyncio.gather(*synchronous_ops)
File ".../site-packages/jax/experimental/array_serialization/serialization.py", line 304, in async_serialize
return await asyncio.gather(*future_write_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/jax/experimental/array_serialization/serialization.py", line 284, in _write_array
write_future = t[shard.index].write(
^^^^^^^^^^^^^^^^^^^^^
TypeError: write(): incompatible function arguments. The following argument types are supported:
1. (self: tensorstore.TensorStore, source: Union[tensorstore.TensorStore, numpy.typing.ArrayLike]) -> tensorstore.WriteFutures
Invoked with: TensorStore({
'base': {
'assume_metadata': True,
'driver': 'zarr',
'dtype': 'float32',
'kvstore': {
'base': {
'driver': 'file',
'path': '.../7afc2a53292a409e98a1a41b.orbax-checkpoint-tmp-0/ocdbt.process_0/',
},
'cache_pool': 'cache_pool#ocdbt',
'config': {
'compression': {'id': 'zstd'},
'max_decoded_node_bytes': 100000000,
'max_inline_value_bytes': 1024,
'uuid': 'da97f553e0a14b1d4ee7c13b8c741510',
'version_tree_arity_log2': 4,
},
'driver': 'ocdbt',
'experimental_read_coalescing_interval': '1ms',
'experimental_read_coalescing_merged_bytes': 500000000000,
'experimental_read_coalescing_threshold_bytes': 1000000,
'path': '0.batch_stats.encoder.BatchNorm_0.mean/',
},
'metadata': {
'chunks': [8],
'compressor': {'id': 'zstd', 'level': 1},
'dimension_separator': '.',
'dtype': '<f4',
'fill_value': None,
'filters': None,
'order': 'C',
'shape': [8],
'zarr_format': 2,
},
'recheck_cached_data': False,
'recheck_cached_metadata': False,
},
'context': {
'cache_pool': {},
'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
'data_copy_concurrency': {},
'file_io_concurrency': {'limit': 128},
'file_io_sync': True,
'ocdbt_coordinator': {},
},
'driver': 'cast',
'dtype': 'float32',
'transform': {'input_exclusive_max': [[8]], 'input_inclusive_min': [0]},
}), array([-0.81411403, -0.17954189, 0.06563368, -0.30694914, -0.9737644 ,
0.6138028 , 0.8517958 , 0.22747818], dtype=float32); kwargs: can_reference_source_data_indefinitely=True
The text was updated successfully, but these errors were encountered:
After upgrading to
jax==0.4.31
I am seeing this error when trying to save a model using thePyTreeCheckpointer
.Downgrading to
0.4.30
fixed it for now.The text was updated successfully, but these errors were encountered: