-
Notifications
You must be signed in to change notification settings - Fork 1
/
xtrain_order4_zcs_jvp.py
143 lines (115 loc) · 5.01 KB
/
xtrain_order4_zcs_jvp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Training with ZCS on JAX
"""
import argparse
from functools import partial
from time import time
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state
from tqdm import trange
from src.model import DeepONet
from src.utils import mse_to_zeros, batched_l2_relative_error, DiffRecData
def compute_u_pde(forward_zcs_fn, z_x, z_t, source_input):
""" diffusion-reaction equation with ZCS """
# constants
d, k = 0.01, 0.01
# grad functions with jvp
def get_u_and_u_t(z_x_, z_t_):
return jax.jvp(forward_zcs_fn, (z_x_, z_t_), (0., 1.))
def get_u_x(z_x_, z_t_):
return jax.jvp(forward_zcs_fn, (z_x_, z_t_), (1., 0.))[1]
def get_u_xx(z_x_, z_t_):
return jax.jvp(get_u_x, (z_x_, z_t_), (1., 0.))[1]
def get_u_xxx(z_x_, z_t_):
return jax.jvp(get_u_xx, (z_x_, z_t_), (1., 0.))[1]
def get_u_xxxx(z_x_, z_t_):
return jax.jvp(get_u_xxx, (z_x_, z_t_), (1., 0.))[1]
# u and pde
u, u_t = get_u_and_u_t(z_x, z_t)
u_xxxx = get_u_xxxx(z_x, z_t)
pde = u_t - d * u_xxxx + k * u ** 2 - source_input
return u, pde
@partial(jax.jit, static_argnames=['n_points_pde', 'n_points_bc', 'model_interface'])
def train_step(state, branch_input, trunk_input, source_input,
n_points_pde, n_points_bc, model_interface,
z_x, z_t): # define zcs outside for jit
""" train for a single step """
def loss_fn(params_):
""" loss function, for AD w.r.t. network weights """
def forward_zcs_fn(z_x_, z_t_):
""" forward function, for AD w.r.t. zcs scalars """
z_xt = jnp.stack((z_x_, z_t_))
trunk_in_zcs = trunk_input + z_xt[None, :]
return model_interface.apply({'params': params_},
branch_in=branch_input, trunk_in=trunk_in_zcs)
u_val, pde_val = compute_u_pde(forward_zcs_fn, z_x, z_t, source_input)
pde_loss_ = mse_to_zeros(pde_val[:, :n_points_pde])
bc_loss_ = mse_to_zeros(u_val[:, n_points_pde:n_points_pde + n_points_bc])
ic_loss_ = mse_to_zeros(u_val[:, n_points_pde + n_points_bc:])
return pde_loss_ + bc_loss_ + ic_loss_, (pde_loss_, bc_loss_, ic_loss_)
# loss
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (pde_loss, bc_loss, ic_loss)), grads = grad_fn(state.params)
# update
state = state.apply_gradients(grads=grads)
return state, loss, pde_loss, bc_loss, ic_loss
def run():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-i', '--iterations', type=int, default=1000,
help='number of iterations')
parser.add_argument('-M', '--n-functions', type=int, default=50,
help='number of functions in a batch')
parser.add_argument('-N', '--n-points', type=int, default=4000,
help='number of collocation points in a batch')
args = parser.parse_args()
# load data
data = DiffRecData(data_path='./data')
# number of functions and points in a batch
# Note: the default values come from the diff_rec example in DeepXDE-ZCS
# for comparison with pytorch, tensorflow and paddle
N_FUNCTIONS = args.n_functions # noqa
N_POINTS_PDE = args.n_points # noqa
# train state
model = DeepONet(branch_features=(data.n_features, 128, 128, 128),
trunk_features=(data.n_dims, 128, 128, 128),
cartesian_prod=True)
branch_in, trunk_in, _, _ = data.sample_batch(N_FUNCTIONS, N_POINTS_PDE)
params = model.init(jax.random.PRNGKey(0), branch_in, trunk_in)['params']
optimizer = optax.adam(learning_rate=0.0005)
the_state = train_state.TrainState.create(
apply_fn=model.apply, params=params, tx=optimizer)
#################
# training loop #
#################
# zcs variables
z_x, z_t = jnp.zeros(()), jnp.zeros(())
pbar = trange(args.iterations, desc='Training')
t_first, t_total = 0., 0.
for it in pbar:
# sample data
branch_in, trunk_in, source_in, _ = data.sample_batch(
N_FUNCTIONS, N_POINTS_PDE, seed=it, train=True)
# update
t0 = time() # wall time excluding data sampling
the_state, loss_val, pde_loss_val, bc_loss_val, ic_loss_val = \
train_step(the_state, branch_in, trunk_in, source_in,
N_POINTS_PDE, data.n_bc, model,
z_x, z_t)
pbar.set_postfix_str(f"L_pde={pde_loss_val:.4e}, "
f"L_bc={bc_loss_val:.4e}, "
f"L_ic={ic_loss_val:.4e}, "
f"L={loss_val:.4e}")
if it == 0:
t_first += time() - t0
else:
t_total += time() - t0
print(f'Jit-compile done in {t_first:.1f} seconds')
print(f'Training done in {t_total:.1f} seconds')
##############
# evaluation #
##############
print('No true solution for validation')
if __name__ == "__main__":
run()