You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Make the fixed step solvers compileable/differentiable with respect to the t_span and t_eval arguments.
This isn't an extremely urgent issue, however it would still be very nice to round out the features of these solvers.
Details
Issue #122 outlines a bug that is ultimately due to the fact that the JAX solvers in dynamics cannot be compiled if t_eval is not None. The fix PR, #125, resolves this issue by updating jax_odeint and the diffrax solver wrapper so that they can be compiled with respect to t_eval.
As described in #122 however, updating the fixed step JAX solvers built in dynamics to be compilable with respect to both t_span and t_eval is non-trivial due to their looping structure being dependent on the values of t_span and t_eval. As a result, the fix #125 is only partial: in the case of JAX fixed step solvers, the problem is simply avoided rather than being fundamentally fixed.
To make the fixed step solvers fully compilable/differentiable with respect to the t_span and t_eval arguments, the functions fixed_step_solver_template_jax and fixed_step_lmde_solver_parallel_template_jax need to be updated to use more advanced JAX control flow. Preserving differentiability with respect to other parameters may also require defining custom differentiation rules (vjp and jvp rules).
The text was updated successfully, but these errors were encountered:
Summary
Make the fixed step solvers compileable/differentiable with respect to the
t_span
andt_eval
arguments.This isn't an extremely urgent issue, however it would still be very nice to round out the features of these solvers.
Details
Issue #122 outlines a bug that is ultimately due to the fact that the JAX solvers in dynamics cannot be compiled if
t_eval is not None
. The fix PR, #125, resolves this issue by updatingjax_odeint
and the diffrax solver wrapper so that they can be compiled with respect tot_eval
.As described in #122 however, updating the fixed step JAX solvers built in dynamics to be compilable with respect to both
t_span
andt_eval
is non-trivial due to their looping structure being dependent on the values oft_span
andt_eval
. As a result, the fix #125 is only partial: in the case of JAX fixed step solvers, the problem is simply avoided rather than being fundamentally fixed.To make the fixed step solvers fully compilable/differentiable with respect to the
t_span
andt_eval
arguments, the functionsfixed_step_solver_template_jax
andfixed_step_lmde_solver_parallel_template_jax
need to be updated to use more advanced JAX control flow. Preserving differentiability with respect to other parameters may also require defining custom differentiation rules (vjp
andjvp
rules).The text was updated successfully, but these errors were encountered: