Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of OrderedDict with optree and related documentation. #20481

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

hertschuh
Copy link
Collaborator

The tree API had specific but contradicting documentation calling out the handling of OrderedDicts. However, the behavior of the optree implementation did not honor this documentation (using the key order, not the sequence order) for flatten, although it did for pack_sequence_as. The result was that not only did flatten not behave the same with optree and dm-tree, but also pack_sequence_as(flatten(...)) was not idempotent. The optree implementation did have all the machinery needed to handle OrderedDicts per spec, which was used for pack_sequence_as, but not flatten. This also fixes the discrepancy in the behavior for namedtuples.

  • Fixed contradicting documentation in flatten and pack_sequence_as related to the handling of OrderedDicts.
  • Fixed references to unflatten_as, which doesn't exist.
  • Removed most if optree tests in tree_test.py, which should not exist for consistency between optree and dm-tree.
  • Fixed unit tests which were incorrectly flattening the result of flatten_with_path.
  • Fixed unintented use of tree instead of keras.tree in unit test.
  • Ran unit tests for all backends with dm-tree uninstalled.

The tree API had specific but contradicting documentation calling out the handling of `OrderedDict`s. However, the behavior of the `optree` implementation did not honor this documentation (using the key order, not the sequence order) for `flatten`, although it did for `pack_sequence_as`. The result was that not only did `flatten` not behave the same with `optree` and `dm-tree`, but also `pack_sequence_as(flatten(...))` was not idempotent. The `optree` implementation did have all the machinery needed to handle `OrderedDict`s per spec, which was used for `pack_sequence_as`, but not `flatten`. This also fixes the discrepancy in the behavior for `namedtuple`s.

- Fixed contradicting documentation in `flatten` and `pack_sequence_as` related to the handling of `OrderedDict`s.
- Fixed references to `unflatten_as`, which doesn't exist.
- Removed most `if optree` tests in `tree_test.py`, which should not exist for consistency between `optree` and `dm-tree`.
- Fixed unit tests which were incorrectly flattening the result of `flatten_with_path`.
- Fixed unintented use of `tree` instead of `keras.tree` in unit test.
- Ran unit tests for all backends with `dm-tree` uninstalled.
@codecov-commenter
Copy link

codecov-commenter commented Nov 11, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 76.40%. Comparing base (b0b9d04) to head (def69ec).

❗ There is a different number of reports uploaded between BASE (b0b9d04) and HEAD (def69ec). Click for more details.

HEAD has 2 uploads less than BASE
Flag BASE (b0b9d04) HEAD (def69ec)
keras 4 3
keras-torch 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20481      +/-   ##
==========================================
- Coverage   82.07%   76.40%   -5.68%     
==========================================
  Files         515      515              
  Lines       47504    47512       +8     
  Branches     7454     7457       +3     
==========================================
- Hits        38991    36303    -2688     
- Misses       6703     9452    +2749     
+ Partials     1810     1757      -53     
Flag Coverage Δ
keras 76.33% <100.00%> (-5.60%) ⬇️
keras-jax 65.02% <100.00%> (+<0.01%) ⬆️
keras-numpy 59.98% <100.00%> (+<0.01%) ⬆️
keras-tensorflow 66.04% <100.00%> (+<0.01%) ⬆️
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

from absl.testing import parameterized

from keras.src import backend
from keras.src import metrics as losses_module
from keras.src import metrics as metrics_module
from keras.src import ops
from keras.src import testing
from keras.src import tree
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be triggering a failure on the torch side https://github.com/keras-team/keras/actions/runs/11785192949/job/32825931002?pr=20481

(inability to recognize a NamedTuple as a loss tuple)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into it. For some reason, this test required dm-tree.

Testing further, I tested with optree (the default) uninstalled and before this change, and a number of tests fail.

So some tests (and maybe some code) are still tree implementation dependent.

@hertschuh hertschuh marked this pull request as draft November 12, 2024 16:42
@nicolaspi
Copy link
Contributor

Wouldn't it be better to wrap OrderedDict rather than re-implement flatten, which is written in C++ in optree?

e.g.:

from collections import OrderedDict
import optree
from keras import tree

class WrappedOrderedDict(OrderedDict):
  pass

def flatten(d):
  values = []
  keys = []
  for key in sorted(d.keys()):
    values.append(d[key])
    keys.append(key)
  return values, list(d.keys()), keys

def unflatten(metadata, children):
  index = {key: i for i, key in enumerate(sorted(metadata))}
  return OrderedDict({key: children[index[key]] for key in metadata})


optree.register_pytree_node(
    WrappedOrderedDict,
    flatten,
    unflatten,
    namespace="keras",
)


def ordereddict_pytree_test():
  # Create an OrderedDict with deliberately unsorted keys
  ordered_d = OrderedDict([('c', 3), ('a', 1), ('b', 2)])
  
  def wrap(s):
    if isinstance(s, OrderedDict):
      return WrappedOrderedDict(s)
    return None

  def unwrap(s):
    if isinstance(s, WrappedOrderedDict):
      return OrderedDict(s)
    return None

  d = tree.traverse(wrap, ordered_d, top_down=False)
  flat_d = tree.flatten(d)
  flat_d_paths = tree.flatten_with_path(d)
  assert flat_d == [1, 2, 3]
  assert [p[0] for p, v in flat_d_paths] == ["a", "b", "c"]
  
  tree_struct = tree.traverse(wrap, ordered_d, top_down=False)
  wrapped_d = tree.pack_sequence_as(tree_struct, flat_d)
  orig_d = tree.traverse(unwrap, wrapped_d, top_down=False)
  assert isinstance(orig_d, OrderedDict)
  assert list(orig_d.keys()) == list(ordered_d.keys())
  assert list(orig_d.values()) == list(ordered_d.values())


ordereddict_pytree_test()

@hertschuh
Copy link
Collaborator Author

Wouldn't it be better to wrap OrderedDict rather than re-implement flatten, which is written in C++ in optree?

Hi Nicolas,

Thank you for the suggestion. I actually completely scratched this PR and decided to use a different approach. The optree behavior will be the reference behavior. The goal is indeed to maximize the use of the C++ implementation of optree since it is the default and dm-tree is only a fallback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants