Skip to content

Commit

Permalink
Merge pull request #4351 from jlperla:lbfgs_support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692253971
  • Loading branch information
Flax Authors committed Nov 1, 2024
2 parents 591cd40 + 342adde commit d8b1a92
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 3 deletions.
8 changes: 5 additions & 3 deletions flax/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
self.opt_state = _wrap_optimizer_state(tx.init(nnx.state(model, wrt)))
self.wrt = wrt

def update(self, grads):
def update(self, grads, **kwargs):
"""Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value.
The ``grads`` must be derived from ``nnx.grad(..., wrt=self.wrt)``, where the
gradients are with respect to the same :class:`Variable` types as defined in
Expand Down Expand Up @@ -249,14 +249,16 @@ def update(self, grads):
Args:
grads: the gradients derived from ``nnx.grad``.
**kwargs: additional keyword arguments passed to the tx.update, to support
``GradientTransformationExtraArgs``, such as ``optax.scale_by_backtracking_linesearch``.
"""
params = nnx.state(self.model, self.wrt)
opt_state = _opt_state_variables_to_state(self.opt_state)

updates, new_opt_state = self.tx.update(grads, opt_state, params)
updates, new_opt_state = self.tx.update(grads, opt_state, params, **kwargs)
new_params = optax.apply_updates(params, updates)
assert isinstance(new_params, nnx.State)

self.step.value += 1
nnx.update(self.model, new_params)
_update_opt_state(self.opt_state, new_opt_state)
_update_opt_state(self.opt_state, new_opt_state)
101 changes: 101 additions & 0 deletions tests/nnx/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,58 @@ def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y):

self.assertTrue(new_loss < initial_loss)


@parameterized.product(
module_cls=[nnx.Linear, Model],
jit_decorator=[lambda f: f, nnx.jit, jax.jit],
optimizer=[optax.lbfgs],
)
def test_jit_linesearch(self, module_cls, jit_decorator, optimizer):
x = jax.random.normal(jax.random.key(0), (1, 2))
y = jnp.ones((1, 4))
model = module_cls(2, 4, rngs=nnx.Rngs(0))
tx = optimizer(
1e-3
)
state = nnx.Optimizer(model, tx)

if jit_decorator == jax.jit:
model_static, model_state = nnx.split(state.model)
loss_fn = lambda graphdef, state, x, y: (
(nnx.merge(graphdef, state)(x) - y) ** 2
).mean()
initial_loss = loss_fn(model_static, model_state, x, y)

def jax_jit_train_step(graphdef, state, x, y):
state = nnx.merge(graphdef, state)
model_static, model_state = nnx.split(state.model)
grads = jax.grad(loss_fn, argnums=1)(model_static, model_state, x, y)
state.update(grads, grad = grads, value = initial_loss, value_fn = lambda state: loss_fn(model_static, state, x, y))
return nnx.split(state)

graphdef, state = jit_decorator(jax_jit_train_step)(
*nnx.split(state), x, y
)
state = nnx.merge(graphdef, state)
new_loss = loss_fn(*nnx.split(state.model), x, y)

else:
graphdef = nnx.graphdef(model)
loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()

loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y)

initial_loss = loss_fn(state.model, x, y)

def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y):
grads = nnx.grad(loss_fn)(optimizer.model, x, y)
optimizer.update(grads, grad = grads, value = initial_loss, value_fn = loss_fn_split)

jit_decorator(nnx_jit_train_step)(state, x, y)
new_loss = loss_fn(state.model, x, y)

self.assertTrue(new_loss < initial_loss)

@parameterized.product(
module_cls=[nnx.Linear, Model],
optimizer=[optax.sgd, optax.adam],
Expand Down Expand Up @@ -203,6 +255,55 @@ def test_wrt_update(self, variable):
)
)

@parameterized.parameters(
{'variable': nnx.Param},
#{'variable': nnx.LoRAParam},
{'variable': (nnx.Param, nnx.LoRAParam)},
)
def test_wrt_update_linesearch(self, variable):
in_features = 4
out_features = 10
model = nnx.LoRA(
in_features=in_features,
lora_rank=2,
out_features=out_features,
base_module=Model(
in_features=in_features, out_features=out_features, rngs=nnx.Rngs(0)
),
rngs=nnx.Rngs(1),
)
state = nnx.Optimizer(model, optax.lbfgs(), wrt=variable)
prev_variables, prev_other_variables = nnx.state(model, variable, ...)

x = jnp.ones((1, 4))
y = jnp.ones((1, 10))
loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean()

grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, variable))(
state.model, x, y
)
initial_loss = loss_fn(model, x, y)
graphdef = nnx.graphdef(model)
loss_fn_split = lambda state: loss_fn(nnx.merge(graphdef, state), x, y)

state.update(grads, grad=grads, value_fn = loss_fn_split, value = initial_loss)
self.assertTrue(loss_fn(model, x, y) < initial_loss)

# make sure only the Variable's filtered in `wrt` are changed, and the others are unchanged
variables, other_variables = nnx.state(model, variable, ...)
self.assertTrue(
jax.tree.all(
jax.tree.map(lambda x, y: (x != y).all(), prev_variables, variables)
)
)
if other_variables:
self.assertTrue(
jax.tree.all(
jax.tree.map(
lambda x, y: (x == y).all(), prev_other_variables, other_variables
)
)
)

if __name__ == '__main__':
absltest.main()

0 comments on commit d8b1a92

Please sign in to comment.