Skip to content

Commit

Permalink
deprecation fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc committed May 17, 2024
1 parent 0832c3a commit a167b6d
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def projective_measurement(phi, psi, p, key_meas, key_spin):
key_spin, subkey_spin = jax.random.split(key_spin)

params = flax.core.unfreeze(psi.parameters)
params = jax.tree_map(lambda x: jnp.array(x), params)
params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
if jax.random.uniform(subkey_spin) < prob_up.real:
params["orbital_down"] = params["orbital_down"].at[i].set(1e-12)
else:
Expand Down Expand Up @@ -128,7 +128,7 @@ def dynamics_with_measurements(

# ZZ diagonal term
params = flax.core.unfreeze(psi.parameters)
params = jax.tree_map(lambda x: jnp.array(x), params)
params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
for l, m in g.edges():
params["theta_zz"] = (
params["theta_zz"]
Expand All @@ -154,7 +154,7 @@ def dynamics_with_measurements(

# ZZ diagonal term
params = flax.core.unfreeze(psi.parameters)
params = jax.tree_map(lambda x: jnp.array(x), params)
params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
for l, m in g.edges():
params["theta_zz"] = (
params["theta_zz"]
Expand Down
4 changes: 2 additions & 2 deletions netket_fidelity/driver/ptvmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
U_dagger=None,
preconditioner: PreconditionerT = identity_preconditioner,
is_unitary=False,
sample_Upsi=False,
sample_Upsi=False,
cv_coeff=None,
):
self._dt = dt
Expand All @@ -41,7 +41,7 @@ def __init__(
U_dagger=U_dagger,
preconditioner=preconditioner,
is_unitary=is_unitary,
sample_Upsi=sample_Upsi,
sample_Upsi=sample_Upsi,
cv_coeff=cv_coeff,
)

Expand Down
10 changes: 5 additions & 5 deletions netket_fidelity/infidelity/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def InfidelityOperator(
the function :class:`netket_fidelity.infidelity.InfidelityUPsi` .
This works only with the operators provdided in the package.
We remark that sampling from :math:`U|\phi\rangle` requires to compute connected elements of
We remark that sampling from :math:`U|\phi\rangle` requires to compute connected elements of
:math:`U` and so is more expensive than sampling from an autonomous state.
The choice of this estimator is specified by passing :code:`sample_Upsi=True`,
while the flag argument :code:`is_unitary` indicates whether :math:`U` is unitary or not.
Expand All @@ -82,7 +82,7 @@ def InfidelityOperator(
This estimator is more efficient since it does not require to sample from
:math:`U|\phi\rangle`, but only from :math:`|\phi\rangle`.
This choice of the estimator is the default and it works only
with `is_unitary==True` (besides :code:`sample_Upsi=False` ).
with `is_unitary==True` (besides :code:`sample_Upsi=False` ).
When :math:`|\Phi⟩ = |\phi⟩` the two estimators coincides.
To reduce the variance of the estimator, the Control Variates (CV) method can be applied. This consists
Expand All @@ -100,8 +100,8 @@ def InfidelityOperator(
c* = \frac{\rm{Cov}_{χ}\left[ |1-I_{loc}|^2, \rm{Re}\left[1-I_{loc}\right]\right]}{
\rm{Var}_{χ}\left[ |1-I_{loc}|^2\right] },
where :math:`\rm{Cov}\left\cdot, \cdot\right]` indicates the covariance and :math:`\rm{Var}\left[\cdot\right]` the variance.
In the relevant limit :math:`|\Psi⟩ \rightarrow|\Phi⟩`, we have :math:`c^\star \rightarrow -1/2`. The value :math:`-1/2` is
where :math:`\rm{Cov}\left\cdot, \cdot\right]` indicates the covariance and :math:`\rm{Var}\left[\cdot\right]` the variance.
In the relevant limit :math:`|\Psi⟩ \rightarrow|\Phi⟩`, we have :math:`c^\star \rightarrow -1/2`. The value :math:`-1/2` is
adopted as default value for c in the infidelity
estimator. To not apply CV, set c=0.
Expand All @@ -110,7 +110,7 @@ def InfidelityOperator(
U: operator :math:`\hat{U}`.
U_dagger: dagger operator :math:`\hat{U^\dagger}`.
cv_coeff: Control Variates coefficient c.
is_unitary: flag specifiying the unitarity of :math:`\hat{U}`. If True with
is_unitary: flag specifiying the unitarity of :math:`\hat{U}`. If True with
:code:`sample_Upsi=False`, the second estimator is used.
dtype: The dtype of the output of expectation value and gradient.
sample_Upsi: flag specifiying whether to sample from |ϕ⟩ or from U|ϕ⟩. If False with `is_unitary=False` , an error occurs.
Expand Down
4 changes: 2 additions & 2 deletions netket_fidelity/infidelity/overlap/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def expect_fun(params):
F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True)

F_grad = F_vjp_fun(jnp.ones_like(F))[0]
F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_map(lambda x: -x, F_grad)
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad)
I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0)

return I_stats, I_grad
4 changes: 2 additions & 2 deletions netket_fidelity/infidelity/overlap/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def kernel_fun(params, params_t, σ, σ_t):
)

F_grad = F_vjp_fun(jnp.ones_like(F))[0]
F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_map(lambda x: -x, F_grad)
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad)
I_stats = F_stats.replace(mean=1 - F)

return I_stats, I_grad
4 changes: 2 additions & 2 deletions netket_fidelity/infidelity/overlap_U/expect.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def kernel_fun(params, params_t, σ, σ_t):
)

F_grad = F_vjp_fun(jnp.ones_like(F))[0]
F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_map(lambda x: -x, F_grad)
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad)
I_stats = F_stats.replace(mean=1 - F)

return I_stats, I_grad
6 changes: 3 additions & 3 deletions test/_infidelity_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ def _infidelity_exact(params_new, vstate, U):
)

else:
return 1 - jnp.absolute(state_new.conj().T @ U.to_sparse() @ state_old) ** 2 / (
(state_new.conj().T @ state_new) * (state_old.conj().T @ state_old)
)
return 1 - jnp.absolute(
state_new.conj().T @ (U.to_sparse() @ state_old)
) ** 2 / ((state_new.conj().T @ state_new) * (state_old.conj().T @ state_old))

0 comments on commit a167b6d

Please sign in to comment.