You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Following the end of the discussion in #26 - I am working on the qutip-jax-dia branch and tried to implement cubic splines from jax_cosmo for a time dependent hamiltonian simulation (master equation with static collapse operators) which can be massively speed up with jax for re-running the same simulation with different TD params. I can run the simulation without jitting the function sim() on a CPU but it's very slow (as to be expected).
I know qutip-jax-dia is in beta beta, but maybe some of the clever people here have some suggestions as to why I can't jit the function.
For reference I'm working on osx-arm64 with an M1 chip.
The error message I get is:
Traceback (most recent call last):
File "/Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py", line 105, in <module>
ys = fast_sim(single_sample)
File "/Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py", line 65, in sim
result=qt.mesolve(H=[[H_p, jax.jit(pump)], [H_s, jax.jit(stokes)], [H_dp, jax.jit(delta_pump)], [H_d2, jax.jit(delta_stokes)]],
File "/Users/janoleernst/anaconda3/envs/qutip-jax/lib/python3.10/site-packages/qutip/solver/mesolve.py", line 128, in mesolve
H = QobjEvo(H, args=args, tlist=tlist)
File "qutip/core/cy/qobjevo.pyx", line 242, in qutip.core.cy.qobjevo.QobjEvo.__init__
File "qutip/core/cy/qobjevo.pyx", line 807, in qutip.core.cy.qobjevo.QobjEvo.compress
File "qutip/core/cy/qobjevo.pyx", line 761, in qutip.core.cy.qobjevo.QobjEvo._compress_merge_qobj
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function sim at /Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py:42 for jit. This value became a tracer due to JAX operations on these lines:
operation a:c128[2,4] = pjit[name=atleast_2d jaxpr={ lambda ; b:c128[2,4]. let in (b,) }] c
from line /Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py:65 (sim)
operation a:c128[2,4] = pjit[name=atleast_2d jaxpr={ lambda ; b:c128[2,4]. let in (b,) }] c
from line /Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py:65 (sim)
operation a:bool[] = pjit[
name=allclose
jaxpr={ lambda ; b:c128[2,4] c:f64[] d:f64[]. let
e:bool[2,4] = pjit[
name=isclose
jaxpr={ lambda ; f:c128[2,4] g:f64[] h:f64[] i:f64[]. let
j:c128[] = convert_element_type[
new_dtype=complex128
weak_type=False
] g
k:f64[] = convert_element_type[new_dtype=float64 weak_type=False] h
l:f64[] = convert_element_type[new_dtype=float64 weak_type=False] i
m:c128[2,4] = sub f j
n:f64[2,4] = abs m
o:f64[] = abs j
p:f64[] = mul k o
q:f64[] = add l p
r:bool[2,4] = le n q
s:bool[2,4] = pjit[
name=isinf
jaxpr={ lambda ; t:c128[2,4]. let
u:f64[2,4] = real t
v:f64[2,4] = imag t
w:f64[2,4] = abs u
x:bool[2,4] = eq w inf
y:f64[2,4] = abs v
z:bool[2,4] = eq y inf
ba:bool[2,4] = or x z
in (ba,) }
] f
bb:bool[] = pjit[
name=isinf
jaxpr={ lambda ; bc:c128[]. let
bd:f64[] = real bc
be:f64[] = imag bc
bf:f64[] = abs bd
bg:bool[] = eq bf inf
bh:f64[] = abs be
bi:bool[] = eq bh inf
bj:bool[] = or bg bi
in (bj,) }
] j
bk:bool[2,4] = or s bb
bl:bool[2,4] = and s bb
bm:bool[2,4] = not bk
bn:bool[2,4] = and r bm
bo:bool[2,4] = eq f j
bp:bool[2,4] = and bl bo
bq:bool[2,4] = or bn bp
br:bool[2,4] = ne f f
bs:bool[] = ne j j
bt:bool[2,4] = or br bs
bu:bool[2,4] = not bt
bv:bool[2,4] = and bq bu
in (bv,) }
] b c 1e-05 d
bw:bool[] = reduce_and[axes=(0, 1)] e
in (bw,) }
] bx by bz
from line /Users/janoleernst/Desktop/DPhil/Simulations/Code/rl-qc/qu_sim_speed_benchmarks/jitting_qutip_sim.py:65 (sim)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
Full code is as follows (note that I am just using some constant functions to test the whole thing:
import qutip as qt
import time
import numpy as np
import qutip_jax
import jax
import jax.numpy as jnp
import jax_cosmo.scipy.interpolate as inter
from diffrax import diffeqsolve, ODETerm, Dopri5, PIDController
from scipy.signal.windows import blackman
with qt.CoreOptions(default_dtype="jaxdia"):
#initial state of the system
rho0=qt.fock_dm(4,0)
#simulation params
n_steps = 10
T=1
resolution=10
omega_0=30
Omega_0 = omega_0
gamma = 1.
#Pump and Stokes Hamiltonians
H_p = 0.5 * (qt.projection(4,1,0) + qt.projection(4,0,1))
H_s = 0.5 * (qt.projection(4,2,1) + qt.projection(4,1,2))
# Detuning Hamiltonian
H_dp = qt.projection(4,1,1)
H_d2 = qt.projection(4,2,2)
time_list=jnp.linspace(0.,T, resolution*n_steps, dtype=jnp.float64)
delta=0.0*omega_0
Delta=0.0*omega_0
#delta=Omega_0*0.15
#Delta=-23.5*delta #uncomment if you want to introduce a bias
H_d = Delta * qt.projection(4,1,1) + delta * qt.projection(4,2,2)
#Artificial environment Lindblad operator:
L = jnp.sqrt(gamma) * qt.projection(4,3,1)
# Define function that takes input arrays and transforms them into cubic splines
def cubic_spline(input_array):
return inter.InterpolatedUnivariateSpline(time_list, input_array)
def sim(single_action_sample):
# Carries out a full episode of system dynamics for a single action sample within a batch
delta_p, delta_2, omega_p, omega_s = single_action_sample
# Define the control signals
pump = cubic_spline(omega_p)
stokes = cubic_spline(omega_s)
delta_pump = cubic_spline(delta_p)
delta_stokes = cubic_spline(delta_2)
# Define mesovle options
options = {
"method": "diffrax",
"normalize_output": True,
"stepsize_controller" : PIDController(rtol=1e-5, atol=1e-5),
"solver": Dopri5()
}
#start_time=time.time()
result=qt.mesolve(H=[[H_p, jax.jit(pump)], [H_s, jax.jit(stokes)], [H_dp, jax.jit(delta_pump)], [H_d2, jax.jit(delta_stokes)]],
rho0=rho0, tlist=time_list, c_ops=L, options=options)
#final_time=time.time()
#print(f"Time taken for single step: {final_time-start_time}")
#reward is expectation value of final desired state
reward = result.states[-1][2,2]
return reward
# Test the environment
if __name__ == "__main__":
# Test the step function
amp_stokes=jnp.array(50*blackman(100), dtype=jnp.complex128)
amp_pump=jnp.array(50*blackman(100), dtype=jnp.complex128)
det_stokes=jnp.array(-10*blackman(100), dtype=jnp.complex128)
det_pump=jnp.array(0*100, dtype=jnp.complex128)
single_sample = jnp.array(
[det_pump, det_stokes,amp_pump, amp_stokes], dtype=jnp.complex128
)
#Test the sim function
start=time.time()
res=sim(single_sample)
print(f"Time taken: {time.time()-start} ")
fast_sim=jax.jit(sim)
start=time.time()
ys = fast_sim(single_sample)
print(f"Time taken jitted: {time.time()-start} ")
print(ys)
Many thanks!
The text was updated successfully, but these errors were encountered:
mesolve does not support jit. It does safety checks, manage metadata, use cython, etc. which does not work well inside jit. That's why most example spit the setup (solver = MESolver(...)) and the computations (solver.run) and only the second is inside the jit compiled function.
With spline coefficient, the separation of setup and computation is harder if you want to reuse the solver. You would need to pass the InterpolatedUnivariateSpline as args.
Following the end of the discussion in #26 - I am working on the qutip-jax-dia branch and tried to implement cubic splines from jax_cosmo for a time dependent hamiltonian simulation (master equation with static collapse operators) which can be massively speed up with jax for re-running the same simulation with different TD params. I can run the simulation without jitting the function sim() on a CPU but it's very slow (as to be expected).
I know qutip-jax-dia is in beta beta, but maybe some of the clever people here have some suggestions as to why I can't jit the function.
For reference I'm working on osx-arm64 with an M1 chip.
The error message I get is:
Full code is as follows (note that I am just using some constant functions to test the whole thing:
Many thanks!
The text was updated successfully, but these errors were encountered: