-
Notifications
You must be signed in to change notification settings - Fork 36
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't checkpoint Numpy RNG state #1071
Comments
Stepped through the orbax source with a debugger and realised that the problem isn't the string, it gets handled by As a result I'm heavily reframing my issue to not be about inability to checkpoint string leafs but rather to checkpoint numpy random key states properly. Apologies for any confusion. I have looked through |
How are you planning to construct the Generator object back from the restored state? |
The Generator objects have a Also currently I resorted to using |
Thanks for sharing the details. Will using
Alternatively, using Json should suffice too. |
The problem isn't really the format, as yes the format returned by Serialising with JSON works because it just writes it into a string essentially and Python can read it back from said string without issue. I suppose I can stick with that, or put the numpy Generator objects into my checkpoint PyTree directly instead of their state and then write a TypeHandler for |
Thanks for clarifying the difference between MT19937 and PCG64! A JSON based solution is ideal for this scenario. I will look into it. |
Hi. So I'm trying to use orbax-checkpoint to checkpoint my full experiment state, not just my network weights, and part of this state is NumPy RNG states, which look as follows:
When trying to checkpoint this with orbax-checkpoint, I get the following error:
I looked through the API and the docs and it does seem like string leaf nodes should be supported. But I think the problem is that for some reason orbax-checkpoint doesn't even see the string asstr
(which is whatStringHandler
is registered for), it sees it as anobject
(which is technically not false).Edit: The issue seems to be that the large integers in the numpy random key state get turned into
dtype=object
numpy arrays, which is not orbax' fault. But perhaps this should be able to be handled in some way by orbax?I noticed in the changelog where
NumpyRandomKeyCheckpointHandler
is introduced that it's intended fornumpy.random.get_state()
, but that's just the global rng state and more or less a legacy feature, and (as far as I know) current numpy rng best practice is to instantiate and use individual rng objects through the new API, which seem to have a different kind of state and that's where the issue is.Here is a Colab notebook reproducing the issue on
orbax-checkpoint==0.5.23
.The text was updated successfully, but these errors were encountered: