Skip to content

Commit

Permalink
Make module wrapper transparent with respect to the scope
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 699104144
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Nov 27, 2024
1 parent b94b8d1 commit 045b8b9
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Changelog follow the https://keepachangelog.com/ standard (at least the headers)

## [Unreleased]

* Add `kd.nn.WrapperModule` to make a inner-module transparent with
respect of .

## [1.0.0] - 2024-11-21

* `kd.kontext.Path` now supports tensor slicing. So for example using keys like
Expand Down
1 change: 1 addition & 0 deletions kauldron/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

# Modules
from kauldron.modules.adapter import ExternalModule
from kauldron.modules.adapter import WrapperModule
from kauldron.modules.misc import Dropout
from kauldron.modules.misc import DummyModel
from kauldron.modules.misc import Identity
Expand Down
21 changes: 19 additions & 2 deletions kauldron/modules/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,25 @@
from kauldron.utils import train_property


class ExternalModule(nn.Module):
class WrapperModule(nn.Module):
"""Base class to wrapper a module.
The wrapper module transparent with respect to the inner parameters (
`{'params': inner_params}` instead of nesting
`{'params': {'model': inner_params}}`).
"""

model: nn.Module

def __post_init__(self):
super().__post_init__()
# Share scope, to make the wrapper module transparent with respect to the
# parameters (instead of nesting `{'params': {'model': params}}`).
if self.scope is not None:
nn.share_scope(self, self.model)


class ExternalModule(WrapperModule):
"""Module that is defined outside Kauldron.
This is a **very** thin wrapper around `flax.linen.Module` that add:
Expand Down Expand Up @@ -52,7 +70,6 @@ class ExternalModule(nn.Module):
can be inverted with `~` (e.g. `train_kwarg_name='~deterministic'`)
"""

model: nn.Module
keys: str | dict[str, str]
train_kwarg_name: Optional[str] = None

Expand Down
17 changes: 17 additions & 0 deletions kauldron/modules/adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Test."""

from typing import Any

from flax import linen as nn
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -48,3 +50,18 @@ def test_external():

assert not np.array_equal(out_train, inputs)
np.testing.assert_array_equal(out_eval, inputs)


def test_wrapper():
class MyWrapper(kd.nn.WrapperModule):

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.model(*args, **kwargs)

model = MyWrapper(
model=nn.Dense(2),
)

inputs = jnp.ones((5,))
params = model.init(jax.random.PRNGKey(0), inputs)
assert list(params['params']) == ['kernel', 'bias']

0 comments on commit 045b8b9

Please sign in to comment.