Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: write(): incompatible function arguments. #1039

Open
FranzKnut opened this issue Aug 5, 2024 · 2 comments
Open

TypeError: write(): incompatible function arguments. #1039

FranzKnut opened this issue Aug 5, 2024 · 2 comments

Comments

@FranzKnut
Copy link

After upgrading to jax==0.4.31 I am seeing this error when trying to save a model using the PyTreeCheckpointer.
Downgrading to 0.4.30 fixed it for now.

  File ".../site-packages/orbax/checkpoint/checkpointer.py", line 151, in save
    self._handler.save(tmpdir, args=ckpt_args)
  File ".../site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 500, in save
    super().save(directory, args=args)
  File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 615, in save
    asyncio.run(async_save(directory, *args, **kwargs))
  File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/runners.py", line 190, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/julian/.pyenv/versions/3.11.6/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 608, in async_save
    commit_futures = await self.async_save(*args, **kwargs)  # pytype: disable=bad-return-type
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 489, in async_save
    return await super().async_save(directory, args=args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/orbax/checkpoint/base_pytree_checkpoint_handler.py", line 568, in async_save
    commit_futures = await asyncio.gather(*serialize_ops)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/orbax/checkpoint/type_handlers.py", line 1350, in serialize
    await asyncio.gather(*synchronous_ops)
  File ".../site-packages/jax/experimental/array_serialization/serialization.py", line 304, in async_serialize
    return await asyncio.gather(*future_write_state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/jax/experimental/array_serialization/serialization.py", line 284, in _write_array
    write_future = t[shard.index].write(
                   ^^^^^^^^^^^^^^^^^^^^^
TypeError: write(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorstore.TensorStore, source: Union[tensorstore.TensorStore, numpy.typing.ArrayLike]) -> tensorstore.WriteFutures

Invoked with: TensorStore({
  'base': {
    'assume_metadata': True,
    'driver': 'zarr',
    'dtype': 'float32',
    'kvstore': {
      'base': {
        'driver': 'file',
        'path': '.../7afc2a53292a409e98a1a41b.orbax-checkpoint-tmp-0/ocdbt.process_0/',
      },
      'cache_pool': 'cache_pool#ocdbt',
      'config': {
        'compression': {'id': 'zstd'},
        'max_decoded_node_bytes': 100000000,
        'max_inline_value_bytes': 1024,
        'uuid': 'da97f553e0a14b1d4ee7c13b8c741510',
        'version_tree_arity_log2': 4,
      },
      'driver': 'ocdbt',
      'experimental_read_coalescing_interval': '1ms',
      'experimental_read_coalescing_merged_bytes': 500000000000,
      'experimental_read_coalescing_threshold_bytes': 1000000,
      'path': '0.batch_stats.encoder.BatchNorm_0.mean/',
    },
    'metadata': {
      'chunks': [8],
      'compressor': {'id': 'zstd', 'level': 1},
      'dimension_separator': '.',
      'dtype': '<f4',
      'fill_value': None,
      'filters': None,
      'order': 'C',
      'shape': [8],
      'zarr_format': 2,
    },
    'recheck_cached_data': False,
    'recheck_cached_metadata': False,
  },
  'context': {
    'cache_pool': {},
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'cast',
  'dtype': 'float32',
  'transform': {'input_exclusive_max': [[8]], 'input_inclusive_min': [0]},
}), array([-0.81411403, -0.17954189,  0.06563368, -0.30694914, -0.9737644 ,
        0.6138028 ,  0.8517958 ,  0.22747818], dtype=float32); kwargs: can_reference_source_data_indefinitely=True
@cpgaffney1
Copy link
Collaborator

Upgrade Orbax version? At head we don't depend on jax/experimental/array_serialization/serialization.py.

@noahzhy
Copy link

noahzhy commented Oct 7, 2024

same problem, same solution. 0.4.30 works on me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants