Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't checkpoint Numpy RNG state #1071

Open
rainx0r opened this issue Aug 16, 2024 · 6 comments
Open

Can't checkpoint Numpy RNG state #1071

rainx0r opened this issue Aug 16, 2024 · 6 comments

Comments

@rainx0r
Copy link

rainx0r commented Aug 16, 2024

Hi. So I'm trying to use orbax-checkpoint to checkpoint my full experiment state, not just my network weights, and part of this state is NumPy RNG states, which look as follows:

>>> import numpy as np
>>> rng = np.random.default_rng(42)
>>> state = rng.__getstate__()
>>> state
{'bit_generator': 'PCG64', 'state': {'state': 274674114334540486603088602300644985544, 'inc': 332724090758049132448979897138935081983}, 'has_uint32': 0, 'uinteger': 0}

When trying to checkpoint this with orbax-checkpoint, I get the following error:

import orbax.checkpoint as ocp

dir = ocp.test_utils.erase_and_create_empty("/tmp/string-checkpoint-reprod")

ckpt_manager = ocp.CheckpointManager(
    dir,
    options=ocp.CheckpointManagerOptions(
        max_to_keep=5,
        create=True,
    ),
)

ckpt_manager.save(0, args=ocp.args.NumpyRandomKeySave(state))
ValueError: Error parsing object member "dtype": Unsupported data type: "object" [source locations='tensorstore/internal/json_binding/json_binding.h:384\ntensorstore/internal/json_binding/json_binding.h:524\ntensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']

I looked through the API and the docs and it does seem like string leaf nodes should be supported. But I think the problem is that for some reason orbax-checkpoint doesn't even see the string as str (which is what StringHandler is registered for), it sees it as an object (which is technically not false).

Edit: The issue seems to be that the large integers in the numpy random key state get turned into dtype=object numpy arrays, which is not orbax' fault. But perhaps this should be able to be handled in some way by orbax?

I noticed in the changelog where NumpyRandomKeyCheckpointHandler is introduced that it's intended for numpy.random.get_state(), but that's just the global rng state and more or less a legacy feature, and (as far as I know) current numpy rng best practice is to instantiate and use individual rng objects through the new API, which seem to have a different kind of state and that's where the issue is.

Here is a Colab notebook reproducing the issue on orbax-checkpoint==0.5.23.

@rainx0r rainx0r changed the title Can't checkpoint a PyTree with string leaf nodes Can't checkpoint a PyTree with string leaf Aug 16, 2024
@rainx0r
Copy link
Author

rainx0r commented Aug 16, 2024

Stepped through the orbax source with a debugger and realised that the problem isn't the string, it gets handled by StringHandler correctly. The problem is the really large integers in the numpy rng state. It looks like when they're handled with ScalarHandler, they get turned into numpy arrays of dtype object.

As a result I'm heavily reframing my issue to not be about inability to checkpoint string leafs but rather to checkpoint numpy random key states properly. Apologies for any confusion.

I have looked through NumpyRandomKeySave's source as well as at NumpyRandomKeyCheckpointHandler but they seem to defer to PyTreeSave / PyTreeCheckpointHandler anyway, which don't currently work.

@rainx0r rainx0r changed the title Can't checkpoint a PyTree with string leaf Can't checkpoint Numpy RNG state Aug 16, 2024
@niketkumar
Copy link
Collaborator

How are you planning to construct the Generator object back from the restored state?

@rainx0r
Copy link
Author

rainx0r commented Aug 27, 2024

The Generator objects have a __setstate__() function that takes in the dict state exported from the Generator's __getstate__() function and that's how I currently have it implemented. Not entirely sure if these two functions are part of the Generator's public API or intended to be used directly, but another strategy that doesn't use any potentially "private" functions is to get the state through the Generator's .bit_generator.state attribute and then to set it by assigning it back.

Also currently I resorted to using JsonSave() and JsonRestore() for these RNG states.

@niketkumar
Copy link
Collaborator

Thanks for sharing the details.

Will using numpy.random.get_state(legacy=False) meet your requirements? In that case, Orbax already supports it. Please take a look at this unit test:

def test_save_and_restore_numpy_random_key_nonlegacy(self):

Alternatively, using Json should suffice too.

@rainx0r
Copy link
Author

rainx0r commented Aug 28, 2024

The problem isn't really the format, as yes the format returned by numpy.random.get_state(legacy=False) is the same as the one you get by accessing Generator.bit_generator.state (or calling Generator.__getstate__()). It's more so that the BitGenerator the new Generator objects use is not MT19937, which is what the older global numpy random uses, but rather use PCG64 whose state involves extremely large integers that are unsupported by numpy arrays (which is Orbax' default mode of serialising scalars from what I can tell).

Serialising with JSON works because it just writes it into a string essentially and Python can read it back from said string without issue. I suppose I can stick with that, or put the numpy Generator objects into my checkpoint PyTree directly instead of their state and then write a TypeHandler for numpy.random.Generator that extracts / sets the state and serialises with JSON.

@niketkumar
Copy link
Collaborator

Thanks for clarifying the difference between MT19937 and PCG64!

A JSON based solution is ideal for this scenario. I will look into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants