Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array slice indices are not compatible with jit #62

Open
rochisha0 opened this issue Jul 24, 2024 · 0 comments
Open

Array slice indices are not compatible with jit #62

rochisha0 opened this issue Jul 24, 2024 · 0 comments

Comments

@rochisha0
Copy link
Contributor

Issue

def isherm_jaxdia(matrix, tol=None):
    if matrix.shape[0] != matrix.shape[1]:
        return False
    tol = tol or qutip.settings.core["atol"]
    done = []
    for offset, data in zip(matrix.offsets, matrix.data):
        if offset in done:
            continue
        start = max(0, offset)
        end = min(matrix.shape[1], matrix.shape[0] + offset)
        if -offset not in matrix.offsets:
            if not _is_zero(data[start:end], tol):
                return False
        else:
            idx = matrix.offsets.index(-offset)
            done.append(-offset)
            st = max(0, -offset)
            et = min(matrix.shape[1], matrix.shape[0] - offset)
            if not _is_conj(data[start:end], matrix.data[idx, st:et], tol):
                return False
    return True

is_herm_jaxdia is not compatible with jit as it uses array slice indices.

Solution
jax.lax.dynamic_slice to perform slicing within the JIT-compiled function. dynamic_slice is designed to be compatible with JAX's JIT compilation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant