Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tr is not jit compatible because of isherm #34

Open
BoxiLi opened this issue Feb 15, 2024 · 1 comment
Open

tr is not jit compatible because of isherm #34

BoxiLi opened this issue Feb 15, 2024 · 1 comment

Comments

@BoxiLi
Copy link
Member

BoxiLi commented Feb 15, 2024

tr checks first if the matrix is Hermitian and changes the type of the output.

...
out = _data.trace(self._data)
# This ensures that trace can return something that is not a number such
# as a `tensorflow.Tensor` in qutip-tensorflow.
return out.real if (self.isherm
                and hasattr(out, "real")
                ) else out

However, when determining the isherm property, jnp.allclose is used to check if the matrix is close to hermitian.

@partial(jit, static_argnames=["tol"])
def _isherm(matrix, tol):
    return jnp.allclose(matrix, matrix.T.conj(), atol=tol, rtol=0)

In principle, following the philosophy of jit, anything related to the property of a matrix and will be used in branching the computation should not evaluate the matrix. Maybe for JaxArray we should just leave it as false if it can not be derived explicitly?

Example:

import qutip
import qutip_jax
import jax
import jax.numpy as np
qutip.settings.core["default_dtype"] = "jax"
@jax.jit
def tmp(a):
    m = qutip.Qobj(np.array([[1., a], [np.conjugate(a), 1.]]))
    return m.tr()
tmp(1.-1.j)
@BoxiLi BoxiLi changed the title dag is not jit compatible because of isherm tr is not jit compatible because of isherm Feb 15, 2024
@BoxiLi
Copy link
Member Author

BoxiLi commented Feb 15, 2024

What I don't understand is, dag seems to work fine for this case

@jax.jit
def tmp(a):
    m = qutip.Qobj(np.array([[1., a], [np.conjugate(a), 1.]]))
    m.dag()
    return 0.
tmp(1.-1.j)

where _isherm is used instead of isherm.

def dag(self):
    """Get the Hermitian adjoint of the quantum object."""
    if self._isherm:
        return self.copy()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant