diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 3b3565ef29..2a92b8b5ad 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -24,7 +24,6 @@ import numpy as np import typing_extensions as tpe -from flax.core.frozen_dict import FrozenDict from flax.nnx import filterlib, reprlib from flax.nnx.proxy_caller import ( ApplyCaller, @@ -183,7 +182,7 @@ def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: return _node_impl_for_type[x] -class _HashableMapping(tp.Mapping[HA, HB], tp.Hashable): +class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): def __init__(self, mapping: tp.Mapping[HA, HB] | tp.Iterable[tuple[HA, HB]]): self._mapping = dict(mapping) @@ -204,7 +203,7 @@ def __hash__(self) -> int: def __eq__(self, other: tp.Any) -> bool: return ( - isinstance(other, _HashableMapping) and self._mapping == other._mapping + isinstance(other, HashableMapping) and self._mapping == other._mapping ) def __repr__(self) -> str: @@ -246,7 +245,7 @@ def __treescope_repr__(self, path, subtree_renderer): class VariableDef(reprlib.Representable): type: type[Variable] index: int - metadata: FrozenDict[str, tp.Any] + metadata: HashableMapping[str, tp.Any] def __nnx_repr__(self): yield reprlib.Object(type=type(self)) @@ -272,7 +271,7 @@ def __treescope_repr__(self, path, subtree_renderer): jax.tree_util.register_static(VariableDef) -@dataclasses.dataclass(frozen=True, repr=False) +@dataclasses.dataclass(frozen=True, repr=False, slots=True) class NodeDef(GraphDef[Node], reprlib.Representable): """A dataclass that denotes the tree structure of a :class:`Module`. A ``GraphDef`` can be generated by either @@ -281,11 +280,11 @@ class NodeDef(GraphDef[Node], reprlib.Representable): type: tp.Type[Node] index: int attributes: tuple[Key, ...] - subgraphs: _HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]] - static_fields: _HashableMapping[Key, tp.Any] - leaves: _HashableMapping[Key, VariableDef | NodeRef[tp.Any]] + subgraphs: HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]] + static_fields: HashableMapping[Key, tp.Any] + leaves: HashableMapping[Key, VariableDef | NodeRef[tp.Any]] metadata: tp.Any - index_mapping: FrozenDict[Index, Index] | None + index_mapping: HashableMapping[Index, Index] | None @classmethod def create( @@ -303,11 +302,11 @@ def create( type=type, index=index, attributes=attributes, - subgraphs=_HashableMapping(subgraphs), - static_fields=_HashableMapping(static_fields), - leaves=_HashableMapping(leaves), + subgraphs=HashableMapping(subgraphs), + static_fields=HashableMapping(static_fields), + leaves=HashableMapping(leaves), metadata=metadata, - index_mapping=FrozenDict(index_mapping) + index_mapping=HashableMapping(index_mapping) if index_mapping is not None else None, ) @@ -424,7 +423,7 @@ def _graph_flatten( flat_state[(*path, key)] = value.to_state() variable_index = ref_index[value] = len(ref_index) variabledef = VariableDef( - type(value), variable_index, FrozenDict(value.get_metadata()) + type(value), variable_index, HashableMapping(value.get_metadata()) ) leaves.append((key, variabledef)) else: @@ -794,7 +793,7 @@ def split( if ctx.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(ctx.index_ref, self.ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=FrozenDict(index_to_index) + graphdef, index_mapping=HashableMapping(index_to_index) ) return graphdef, *states @@ -984,7 +983,7 @@ def split( if self.index_ref is not None and isinstance(graphdef, NodeDef): index_to_index = compose_mapping(self.index_ref, ref_index) graphdef = dataclasses.replace( - graphdef, index_mapping=FrozenDict(index_to_index) + graphdef, index_mapping=HashableMapping(index_to_index) ) self.flatten_end(ref_index) diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index b86823c527..7194bc33ea 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -19,7 +19,6 @@ from flax import struct -from flax.core.frozen_dict import FrozenDict from flax.nnx import ( extract, filterlib, @@ -428,7 +427,7 @@ def _custom_vjp_split_fn( nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) -def _extract_index_mappings(x, *, index_mappings: deque[FrozenDict]): +def _extract_index_mappings(x, *, index_mappings: deque[graph.HashableMapping]): if isinstance(x, graph.NodeDef): assert x.index_mapping is not None index_mappings.append(x.index_mapping) @@ -466,7 +465,9 @@ def __call__(self, *pure_args): (args_out, out), ctxtag=self.ctxtag ) # remove index_mapping from NodeDef's but store them in global context - index_mappings: deque[FrozenDict] = extract.get_broadcast_state(self.ctxtag) + index_mappings: deque[graph.HashableMapping] = extract.get_broadcast_state( + self.ctxtag + ) pure_args_out, pure_out = jax.tree.map( functools.partial(_extract_index_mappings, index_mappings=index_mappings), @@ -519,8 +520,8 @@ def __call__(self, *pure_args): if update_context_active: # remove index_mapping from NodeDef's but store them in global context - index_mappings: deque[FrozenDict] = extract.get_broadcast_state( - self.ctxtag + index_mappings: deque[graph.HashableMapping] = ( + extract.get_broadcast_state(self.ctxtag) ) pure_args_out, pure_out = jax.tree.map( functools.partial( @@ -631,7 +632,7 @@ def __call__( for i, x in enumerate(tree_node_args) if i not in self.jax_nondiff_argnums ) - index_mappings: deque[FrozenDict] = deque() + index_mappings: deque[graph.HashableMapping] = deque() with extract.broadcast_state(self.ctxtag, index_mappings): if self.fwd is None or self.bwd is None or self.symbolic_zeros is None: raise ValueError() @@ -663,7 +664,7 @@ def __call__( # insert index_mappings def _insert_index_mappings(x): if isinstance(x, graph.NodeDef): - index_mapping: FrozenDict = index_mappings.popleft() + index_mapping: graph.HashableMapping = index_mappings.popleft() return dataclasses.replace(x, index_mapping=index_mapping) return x diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 20366c3e1f..1ea33d78d5 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -650,7 +650,7 @@ def check_carry_same_references(key_path, arg, out): def _extract_index_mappings( pure_carry_arg_out, - carry_index_mappings: list[FrozenDict[int, int]], + carry_index_mappings: list[graph.HashableMapping[int, int]], /, ): def extract_index_mappings(x): @@ -675,7 +675,7 @@ def extract_index_mappings(x): def _insert_index_mappings( pure_carry_arg_out, - carry_index_mappings: deque[FrozenDict[int, int]], + carry_index_mappings: deque[graph.HashableMapping[int, int]], /, ): def insert_index_mappings(x): @@ -1096,7 +1096,7 @@ def __call__( # next we have to remove all the index_mappings from the NodeDefs # in the carry outputs because they are not present in the inputs - carry_index_mappings: list[FrozenDict[int, int]] = [] + carry_index_mappings: list[graph.HashableMapping[int, int]] = [] pure_carry_arg_out = _extract_index_mappings( pure_carry_arg_out, carry_index_mappings ) @@ -1347,10 +1347,12 @@ def per_node_def(nd: graph.NodeDef | tp.Any): return per_node_def(ns._graphdef) - return dataclasses.replace(ns, _graphdef=dataclasses.replace( - ns._graphdef, - index_mapping=FrozenDict(global_index_mapping) - )) + return dataclasses.replace( + ns, + _graphdef=dataclasses.replace( + ns._graphdef, index_mapping=graph.HashableMapping(global_index_mapping) + ), + ) return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) diff --git a/uv.lock b/uv.lock index 7f3e9ab32c..b4a0aaa65d 100644 --- a/uv.lock +++ b/uv.lock @@ -81,6 +81,12 @@ dependencies = [ { name = "etils", version = "1.9.2", source = { registry = "https://pypi.org/simple" }, extra = ["epath"], marker = "python_full_version >= '3.11' and platform_system != 'Darwin'" }, ] wheels = [ + { url = "https://files.pythonhosted.org/packages/ff/9b/fe3cc94350cf082d3fb70a1393b259cd1d9bce5212f14f53deea1008b94b/array_record-0.5.1-1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3dbfac79589b53ad765d247b4b6b6c108623053950a8ae36d8a5f2bfec396bd1", size = 2140349 }, + { url = "https://files.pythonhosted.org/packages/ce/fd/a241172b054f0c496cc575a6081e2b457ef3cf520e652ee22f3035714535/array_record-0.5.1-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a0911cca3f71aa6724ae08c351e486acc2dcdc098df0e4ae9aa920f16aee2385", size = 2200584 }, + { url = "https://files.pythonhosted.org/packages/7f/b9/ab118be4efaae976db4dbffbf4d9479151509668261d95beaa80a956a757/array_record-0.5.1-1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e6b297b9241d10f072f00a85e97c8743c9e623be20e413ab3403b9326ed98890", size = 2140482 }, + { url = "https://files.pythonhosted.org/packages/d8/5e/9379b00e5b17ea280845b82c492cab9298eb658ea9d40c21f0fd064a4dd5/array_record-0.5.1-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e39b2001fbed6f6d621a5f2079609037167ee06bf977fd6c37d225043c39a015", size = 2200598 }, + { url = "https://files.pythonhosted.org/packages/45/9b/74eb64c839871cb3adfb254246e42be8a7ce636debe9ab9a3748cb0c484b/array_record-0.5.1-1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c4e6e5cef45a82641f4bb008c2a1409cd043f46dd3f0e5a2e7f232416435186d", size = 2140093 }, + { url = "https://files.pythonhosted.org/packages/b7/4d/8ed8fbef16144db66b92e3fcbcb4656edaa5cf538d20fe7913c1caa78b68/array_record-0.5.1-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:927f5f0bdbb141e75d370ade9ce784514babcb78f86d23badbab2d7fd6b7cd48", size = 2200996 }, { url = "https://files.pythonhosted.org/packages/76/85/f8e77e0ee6644ab3585de1b73a183e6831ded6e7b791f21a3de5f6e29aeb/array_record-0.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b9f2e304e59a17af9f5bf2a86b93ad4700d0eeb85d742a884aa38dc0b54dda5b", size = 2135133 }, { url = "https://files.pythonhosted.org/packages/9e/da/a7c513f35d4878888ca5d1e8548324e90414106ece7b44908002c800a22f/array_record-0.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:897362036f2920093eff3d729c2a6e1844e3077f513d6bd29640cd02f98e07c7", size = 2195378 }, { url = "https://files.pythonhosted.org/packages/61/7f/e0329a2aad1cf96e2b797e55e744af94c3d8d1969240c0153660214477c0/array_record-0.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ebe99f37e3a797322f4f5cfc6902b5e852012ba2729fac628aad6affb225247", size = 2135268 }, @@ -767,7 +773,7 @@ wheels = [ [[package]] name = "flax" -version = "0.10.1" +version = "0.10.0" source = { editable = "." } dependencies = [ { name = "jax" }, @@ -2255,7 +2261,7 @@ wheels = [ [[package]] name = "orbax-checkpoint" -version = "0.8.0" +version = "0.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -2272,9 +2278,9 @@ dependencies = [ { name = "tensorstore" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/66/48/54339d92c2b37f2ddea72501653f1b85a85ca2f19f4102b4b966260c2700/orbax_checkpoint-0.8.0.tar.gz", hash = "sha256:0754ecc2e5fc858e62bbcf610606502d8e1c9ada7295d9bb49cc172f884b0b1e", size = 206396 } +sdist = { url = "https://files.pythonhosted.org/packages/48/ce/4a3386e9e4bc95ab638c2779e259a231ffc2ae1fc2b67a68f4d6d8a794f2/orbax_checkpoint-0.9.1.tar.gz", hash = "sha256:edf76d8fc482a9a0296645522f6d13ee09c499af0ef2f9369899cbfee31c7f88", size = 212247 } wheels = [ - { url = "https://files.pythonhosted.org/packages/28/35/1a3ec885f192884867c1325920171d67ca2fa9122837ea96af284a2a2f05/orbax_checkpoint-0.8.0-py3-none-any.whl", hash = "sha256:df8e353feb7f4eeba9f5b16f704699df54c3c44c5c6ec4d4d117c40bf27830cc", size = 286357 }, + { url = "https://files.pythonhosted.org/packages/8a/41/19de59b4a8581ad364970b2429481838cea8d190f50bc8f672ca329b64e0/orbax_checkpoint-0.9.1-py3-none-any.whl", hash = "sha256:d33e23e63b7ffcf66de3fb5daa3d11bf24a34e66d133eb4a037788f22155857e", size = 296514 }, ] [[package]]