Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_validate_params fails on zero-sized arrays #1309

Open
hrbigelow opened this issue Nov 8, 2024 · 1 comment
Open

_validate_params fails on zero-sized arrays #1309

hrbigelow opened this issue Nov 8, 2024 · 1 comment

Comments

@hrbigelow
Copy link

Hi,

@niketkumar @cpgaffney1,

cc @dionhaefner

The following attempts to serialize a zero-sized array, but it fails validation in _validate_params.

I believe the problem is that _validate_params expects to find for every 'foo/.zarray' entry, a matching data entry foo/0. However, this code produces tensorstore entries: 'a/0', 'a/.zarray', 'z/.zarray', but not z/0 since there is no data in the z tensor.

I'm actually not sure if tensorstore saves an entry z/0 or not, or what the intended behavior should be.

Any insight would be greatly appreciated!

import jax.numpy as jnp
import jax.tree_util as jtu
import tempfile
import orbax.checkpoint as ocp

target = {
    'a': jnp.array([1, 2, 3], jnp.int32),
    'z': jnp.zeros((0,)),
}

orbax_checkpointer = ocp.Checkpointer(
  ocp.PyTreeCheckpointHandler()
)

with tempfile.TemporaryDirectory() as ckpt_path:
  overwrite = True
  save_args = jtu.tree_map(lambda _: ocp.SaveArgs(), target)
  orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
(jax_env) henry@henry-gs65:orbax$ python flax4309.py 
Traceback (most recent call last):
  File "/home/henry/ai/projects/orbax/flax4309.py", line 18, in <module>
    orbax_checkpointer.save(ckpt_path, target, save_args=save_args, force=overwrite)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/checkpointer.py", line 216, in save
    self._handler.finalize(tmpdir.get())
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 1004, in finalize
    self._handler_impl.finalize(directory)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py", line 806, in finalize
    asyncio_utils.run_sync(
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
    return asyncio.run(coro)
           ^^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 190, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/henry/miniconda3/lib/python3.11/asyncio/base_events.py", line 653, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 704, in merge_ocdbt_per_process_files
    await _validate_params(directory, ts_context, use_zarr3=use_zarr3)
  File "/home/henry/ai/projects/orbax/checkpoint/orbax/checkpoint/type_handlers.py", line 625, in _validate_params
    raise ValueError(
ValueError: Save failed: 1/2 params are missing in checkpoint:
z.
Tensorstore KvStore: KvStore({
  'base': {
    'driver': 'file',
    'path': '/tmp/tmpbxi1zpec.orbax-checkpoint-tmp-0/',
  },
  'cache_pool': 'cache_pool#ocdbt',
  'config': {
    'compression': {'id': 'zstd'},
    'max_decoded_node_bytes': 100000000,
    'max_inline_value_bytes': 1024,
    'uuid': '3ef941407cca4f778414e9e92b15dedb',
    'version_tree_arity_log2': 4,
  },
  'context': {
    'cache_pool#ocdbt': {'total_bytes_limit': 100000000},
    'data_copy_concurrency': {},
    'file_io_concurrency': {'limit': 128},
    'file_io_sync': True,
    'ocdbt_coordinator': {},
  },
  'driver': 'ocdbt',
  'experimental_read_coalescing_interval': '1ms',
  'experimental_read_coalescing_merged_bytes': 500000000000,
  'experimental_read_coalescing_threshold_bytes': 1000000,
}).
@cpgaffney1
Copy link
Collaborator

Thanks for spotting this, 0-sized array handling is not well defined and we have no tests (internal or external) for it. We will clarify the intended behavior, add tests, and resolve the validation issue, and get back to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants