-
Notifications
You must be signed in to change notification settings - Fork 648
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4377 from google:nnx-pytree-optimization
PiperOrigin-RevId: 697044650
- Loading branch information
Showing
5 changed files
with
368 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright 2024 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# %% | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import optax | ||
from time import time | ||
|
||
from flax import nnx | ||
|
||
from absl import flags | ||
from absl import app | ||
|
||
FLAGS = flags.FLAGS | ||
flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in') | ||
flags.DEFINE_integer('total_steps', 100, 'Total number of training steps') | ||
flags.DEFINE_integer('width', 32, 'Hidden layer size') | ||
flags.DEFINE_integer('depth', 5, 'Depth of the model') | ||
|
||
|
||
|
||
class Linear(nnx.Module): | ||
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): | ||
self.list = [ | ||
nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), | ||
nnx.Param(jnp.zeros((dout,))), | ||
] | ||
self.dict = { | ||
'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), | ||
'b': nnx.Param(jnp.zeros((dout,))), | ||
} | ||
|
||
|
||
|
||
class MLP(nnx.Module): | ||
def __init__(self, depth, *, rngs: nnx.Rngs): | ||
self.intermediates = [ | ||
Linear(10, 10, rngs=rngs) for _ in range(depth) | ||
] | ||
|
||
|
||
def main(argv): | ||
print(argv) | ||
mode: str = FLAGS.mode | ||
total_steps: int = FLAGS.total_steps | ||
width: int = FLAGS.width | ||
depth: int = FLAGS.depth | ||
|
||
print(f'{mode=}, {total_steps=}, {width=}') | ||
|
||
X = np.linspace(0, 1, 100)[:, None] | ||
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) | ||
|
||
model = MLP(depth=depth, rngs=nnx.Rngs(0)) | ||
tx = optax.sgd(1e-3) | ||
optimizer = nnx.Optimizer(model, tx) | ||
|
||
#------------------------------------------------------------ | ||
# NNX | ||
#------------------------------------------------------------ | ||
if mode in ['all', 'nnx']: | ||
@nnx.jit | ||
def step_nnx(model: MLP, optimizer: nnx.Optimizer): | ||
pass | ||
|
||
t0 = time() | ||
for _ in range(total_steps): | ||
step_nnx(model, optimizer) | ||
|
||
total_time = time() - t0 | ||
time_per_step = total_time / total_steps | ||
time_per_layer = time_per_step / depth | ||
print("### NNX ###") | ||
print('total time:', total_time) | ||
print(f'time per step: {time_per_step * 1e6:.2f} µs') | ||
print(f'time per layer: {time_per_layer * 1e6:.2f} µs') | ||
|
||
|
||
#------------------------------------------------------------ | ||
# JAX | ||
#------------------------------------------------------------ | ||
|
||
if mode in ['all', 'jax']: | ||
@jax.jit | ||
def step_jax(graphdef, state): | ||
return graphdef, state | ||
|
||
graphdef, state = nnx.split((model, optimizer)) | ||
t0 = time() | ||
for _ in range(total_steps): | ||
graphdef, state = step_jax(graphdef, state) | ||
|
||
total_time = time() - t0 | ||
time_per_step = total_time / total_steps | ||
time_per_layer = time_per_step / depth | ||
print("### JAX ###") | ||
print('total time:', total_time) | ||
print(f'time per step: {time_per_step * 1e6:.2f} µs') | ||
print(f'time per layer: {time_per_layer * 1e6:.2f} µs') | ||
print() | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright 2024 The Flax Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# %% | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import optax | ||
from time import time | ||
|
||
from flax import nnx | ||
|
||
from absl import flags | ||
from absl import app | ||
|
||
FLAGS = flags.FLAGS | ||
flags.DEFINE_enum('mode', 'nnx', ['nnx', 'jax'], 'Mode to run the script in') | ||
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') | ||
flags.DEFINE_integer('batch_size', 32, 'Batch size') | ||
flags.DEFINE_integer('width', 32, 'Hidden layer size') | ||
flags.DEFINE_integer('depth', 5, 'Depth of the model') | ||
|
||
|
||
def dataset(X, Y, batch_size): | ||
while True: | ||
idx = np.random.choice(len(X), size=batch_size) | ||
yield X[idx], Y[idx] | ||
|
||
|
||
class Linear(nnx.Module): | ||
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): | ||
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) | ||
self.b = nnx.Param(jnp.zeros((dout,))) | ||
|
||
def __call__(self, x): | ||
return x @ self.w + self.b | ||
|
||
|
||
class Count(nnx.Variable): | ||
pass | ||
|
||
|
||
class MLP(nnx.Module): | ||
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs): | ||
self.count = Count(jnp.array(0)) | ||
self.linear_in = Linear(din, dhidden, rngs=rngs) | ||
self.intermediates = [ | ||
Linear(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2) | ||
] | ||
self.linear_out = Linear(dhidden, dout, rngs=rngs) | ||
|
||
def __call__(self, x): | ||
self.count.value += 1 | ||
x = nnx.relu(self.linear_in(x)) | ||
for layer in self.intermediates: | ||
x = nnx.relu(layer(x)) | ||
x = self.linear_out(x) | ||
return x | ||
|
||
|
||
def main(argv): | ||
print(argv) | ||
mode: str = FLAGS.mode | ||
total_steps: int = FLAGS.total_steps | ||
batch_size: int = FLAGS.batch_size | ||
width: int = FLAGS.width | ||
depth: int = FLAGS.depth | ||
|
||
print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') | ||
|
||
if mode not in ['nnx', 'jax']: | ||
raise ValueError(f'Invalid mode: {mode}') | ||
|
||
X = np.linspace(0, 1, 100)[:, None] | ||
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) | ||
|
||
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0)) | ||
tx = optax.sgd(1e-3) | ||
optimizer = nnx.Optimizer(model, tx) | ||
t0 = time() | ||
|
||
if mode == 'nnx': | ||
|
||
@nnx.jit | ||
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch): | ||
x, y = batch | ||
|
||
def loss_fn(model: MLP): | ||
y_pred = model(x) | ||
return jnp.mean((y - y_pred) ** 2) | ||
|
||
grads: nnx.State = nnx.grad(loss_fn)(model) | ||
optimizer.update(grads) | ||
|
||
@nnx.jit | ||
def test_step_nnx(model: MLP, batch): | ||
x, y = batch | ||
y_pred = model(x) | ||
loss = jnp.mean((y - y_pred) ** 2) | ||
return {'loss': loss} | ||
|
||
for step, batch in enumerate(dataset(X, Y, batch_size)): | ||
train_step_nnx(model, optimizer, batch) | ||
|
||
if step % 1000 == 0: | ||
logs = test_step_nnx(model, (X, Y)) | ||
print(f"step: {step}, loss: {logs['loss']}") | ||
|
||
if step >= total_steps - 1: | ||
break | ||
else: | ||
|
||
@jax.jit | ||
def train_step_jax(graphdef, state, batch): | ||
model, optimizer = nnx.merge(graphdef, state) | ||
x, y = batch | ||
|
||
def loss_fn(model: MLP): | ||
y_pred = model(x) | ||
return jnp.mean((y - y_pred) ** 2) | ||
|
||
grads = nnx.grad(loss_fn)(model) | ||
optimizer.update(grads) | ||
|
||
return nnx.state((model, optimizer)) | ||
|
||
@jax.jit | ||
def test_step_jax(graphdef, state, batch): | ||
model, optimizer = nnx.merge(graphdef, state) | ||
x, y = batch | ||
y_pred = model(x) | ||
loss = jnp.mean((y - y_pred) ** 2) | ||
state = nnx.state((model, optimizer)) | ||
return state, {'loss': loss} | ||
|
||
graphdef, state = nnx.split((model, optimizer)) | ||
|
||
for step, batch in enumerate(dataset(X, Y, batch_size)): | ||
state = train_step_jax(graphdef, state, batch) | ||
|
||
if step % 1000 == 0: | ||
state, logs = test_step_jax(graphdef, state, (X, Y)) | ||
print(f"step: {step}, loss: {logs['loss']}") | ||
|
||
if step >= total_steps - 1: | ||
break | ||
|
||
model, optimizer = nnx.merge(graphdef, state) | ||
|
||
total_time = time() - t0 | ||
print('total time:', total_time) | ||
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') | ||
print('times called:', model.count.value) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
Oops, something went wrong.