-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix handling of OrderedDict
with optree and related documentation.
#20481
base: master
Are you sure you want to change the base?
Conversation
The tree API had specific but contradicting documentation calling out the handling of `OrderedDict`s. However, the behavior of the `optree` implementation did not honor this documentation (using the key order, not the sequence order) for `flatten`, although it did for `pack_sequence_as`. The result was that not only did `flatten` not behave the same with `optree` and `dm-tree`, but also `pack_sequence_as(flatten(...))` was not idempotent. The `optree` implementation did have all the machinery needed to handle `OrderedDict`s per spec, which was used for `pack_sequence_as`, but not `flatten`. This also fixes the discrepancy in the behavior for `namedtuple`s. - Fixed contradicting documentation in `flatten` and `pack_sequence_as` related to the handling of `OrderedDict`s. - Fixed references to `unflatten_as`, which doesn't exist. - Removed most `if optree` tests in `tree_test.py`, which should not exist for consistency between `optree` and `dm-tree`. - Fixed unit tests which were incorrectly flattening the result of `flatten_with_path`. - Fixed unintented use of `tree` instead of `keras.tree` in unit test. - Ran unit tests for all backends with `dm-tree` uninstalled.
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20481 +/- ##
==========================================
- Coverage 82.07% 76.40% -5.68%
==========================================
Files 515 515
Lines 47504 47512 +8
Branches 7454 7457 +3
==========================================
- Hits 38991 36303 -2688
- Misses 6703 9452 +2749
+ Partials 1810 1757 -53
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
from absl.testing import parameterized | ||
|
||
from keras.src import backend | ||
from keras.src import metrics as losses_module | ||
from keras.src import metrics as metrics_module | ||
from keras.src import ops | ||
from keras.src import testing | ||
from keras.src import tree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be triggering a failure on the torch side https://github.com/keras-team/keras/actions/runs/11785192949/job/32825931002?pr=20481
(inability to recognize a NamedTuple as a loss tuple)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking into it. For some reason, this test required dm-tree
.
Testing further, I tested with optree
(the default) uninstalled and before this change, and a number of tests fail.
So some tests (and maybe some code) are still tree implementation dependent.
Wouldn't it be better to wrap e.g.: from collections import OrderedDict
import optree
from keras import tree
class WrappedOrderedDict(OrderedDict):
pass
def flatten(d):
values = []
keys = []
for key in sorted(d.keys()):
values.append(d[key])
keys.append(key)
return values, list(d.keys()), keys
def unflatten(metadata, children):
index = {key: i for i, key in enumerate(sorted(metadata))}
return OrderedDict({key: children[index[key]] for key in metadata})
optree.register_pytree_node(
WrappedOrderedDict,
flatten,
unflatten,
namespace="keras",
)
def ordereddict_pytree_test():
# Create an OrderedDict with deliberately unsorted keys
ordered_d = OrderedDict([('c', 3), ('a', 1), ('b', 2)])
def wrap(s):
if isinstance(s, OrderedDict):
return WrappedOrderedDict(s)
return None
def unwrap(s):
if isinstance(s, WrappedOrderedDict):
return OrderedDict(s)
return None
d = tree.traverse(wrap, ordered_d, top_down=False)
flat_d = tree.flatten(d)
flat_d_paths = tree.flatten_with_path(d)
assert flat_d == [1, 2, 3]
assert [p[0] for p, v in flat_d_paths] == ["a", "b", "c"]
tree_struct = tree.traverse(wrap, ordered_d, top_down=False)
wrapped_d = tree.pack_sequence_as(tree_struct, flat_d)
orig_d = tree.traverse(unwrap, wrapped_d, top_down=False)
assert isinstance(orig_d, OrderedDict)
assert list(orig_d.keys()) == list(ordered_d.keys())
assert list(orig_d.values()) == list(ordered_d.values())
ordereddict_pytree_test() |
Hi Nicolas, Thank you for the suggestion. I actually completely scratched this PR and decided to use a different approach. The |
The tree API had specific but contradicting documentation calling out the handling of
OrderedDict
s. However, the behavior of theoptree
implementation did not honor this documentation (using the key order, not the sequence order) forflatten
, although it did forpack_sequence_as
. The result was that not only didflatten
not behave the same withoptree
anddm-tree
, but alsopack_sequence_as(flatten(...))
was not idempotent. Theoptree
implementation did have all the machinery needed to handleOrderedDict
s per spec, which was used forpack_sequence_as
, but notflatten
. This also fixes the discrepancy in the behavior fornamedtuple
s.flatten
andpack_sequence_as
related to the handling ofOrderedDict
s.unflatten_as
, which doesn't exist.if optree
tests intree_test.py
, which should not exist for consistency betweenoptree
anddm-tree
.flatten_with_path
.tree
instead ofkeras.tree
in unit test.dm-tree
uninstalled.