Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Definition of qutip.data.sqrtm.add_specialisations( [(JaxArray, JaxArray, sqrtm_jaxarray),] ) raising error #67

Open
ArturDomingues opened this issue Sep 10, 2024 · 2 comments

Comments

@ArturDomingues
Copy link

I was trying to use qutip-jax and got and error while importing it, which is shown below. Going to qutip/core/data it's possible to see that there is no sqrtm in there, but it exists in qutip/core/data/expm.py, with that in mind I think the fix is just changing

qutip.data.sqrtm.add_specialisations(
    [(JaxArray, JaxArray, sqrtm_jaxarray),]
)

to

qutip.data.expm.sqrtm.add_specialisations(
    [(JaxArray, JaxArray, sqrtm_jaxarray),]
)

in unary.py.
Here is the error I mentioned:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], [line 3](vscode-notebook-cell:?execution_count=4&line=3)
      [1](vscode-notebook-cell:?execution_count=4&line=1) import jax.numpy as jnp
      [2](vscode-notebook-cell:?execution_count=4&line=2) import qutip
----> [3](vscode-notebook-cell:?execution_count=4&line=3) import qutip_jax
      [5](vscode-notebook-cell:?execution_count=4&line=5) with qutip.CoreOptions(default_dtype="jax"):
      [6](vscode-notebook-cell:?execution_count=4&line=6)     excited = qutip.basis(dim, 4, dtype="jax"), qutip.basis(dim, 3, dtype="jax"), qutip.basis(dim, 5, dtype="jax")

File c:\Users\artur\anaconda3\envs\qutipjax\Lib\site-packages\qutip_jax\__init__.py:33
     [30](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:30) del is_jax_array
     [32](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:32) from .binops import *
---> [33](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:33) from .unary import *
     [34](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:34) from .permute import *
     [35](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:35) from .reshape import *

File c:\Users\artur\anaconda3\envs\qutipjax\Lib\site-packages\qutip_jax\unary.py:195
    [181](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:181) qutip.data.expm.add_specialisations(
    [182](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:182)     [
    [183](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:183)         (JaxArray, JaxArray, expm_jaxarray),
    [184](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:184)     ]
    [185](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:185) )
    [188](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:188) qutip.data.inv.add_specialisations(
    [189](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:189)     [
    [190](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:190)         (JaxArray, JaxArray, inv_jaxarray),
    [191](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:191)     ]
    [192](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:192) )
--> [195](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:195) qutip.data.sqrtm.add_specialisations(
    [196](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:196)     [(JaxArray, JaxArray, sqrtm_jaxarray),]
    [197](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:197) )
    [200](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:200) qutip.data.project.add_specialisations(
    [201](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:201)     [
    [202](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:202)         (JaxArray, JaxArray, project_jaxarray),
    [203](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:203)     ]
    [204](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:204) )

AttributeError: module 'qutip.core.data' has no attribute 'sqrtm'

I installed qutip-jax following the instructions in this link, so I've used

pip install qutip --pre
pip install git+https://github.com/qutip/qutip-jax.git
@Ericgig
Copy link
Member

Ericgig commented Sep 10, 2024

To use the development version of qutip-jax, you will need to install both from source:

pip install git+https://github.com/qutip/qutip.git
pip install git+https://github.com/qutip/qutip-jax.git

If you don't have cython working to compile qutip, you could use the released version:

pip install qutip qutip-jax

@ArturDomingues
Copy link
Author

Ok, got it, this should be explicit in the installation instructions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants