-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change TransformState to NamedTuple #106
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, but in my opinion it would be good to check how this affects memory consumption and allocation - since you allocate a new state whenever you update it, which may be quite inefficient
This PR shouldn't affect memory consumption, it just changes the handling of the algorithm states to a better convention. It would be good to have some numerics on memory consumption and the pros and cons of using There is also an element of horses-for-courses since for MCMC-style where you want to collect samples along a trajectory you need |
I'm not sure about this: the previous inplace behaviour would change the elements of a previously allocated object. Now, you re-allocate a new object every time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me apart from a docstring to change
We define a new |
This is quite a large PR unfortunately but fixes #83 and also cleans up the
TransformState
code by moving from mutabledataclass
to immutableNamedTuple
which is also better encapsulation practice for functional code.One thing that's nice is that
NamedTuple
is already added to the pytree registry for optree and torch (fixing #83), so we don't need to do that manually as before.There was an issue with the
aux
handling as modifyingaux
is not possible to do in-place as we don't know the structure ofaux
beforelog_posterior
is called, alsoaux
is not guaranteed to be aTensorTree
and could contain strings etc.The proposed fix is to return a new state with all other attributes modified in-place (i.e. pointers to old state) but
aux
replaced.