From ce556753ea13a250e3d17725c65f1c1f836cfb87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Thu, 2 May 2024 12:40:30 +0200 Subject: [PATCH 01/18] Implement evolve(mps, mpo) --- src/Ansatz/Chain.jl | 47 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 4e7d033..274c1ee 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -652,6 +652,53 @@ function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) return ψ end +""" + evolve!(ψ::Chain, mpo::Chain) + +Applies a Matrix Product Operator (MPO) `mpo` to the [`Chain`](@ref). +""" +function evolve(ψ::Chain, mpo::Chain) + updated_tensors = Tensor[] + Λ = Tensor[] + L = nsites(ψ) + + for i in 1:L + t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims=(select(ψ, :index, Site(i)),)) + physicalind = select(mpo, :index, Site(i)) + + # Fuse the two right legs of t into one + if i == 1 + wanted_inds = (physicalind, rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) + new_inds = (physicalind, rightindex(ψ, Site(i))) + elseif i < L + wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i)), rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) + new_inds = (physicalind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) + else + wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i))) + new_inds = (physicalind, leftindex(ψ, Site(i))) + end + + perm = Tenet.__find_index_permutation(wanted_inds, inds(t)) + t = PermutedDimsArray(parent(t), perm) + + t = Tensor(reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)), new_inds) + push!(updated_tensors, t) + + if i < L + Λᵢ = select(ψ, :between, Site(i), Site(i + 1)) + Λᵢ = Tensor(diag(kron(Matrix(LinearAlgebra.I, d, d), diagm(parent(Λᵢ)))), inds(Λᵢ)) + push!(Λ, Λᵢ) + end + end + + ψ_ev = MPS(updated_tensors) + for i in 1:L-1 + push!(TensorNetwork(ψ_ev), Tensor(parent(Λ[i]), (rightindex(ψ_ev, Site(i)),))) + end + + return ψ_ev +end + function expect(ψ::Chain, observables) # contract observable with TN ϕ = copy(ψ) From 35e6117447bf2012cc1ecb514d7c22e56291ef2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Thu, 2 May 2024 12:57:39 +0200 Subject: [PATCH 02/18] Add missing variable definition --- src/Ansatz/Chain.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 274c1ee..5ba8378 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -685,6 +685,7 @@ function evolve(ψ::Chain, mpo::Chain) push!(updated_tensors, t) if i < L + d = size(TensorNetwork(mpo), rightindex(mpo, Site(i))) Λᵢ = select(ψ, :between, Site(i), Site(i + 1)) Λᵢ = Tensor(diag(kron(Matrix(LinearAlgebra.I, d, d), diagm(parent(Λᵢ)))), inds(Λᵢ)) push!(Λ, Λᵢ) From b126744dbeb51320e679bbe9c71e9f9ba1c005d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Thu, 2 May 2024 13:08:21 +0200 Subject: [PATCH 03/18] Format code --- src/Ansatz/Chain.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 5ba8378..09ed2a4 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -663,7 +663,7 @@ function evolve(ψ::Chain, mpo::Chain) L = nsites(ψ) for i in 1:L - t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims=(select(ψ, :index, Site(i)),)) + t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims = (select(ψ, :index, Site(i)),)) physicalind = select(mpo, :index, Site(i)) # Fuse the two right legs of t into one @@ -671,7 +671,13 @@ function evolve(ψ::Chain, mpo::Chain) wanted_inds = (physicalind, rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) new_inds = (physicalind, rightindex(ψ, Site(i))) elseif i < L - wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i)), rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) + wanted_inds = ( + physicalind, + leftindex(ψ, Site(i)), + leftindex(mpo, Site(i)), + rightindex(ψ, Site(i)), + rightindex(mpo, Site(i)), + ) new_inds = (physicalind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) else wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i))) @@ -681,7 +687,10 @@ function evolve(ψ::Chain, mpo::Chain) perm = Tenet.__find_index_permutation(wanted_inds, inds(t)) t = PermutedDimsArray(parent(t), perm) - t = Tensor(reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)), new_inds) + t = Tensor( + reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)), + new_inds, + ) push!(updated_tensors, t) if i < L From 8fc17abb0cb399e1e657759128ec85e996cd41f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Tue, 4 Jun 2024 12:07:38 +0200 Subject: [PATCH 04/18] Update evolve code --- src/Ansatz/Chain.jl | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 09ed2a4..5e14081 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -652,24 +652,27 @@ function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) return ψ end +evolve(ψ::Chain, mpo::Chain) = evolve!(copy(ψ), mpo) + """ evolve!(ψ::Chain, mpo::Chain) Applies a Matrix Product Operator (MPO) `mpo` to the [`Chain`](@ref). """ -function evolve(ψ::Chain, mpo::Chain) +function evolve!(ψ::Chain, mpo::Chain) updated_tensors = Tensor[] Λ = Tensor[] L = nsites(ψ) for i in 1:L - t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims = (select(ψ, :index, Site(i)),)) + contractedind = select(ψ, :index, Site(i)) + t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims = (contractedind,)) physicalind = select(mpo, :index, Site(i)) # Fuse the two right legs of t into one if i == 1 wanted_inds = (physicalind, rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) - new_inds = (physicalind, rightindex(ψ, Site(i))) + new_inds = (contractedind, rightindex(ψ, Site(i))) elseif i < L wanted_inds = ( physicalind, @@ -678,14 +681,14 @@ function evolve(ψ::Chain, mpo::Chain) rightindex(ψ, Site(i)), rightindex(mpo, Site(i)), ) - new_inds = (physicalind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) + new_inds = (contractedind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) else wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i))) - new_inds = (physicalind, leftindex(ψ, Site(i))) + new_inds = (contractedind, leftindex(ψ, Site(i))) end perm = Tenet.__find_index_permutation(wanted_inds, inds(t)) - t = PermutedDimsArray(parent(t), perm) + t = permutedims(t, perm) t = Tensor( reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)), @@ -701,12 +704,18 @@ function evolve(ψ::Chain, mpo::Chain) end end - ψ_ev = MPS(updated_tensors) - for i in 1:L-1 - push!(TensorNetwork(ψ_ev), Tensor(parent(Λ[i]), (rightindex(ψ_ev, Site(i)),))) + + for i in 1:L + i < L && pop!(TensorNetwork(ψ), select(ψ, :between, Site(i), Site(i + 1))) + pop!(TensorNetwork(ψ), select(ψ, :tensor, Site(i))) end - return ψ_ev + for i in 1:L + i < L && push!(TensorNetwork(ψ), Λ[i]) + push!(TensorNetwork(ψ), updated_tensors[i]) + end + + return ψ end function expect(ψ::Chain, observables) From ff48821f07dbdc6e9241b2fcdf983a60f524079f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Tue, 4 Jun 2024 12:10:41 +0200 Subject: [PATCH 05/18] Format code --- src/Ansatz/Chain.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 5e14081..9e77c93 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -704,7 +704,6 @@ function evolve!(ψ::Chain, mpo::Chain) end end - for i in 1:L i < L && pop!(TensorNetwork(ψ), select(ψ, :between, Site(i), Site(i + 1))) pop!(TensorNetwork(ψ), select(ψ, :tensor, Site(i))) From 08457cde09b4c447f2584a74947b3ede15f72c8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 5 May 2024 01:42:58 +0200 Subject: [PATCH 06/18] Upgrade `Tenet` to v0.6 - Refactor `select` to `tensor`,`inds` methods - Remove `ValSplit` dependency --- Project.toml | 4 +-- docs/src/quantum.md | 4 +-- src/Ansatz.jl | 38 ++++++++++++++-------- src/Ansatz/Chain.jl | 60 +++++++++++++++++------------------ src/Qrochet.jl | 2 +- src/Quantum.jl | 47 ++++++++++++++++++--------- test/Ansatz/Chain_test.jl | 28 ++++++++-------- test/integration/Quac_test.jl | 4 +-- 8 files changed, 107 insertions(+), 80 deletions(-) diff --git a/Project.toml b/Project.toml index 694e89c..4eaa086 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Muscle = "21fe5c4b-a943-414d-bf3e-516f24900631" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" -ValSplit = "0625e100-946b-11ec-09cd-6328dd093154" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -26,7 +25,6 @@ QrochetYaoExt = "Yao" ChainRulesCore = "1.0" Muscle = "0.1" Quac = "0.3" -Tenet = "0.5" -ValSplit = "0.1" +Tenet = "0.6" Yao = "0.8" julia = "1.9" diff --git a/docs/src/quantum.md b/docs/src/quantum.md index 1b78a93..0c8dd03 100644 --- a/docs/src/quantum.md +++ b/docs/src/quantum.md @@ -11,8 +11,8 @@ nsites ## Queries ```@docs -Tenet.select(::Quantum, ::Val{:index}, ::Site) -Tenet.select(::Quantum, ::Val{:tensor}, ::Site) +Tenet.inds(::Quantum; kwargs...) +Tenet.tensors(::Quantum; kwargs...) ``` ## Connecting `Quantum` Tensor Networks diff --git a/src/Ansatz.jl b/src/Ansatz.jl index b6137ac..f7c9ae0 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -1,5 +1,4 @@ using Tenet -using ValSplit using LinearAlgebra """ @@ -29,7 +28,6 @@ for f in [ :nsites, :nlanes, :socket, - :(Tenet.tensors), :(Tenet.arrays), :(Base.collect), ] @@ -46,33 +44,47 @@ alias(::A) where {A} = string(A) function Base.summary(io::IO, tn::A) where {A<:Ansatz} print(io, "$(alias(tn)) (inputs=$(ninputs(tn)), outputs=$(noutputs(tn)))") end -Base.show(io::IO, tn::A) where {A<:Ansatz} = Base.summary(io, tn) - -@valsplit 2 Tenet.select(tn::Ansatz, query::Symbol, args...) = select(Quantum(tn), query, args...) +Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn) + +function Tenet.inds(tn::Ansatz; kwargs...) + if keys(kwargs) === (:bond,) + inds(tn, Val(:bond), kwargs[:bond]...) + else + inds(Quantum(tn); kwargs...) + end +end -function Tenet.select(tn::Ansatz, ::Val{:between}, site1::Site, site2::Site) +function Tenet.inds(tn::Ansatz, ::Val{:bond}, site1::Site, site2::Site) @assert site1 ∈ sites(tn) "Site $site1 not found" @assert site2 ∈ sites(tn) "Site $site2 not found" @assert site1 != site2 "Sites must be different" - tensor1 = select(Quantum(tn), :tensor, site1) - tensor2 = select(Quantum(tn), :tensor, site2) + tensor1 = tensors(tn; at = site1) + tensor2 = tensors(tn; at = site2) isdisjoint(inds(tensor1), inds(tensor2)) && return nothing + return only(inds(tensor1) ∩ inds(tensor2)) +end - TensorNetwork(tn)[only(inds(tensor1) ∩ inds(tensor2))] +function Tenet.tensors(tn::Ansatz; kwargs...) + if keys(kwargs) === (:between,) + tensors(tn, Val(:between), kwargs[:between]...) + else + tensors(Quantum(tn); kwargs...) + end end -function Tenet.select(tn::Ansatz, ::Val{:bond}, site1::Site, site2::Site) +function Tenet.tensors(tn::Ansatz, ::Val{:between}, site1::Site, site2::Site) @assert site1 ∈ sites(tn) "Site $site1 not found" @assert site2 ∈ sites(tn) "Site $site2 not found" @assert site1 != site2 "Sites must be different" - tensor1 = select(Quantum(tn), :tensor, site1) - tensor2 = select(Quantum(tn), :tensor, site2) + tensor1 = tensors(tn; at = site1) + tensor2 = tensors(tn; at = site2) isdisjoint(inds(tensor1), inds(tensor2)) && return nothing - return only(inds(tensor1) ∩ inds(tensor2)) + + TensorNetwork(tn)[only(inds(tensor1) ∩ inds(tensor2))] end struct MissingSchmidtCoefficientsException <: Base.Exception diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 9e77c93..b382ef7 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -129,11 +129,11 @@ rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site) -leftindex(::Periodic, tn::Chain, site::Site) = select(tn, :bond, site, leftsite(tn, site)) +leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, leftsite(tn, site))) rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) rightindex(::Open, tn::Chain, site::Site) = site == Site(nlanes(tn)) ? nothing : rightindex(Periodic(), tn, site) -rightindex(::Periodic, tn::Chain, site::Site) = select(tn, :bond, site, rightsite(tn, site)) +rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, rightsite(tn, site))) Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain)) @@ -248,14 +248,14 @@ function Tenet.contract!( direction::Symbol = :left, delete_Λ = true, ) - Λᵢ = select(tn, :between, site1, site2) + Λᵢ = tensors(tn; between = (site1, site2)) Λᵢ === nothing && return tn if direction === :right - Γᵢ₊₁ = select(tn, :tensor, site2) + Γᵢ₊₁ = tensors(tn; at = site2) replace!(TensorNetwork(tn), Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ, dims = ())) elseif direction === :left - Γᵢ = select(tn, :tensor, site1) + Γᵢ = tensors(tn; at = site1) replace!(TensorNetwork(tn), Γᵢ => contract(Λᵢ, Γᵢ, dims = ())) else throw(ArgumentError("Unknown direction=:$direction")) @@ -322,13 +322,13 @@ Truncate the dimension of the virtual `bond`` of the [`Chain`](@ref) Tensor Netw - The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`. """ function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) - # TODO replace for select(:between) + # TODO replace for tensors(; between) vind = rightindex(qtn, bond[1]) if vind != leftindex(qtn, bond[2]) throw(ArgumentError("Invalid bond $bond")) end - if vind ∉ inds(TensorNetwork(qtn), :hyper) + if vind ∉ inds(qtn; set = :hyper) throw(MissingSchmidtCoefficientsException(bond)) end @@ -357,7 +357,7 @@ end function isleftcanonical(qtn::Chain, site; atol::Real = 1e-12) right_ind = rightindex(qtn, site) - tensor = select(qtn, :tensor, site) + tensor = tensors(qtn; at = site) # we are at right-most site, we need to add an extra dummy dimension to the tensor if isnothing(right_ind) @@ -375,7 +375,7 @@ end function isrightcanonical(qtn::Chain, site; atol::Real = 1e-12) left_ind = leftindex(qtn, site) - tensor = select(qtn, :tensor, site) + tensor = tensors(qtn; at = site) # we are at left-most site, we need to add an extra dummy dimension to the tensor if isnothing(left_ind) @@ -413,15 +413,15 @@ function canonize!(::Open, tn::Chain) canonize_site!(tn, Site(i); direction = :right, method = :svd) # extract the singular values and contract them with the next tensor - Λᵢ = pop!(TensorNetwork(tn), select(tn, :between, Site(i), Site(i + 1))) - Aᵢ₊₁ = select(tn, :tensor, Site(i + 1)) + Λᵢ = pop!(TensorNetwork(tn), tensors(tn; between = (Site(i), Site(i + 1)))) + Aᵢ₊₁ = tensors(tn; at = Site(i + 1)) replace!(TensorNetwork(tn), Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ, dims = ())) push!(Λ, Λᵢ) end for i in 2:nsites(tn) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ Λᵢ = Λ[i-1] # singular values start between site 1 and 2 - A = select(tn, :tensor, Site(i)) + A = tensors(tn; at = Site(i)) Γᵢ = contract(A, Tensor(diag(pinv(Diagonal(parent(Λᵢ)), atol = 1e-64)), inds(Λᵢ)), dims = ()) replace!(TensorNetwork(tn), A => Γᵢ) push!(TensorNetwork(tn), Λᵢ) @@ -462,7 +462,7 @@ to mixed-canonized form with the given center site. """ function LinearAlgebra.normalize!(tn::Chain, root::Site; p::Real = 2) mixed_canonize!(tn, root) - normalize!(select(Quantum(tn), :tensor, root), p) + normalize!(tensors(Quantum(tn); at = root), p) return tn end @@ -522,11 +522,11 @@ function evolve_1site!(qtn::Chain, gate::Dense) targetsite = only(inputs(gate))' # reindex contracting index - replace!(TensorNetwork(qtn), select(qtn, :index, targetsite) => contracting_index) - replace!(TensorNetwork(gate), select(gate, :index, targetsite') => contracting_index) + replace!(TensorNetwork(qtn), inds(qtn; at = targetsite) => contracting_index) + replace!(TensorNetwork(gate), inds(gate; at = targetsite') => contracting_index) # reindex output of gate to match TN sitemap - replace!(TensorNetwork(gate), select(gate, :index, only(outputs(gate))) => select(qtn, :index, targetsite)) + replace!(TensorNetwork(gate), inds(gate; at = only(outputs(gate))) => inds(qtn; at = targetsite)) # contract gate with TN merge!(TensorNetwork(qtn), TensorNetwork(gate)) @@ -542,23 +542,23 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = left_inds::Vector{Symbol} = !isnothing(leftindex(qtn, sitel)) ? [leftindex(qtn, sitel)] : Symbol[] right_inds::Vector{Symbol} = !isnothing(rightindex(qtn, siter)) ? [rightindex(qtn, siter)] : Symbol[] - virtualind::Symbol = select(qtn, :bond, bond...) + virtualind::Symbol = inds(qtn, :bond, bond...) iscanonical ? contract_2sitewf!(qtn, bond) : contract!(TensorNetwork(qtn), virtualind) # reindex contracting index contracting_inds = [gensym(:tmp) for _ in inputs(gate)] replace!(TensorNetwork(qtn), map(zip(inputs(gate), contracting_inds)) do (site, contracting_index) - select(qtn, :index, site') => contracting_index + inds(qtn; at = site') => contracting_index end) replace!(TensorNetwork(gate), map(zip(inputs(gate), contracting_inds)) do (site, contracting_index) - select(gate, :index, site) => contracting_index + inds(gate; at = site) => contracting_index end) # reindex output of gate to match TN sitemap for site in outputs(gate) - if select(qtn, :index, site) != select(gate, :index, site) - replace!(TensorNetwork(gate), select(gate, :index, site) => select(qtn, :index, site)) + if inds(qtn; at = site) != inds(gate; at = site) + replace!(TensorNetwork(gate), inds(gate; at = site) => inds(qtn; at = site)) end end @@ -567,8 +567,8 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = contract!(TensorNetwork(qtn), contracting_inds) # decompose using SVD - push!(left_inds, select(qtn, :index, sitel)) - push!(right_inds, select(qtn, :index, siter)) + push!(left_inds, inds(qtn; at = sitel)) + push!(right_inds, inds(qtn; at = siter)) if iscanonical unpack_2sitewf!(qtn, bond, left_inds, right_inds, virtualind) @@ -581,7 +581,7 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = # renormalize the bond if renormalize && iscanonical - λ = select(qtn, :between, bond...) + λ = tensors(qtn; between = bond) replace!(TensorNetwork(qtn), λ => normalize(λ)) elseif renormalize && !iscanonical normalize!(qtn, bond[1]) @@ -604,13 +604,13 @@ function contract_2sitewf!(ψ::Chain, bond) (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel) - Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : select(ψ, :between, siter, Site(id(siter) + 1)) + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between = (Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between = (siter, Site(id(siter) + 1))) !isnothing(Λᵢ₋₁) && contract!(ψ, :between, Site(id(sitel) - 1), sitel; direction = :right, delete_Λ = false) !isnothing(Λᵢ₊₁) && contract!(ψ, :between, siter, Site(id(siter) + 1); direction = :left, delete_Λ = false) - contract!(TensorNetwork(ψ), select(ψ, :bond, bond...)) + contract!(TensorNetwork(ψ), inds(ψ, :bond, bond...)) return ψ end @@ -628,11 +628,11 @@ function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel) - Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : select(ψ, :between, siter, Site(id(siter) + 1)) + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between = (Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between = (siter, Site(id(siter) + 1))) # do svd of the θ tensor - θ = select(ψ, :tensor, sitel) + θ = tensors(ψ; at = sitel) U, s, Vt = svd(θ; left_inds, right_inds, virtualind) # contract with the inverse of Λᵢ and Λᵢ₊₂ diff --git a/src/Qrochet.jl b/src/Qrochet.jl index 5ab658c..b7d6bdc 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -33,6 +33,6 @@ export evolve!, expect, overlap # reexports from Tenet using Tenet -export select +export inds, tensors, arrays end diff --git a/src/Quantum.jl b/src/Quantum.jl index 815ed6f..eaa5432 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -1,5 +1,4 @@ using Tenet -using ValSplit # TODO Should we store here some information about quantum numbers? """ @@ -181,7 +180,7 @@ nlanes(tn::Quantum) = length(lanes(tn)) Returns the index associated with a site in a [`Quantum`](@ref) Tensor Network. """ -Base.getindex(q::Quantum, site::Site) = q.sites[site] +Base.getindex(q::Quantum, site::Site) = inds(q; at = site) """ Socket @@ -232,25 +231,43 @@ function socket(q::Quantum) end # forward `TensorNetwork` methods -for f in [:(Tenet.tensors), :(Tenet.arrays), :(Base.collect)] +for f in [:(Tenet.arrays), :(Base.collect)] @eval $f(@nospecialize tn::Quantum) = $f(TensorNetwork(tn)) end -@valsplit 2 Tenet.select(tn::Quantum, query::Symbol, args...) = error("Query ':$query' not defined") - """ - select(q::Quantum, :index, site::Site) + inds(tn::Quantum, set::Symbol = :all, args...; kwargs...) + +Options: -Selects the index associated with a site in a [`Quantum`](@ref) Tensor Network. + - `:at`: index at a site """ -Tenet.select(tn::Quantum, ::Val{:index}, site::Site) = tn[site] +function Tenet.inds(tn::Quantum; kwargs...) + if keys(kwargs) === (:at,) + inds(tn, Val(:at), kwargs[:at]) + else + inds(TensorNetwork(tn); kwargs...) + end +end + +Tenet.inds(tn::Quantum, ::Val{:at}, site::Site) = tn.sites[site] """ - select(q::Quantum, :tensor, site::Site) + tensors(tn::Quantum, query::Symbol, args...; kwargs...) -Selects the tensor associated with a site in a [`Quantum`](@ref) Tensor Network. +Options: + + - `:at`: tensor at a site """ -Tenet.select(tn::Quantum, ::Val{:tensor}, site::Site) = select(TensorNetwork(tn), :any, tn[site]) |> only +function Tenet.tensors(tn::Quantum; kwargs...) + if keys(kwargs) === (:at,) + tensors(tn, Val(:at), kwargs[:at]) + else + tensors(TensorNetwork(tn); kwargs...) + end +end + +Tenet.tensors(tn::Quantum, ::Val{:at}, site::Site) = only(tensors(tn; intersects = inds(tn; at = site))) function reindex!(a::Quantum, ioa, b::Quantum, iob) ioa ∈ [:inputs, :outputs] || error("Invalid argument: :$ioa") @@ -264,7 +281,7 @@ function reindex!(a::Quantum, ioa, b::Quantum, iob) end replacements = map(sitesb) do site - select(b, :index, site) => select(a, :index, ioa != iob ? site' : site) + inds(b; at = site) => inds(a; at = ioa != iob ? site' : site) end if issetequal(first.(replacements), last.(replacements)) @@ -274,7 +291,7 @@ function reindex!(a::Quantum, ioa, b::Quantum, iob) replace!(TensorNetwork(b), replacements...) for site in sitesb - b.sites[site] = select(a, :index, ioa != iob ? site' : site) + b.sites[site] = inds(a; at = ioa != iob ? site' : site) end b @@ -312,11 +329,11 @@ function Base.merge(a::Quantum, b::Quantum) sites = Dict{Site,Symbol}() for site in inputs(a) - sites[site] = select(a, :index, site) + sites[site] = inds(a; at = site) end for site in outputs(b) - sites[site] = select(b, :index, site) + sites[site] = inds(b; at = site) end Quantum(tn, sites) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 2269434..0eb7bcb 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -66,7 +66,7 @@ @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 - singular_values = select(qtn, :between, Site(2), Site(3)) + singular_values = tensors(qtn; between = (Site(2), Site(3))) truncated = Qrochet.truncate(qtn, [Site(2), Site(3)]; threshold = singular_values[2] + 0.1) @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 @@ -117,10 +117,10 @@ for i in 1:4 contract_some = contract(canonized, :between, Site(i), Site(i + 1)) - Bᵢ = select(contract_some, :tensor, Site(i)) + Bᵢ = tensors(contract_some; at = Site(i)) @test isapprox(contract(TensorNetwork(contract_some)), contract(TensorNetwork(qtn))) - @test_throws ArgumentError select(contract_some, :between, Site(i), Site(i + 1)) + @test_throws MethodError tensors(contract_some, :between, Site(i), Site(i + 1)) @test isrightcanonical(contract_some, Site(i)) @test isleftcanonical( @@ -128,8 +128,8 @@ Site(i + 1), ) - Γᵢ = select(canonized, :tensor, Site(i)) - Λᵢ₊₁ = select(canonized, :between, Site(i), Site(i + 1)) + Γᵢ = tensors(canonized; at = Site(i)) + Λᵢ₊₁ = tensors(canonized; between = (Site(i), Site(i + 1))) @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims = ()) end end @@ -144,28 +144,28 @@ canonized = canonize_site(qtn, site"1"; direction = :right, method = method) @test isleftcanonical(canonized, site"1") @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) canonized = canonize_site(qtn, site"2"; direction = :right, method = method) @test isleftcanonical(canonized, site"2") @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) canonized = canonize_site(qtn, site"2"; direction = :left, method = method) @test isrightcanonical(canonized, site"2") @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) canonized = canonize_site(qtn, site"3"; direction = :left, method = method) @test isrightcanonical(canonized, site"3") @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) end @@ -182,13 +182,13 @@ @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) @test isapprox(norm(qtn), norm(canonized)) # Extract the singular values between each adjacent pair of sites in the canonized chain - Λ = [select(canonized, :between, Site(i), Site(i + 1)) for i in 1:4] + Λ = [tensors(canonized; between = (Site(i), Site(i + 1))) for i in 1:4] @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 for i in 1:5 @@ -198,7 +198,7 @@ @test isleftcanonical(canonized, Site(i)) elseif i == 5 # in the limits of the chain, we get the norm of the state contract!(canonized, :between, Site(i - 1), Site(i); direction = :right) - tensor = select(canonized, :tensor, Site(i)) + tensor = tensors(canonized; at = Site(i)) replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized)) @test isleftcanonical(canonized, Site(i)) else @@ -212,7 +212,7 @@ if i == 1 # in the limits of the chain, we get the norm of the state contract!(canonized, :between, Site(i), Site(i + 1); direction = :left) - tensor = select(canonized, :tensor, Site(i)) + tensor = tensors(canonized; at = Site(i)) replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized)) @test isrightcanonical(canonized, Site(i)) elseif i == 5 @@ -235,7 +235,7 @@ @test isrightcanonical(canonized, Site(5)) @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) end diff --git a/test/integration/Quac_test.jl b/test/integration/Quac_test.jl index 3a644bf..89016a2 100644 --- a/test/integration/Quac_test.jl +++ b/test/integration/Quac_test.jl @@ -13,11 +13,11 @@ # all open indices are sites siteinds = getindex.((qftqtn,), sites(qftqtn)) - @test issetequal(inds(TensorNetwork(qftqtn), :open), siteinds) + @test issetequal(inds(TensorNetwork(qftqtn); set = :open), siteinds) # all inner indices are not sites # TODO too strict condition. remove? notsiteinds = setdiff(inds(TensorNetwork(qftqtn)), siteinds) - @test_skip issetequal(inds(TensorNetwork(qftqtn), :inner), notsiteinds) + @test_skip issetequal(inds(TensorNetwork(qftqtn); set = :inner), notsiteinds) end end From 81f4e5de4cd58c312131201e95d86de142af9490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Sun, 5 May 2024 02:37:35 +0200 Subject: [PATCH 07/18] Bump version to 0.1.2 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4eaa086..9498317 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Qrochet" uuid = "881a8f22-b5d0-48b0-96e5-a244b33f36d4" authors = ["Sergio Sánchez Ramírez "] -version = "0.1.1" +version = "0.1.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From 22c09bb732f4a0d726efb06f5126fdcf2c725645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 20 May 2024 18:01:23 +0200 Subject: [PATCH 08/18] Refactor `sites` methods --- src/Ansatz.jl | 3 ++- src/Quantum.jl | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index f7c9ae0..ff64360 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -24,7 +24,6 @@ for f in [ :noutputs, :inputs, :outputs, - :sites, :nsites, :nlanes, :socket, @@ -46,6 +45,8 @@ function Base.summary(io::IO, tn::A) where {A<:Ansatz} end Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn) +sites(tn::Ansatz; kwargs...) = sites(Quantum(tn); kwargs...) + function Tenet.inds(tn::Ansatz; kwargs...) if keys(kwargs) === (:bond,) inds(tn, Val(:bond), kwargs[:bond]...) diff --git a/src/Quantum.jl b/src/Quantum.jl index eaa5432..8c5c898 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -150,7 +150,15 @@ Base.show(io::IO, q::Quantum) = print(io, "Quantum (inputs=$(ninputs(q)), output Returns the sites of a [`Quantum`](@ref) Tensor Network. """ -sites(tn::Quantum) = collect(keys(tn.sites)) +function sites(tn::Quantum; kwargs...) + if isempty(kwargs) + collect(keys(tn.sites)) + elseif keys(kwargs) === (:at,) + findfirst(i -> i === kwargs[:at], tn.sites) + else + throw(MethodError(sites, (Quantum,), kwargs)) + end +end """ nsites(q::Quantum) From 71ac3c2c63b55dca980a78f0f82a2b30a7c93109 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 24 May 2024 19:45:00 +0200 Subject: [PATCH 09/18] CompatHelper: bump compat for Yao in [weakdeps] to 0.9, (keep existing compat) (#38) Co-authored-by: CompatHelper Julia --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9498317..cf8dbb5 100644 --- a/Project.toml +++ b/Project.toml @@ -26,5 +26,5 @@ ChainRulesCore = "1.0" Muscle = "0.1" Quac = "0.3" Tenet = "0.6" -Yao = "0.8" +Yao = "0.8, 0.9" julia = "1.9" From 2e90905daf1bc6d44e9791ba77bc693cf892c890 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Fri, 24 May 2024 17:42:24 +0000 Subject: [PATCH 10/18] CompatHelper: add new compat entry for ChainRulesTestUtils in [weakdeps] at version 1, (keep existing compat) --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index cf8dbb5..e7edab0 100644 --- a/Project.toml +++ b/Project.toml @@ -23,6 +23,7 @@ QrochetYaoExt = "Yao" [compat] ChainRulesCore = "1.0" +ChainRulesTestUtils = "1" Muscle = "0.1" Quac = "0.3" Tenet = "0.6" From 16ea9fa2c269351f484c78cc708d87771df6d1ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 2 Jun 2024 07:23:24 +0200 Subject: [PATCH 11/18] Add `zeros`,`ones` methods to `Product` --- src/Ansatz/Product.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index f9d6d38..0efeac7 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -41,6 +41,14 @@ function Product(::Operator, ::Open, arrays) Product(TensorNetwork(_tensors), sitemap) end +function Base.zeros(::Type{Product}, n; p::Int = 2, eltype = Bool) + Product(State(), Open(), fill(append!([one(eltype)], Iterators.repeated(zero(eltype), p - 1)), n)) +end + +function Base.ones(::Type{Product}, n; p::Int = 2, eltype = Bool) + Product(State(), Open(), fill(append!([zero(eltype), one(eltype)], Iterators.repeated(zero(eltype), p - 2)), n)) +end + LinearAlgebra.norm(tn::Product, p::Real = 2) = LinearAlgebra.norm(socket(tn), tn, p) function LinearAlgebra.norm(::Union{State,Operator}, tn::Product, p::Real) mapreduce(*, tensors(tn)) do tensor From 94f9afc2b0c36933d4fe1deb4779e79b657d4e88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 2 Jun 2024 07:31:23 +0200 Subject: [PATCH 12/18] Fix warning in `LinearAlgebra.norm` method --- src/Ansatz.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index ff64360..7739752 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -98,9 +98,6 @@ function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) print(io, "Can't access the spectrum on bond $(e.bond)") end -function LinearAlgebra.norm(ψ::Ansatz, p::Real = 2; kwargs...) - p != 2 && throw(ArgumentError("p=$p is not implemented yet")) - - # TODO: Replace with contract(hcat(ψ, ψ')...) when implemented - return contract(merge(TensorNetwork(ψ), TensorNetwork(ψ')); kwargs...) |> only |> sqrt |> abs +function LinearAlgebra.norm2(ψ::Ansatz; kwargs...) + return contract(TensorNetwork(merge(Quantum(ψ), Quantum(ψ'))); kwargs...) |> only |> sqrt |> abs end From 7beeb55f251ae0e00e5bf3f5c6f272e669df7c17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 2 Jun 2024 12:28:32 +0200 Subject: [PATCH 13/18] Fix `LinearAlgebra.norm` call --- src/Ansatz.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index 7739752..270f11e 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -98,6 +98,12 @@ function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) print(io, "Can't access the spectrum on bond $(e.bond)") end +function LinearAlgebra.norm(ψ::Ansatz, p::Real = 2; kwargs...) + p == 2 || throw(ArgumentError("only L2-norm is implemented yet")) + + return LinearAlgebra.norm2(ψ; kwargs...) +end + function LinearAlgebra.norm2(ψ::Ansatz; kwargs...) return contract(TensorNetwork(merge(Quantum(ψ), Quantum(ψ'))); kwargs...) |> only |> sqrt |> abs end From 0aaeca892a9ccdbdb246e252f22c3fcdfdbd492d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 2 Jun 2024 12:32:14 +0200 Subject: [PATCH 14/18] Fix `merge` in `LinearAlgebra.norm2` --- src/Ansatz.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index 270f11e..1a722b7 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -105,5 +105,5 @@ function LinearAlgebra.norm(ψ::Ansatz, p::Real = 2; kwargs...) end function LinearAlgebra.norm2(ψ::Ansatz; kwargs...) - return contract(TensorNetwork(merge(Quantum(ψ), Quantum(ψ'))); kwargs...) |> only |> sqrt |> abs + return contract(merge(TensorNetwork(ψ), TensorNetwork(ψ')); kwargs...) |> only |> sqrt |> abs end From 9d197b5712916bebfc87b1854a081665452703d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Jun 2024 07:48:49 +0200 Subject: [PATCH 15/18] Add distributed contraction example with Dagger --- examples/Project.toml | 11 +++++++ examples/dagger.jl | 71 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 examples/Project.toml create mode 100644 examples/dagger.jl diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 0000000..163dd21 --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,11 @@ +[deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" +KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" +Qrochet = "881a8f22-b5d0-48b0-96e5-a244b33f36d4" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" +TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" +Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c" diff --git a/examples/dagger.jl b/examples/dagger.jl new file mode 100644 index 0000000..91870b8 --- /dev/null +++ b/examples/dagger.jl @@ -0,0 +1,71 @@ +using Tenet +using Qrochet +using Yao: Yao +using EinExprs +using AbstractTrees +using Distributed +using Dagger +using TimespanLogging +using KaHyPar + +m = 10 +circuit = Yao.EasyBuild.rand_google53(m); +H = Quantum(circuit) +ψ = Product(fill([1, 0], Yao.nqubits(circuit))) +qtn = merge(Quantum(ψ), H, Quantum(ψ)') +tn = Tenet.TensorNetwork(qtn) + +contract_smaller_dims = 20 +target_size = 24 + +Tenet.transform!(tn, Tenet.ContractSimplification()) +path = einexpr( + tn, + optimizer = HyPar( + parts = 2, + imbalance = 0.41, + edge_scaler = (ind_size) -> 10 * Int(round(log2(ind_size))), + vertex_scaler = (prod_size) -> 100 * Int(round(exp2(prod_size))), + ), +); + +max_dims_path = @show maximum(ndims, Branches(path)) +flops_path = @show mapreduce(flops, +, Branches(path)) +@show log10(flops_path) + +grouppath = deepcopy(path); +function recursiveforeach!(f, expr) + f(expr) + foreach(arg -> recursiveforeach!(f, arg), args(expr)) +end +sizedict = merge(Iterators.map(i -> i.size, Leaves(path))...); +recursiveforeach!(grouppath) do expr + merge!(expr.size, sizedict) + if all(<(contract_smaller_dims) ∘ ndims, expr.args) + empty!(expr.args) + end +end + +max_dims_grouppath = maximum(ndims, Branches(grouppath)) +flops_grouppath = mapreduce(flops, +, Branches(grouppath)) +targetinds = findslices(SizeScorer(), grouppath, size = 2^(target_size)); + +subexprs = map(Leaves(grouppath)) do expr + EinExprs.select(path, tuple(head(expr)...)) |> only +end + +addprocs(3) +@everywhere using Dagger, Tenet + +disttn = Tenet.TensorNetwork( + map(subexprs) do subexpr + Tensor( + distribute( # data + parent(Tenet.contract(tn; path = subexpr)), + Blocks([i ∈ targetinds ? 1 : 2 for i in head(subexpr)]...), + ), + head(subexpr), # inds + ) + end, +) +@show Tenet.contract(disttn; path = grouppath) From 024bbc167575d09db13591107b426dd0767da90b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Jun 2024 18:28:27 +0200 Subject: [PATCH 16/18] Fix `zeros`,`ones` in `Product` --- src/Ansatz/Product.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index 0efeac7..31ac08f 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -41,12 +41,16 @@ function Product(::Operator, ::Open, arrays) Product(TensorNetwork(_tensors), sitemap) end -function Base.zeros(::Type{Product}, n; p::Int = 2, eltype = Bool) - Product(State(), Open(), fill(append!([one(eltype)], Iterators.repeated(zero(eltype), p - 1)), n)) +function Base.zeros(::Type{Product}, n::Integer; p::Int = 2, eltype = Bool) + Product(State(), Open(), fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) end -function Base.ones(::Type{Product}, n; p::Int = 2, eltype = Bool) - Product(State(), Open(), fill(append!([zero(eltype), one(eltype)], Iterators.repeated(zero(eltype), p - 2)), n)) +function Base.ones(::Type{Product}, n::Integer; p::Int = 2, eltype = Bool) + Product( + State(), + Open(), + fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n), + ) end LinearAlgebra.norm(tn::Product, p::Real = 2) = LinearAlgebra.norm(socket(tn), tn, p) From 8a5c99414d62237d949beee7dd159d4248e5b70a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 4 Jun 2024 18:28:50 +0200 Subject: [PATCH 17/18] Add `Distributed` example --- examples/Project.toml | 5 ++ examples/distributed.jl | 103 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 examples/distributed.jl diff --git a/examples/Project.toml b/examples/Project.toml index 163dd21..9f3cf58 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -1,11 +1,16 @@ [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +ClusterManagers = "34f1f09b-3a8b-5176-ab39-66d58a4d544e" Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Qrochet = "881a8f22-b5d0-48b0-96e5-a244b33f36d4" Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c" diff --git a/examples/distributed.jl b/examples/distributed.jl new file mode 100644 index 0000000..e9025a1 --- /dev/null +++ b/examples/distributed.jl @@ -0,0 +1,103 @@ +using Yao: Yao +using Qrochet +using Tenet +using EinExprs +using KaHyPar +using Random +using Distributed +using ClusterManagers +using AbstractTrees + +n = 64 +depth = 6 + +circuit = Yao.chain(n) + +for _ in 1:depth + perm = randperm(n) + + for (i, j) in Iterators.partition(perm, 2) + push!(circuit, Yao.put((i, j) => Yao.EasyBuild.FSimGate(2π * rand(), 2π * rand()))) + # push!(circuit, Yao.control(n, i, j => Yao.phase(2π * rand()))) + end +end + +H = Quantum(circuit) +ψ = zeros(Product, n) + +tn = TensorNetwork(merge(Quantum(ψ), H, Quantum(ψ)')) +transform!(tn, Tenet.ContractSimplification()) + +path = einexpr( + tn, + optimizer = HyPar( + parts = 2, + imbalance = 0.41, + edge_scaler = (ind_size) -> 10 * Int(round(log2(ind_size))), + vertex_scaler = (prod_size) -> 100 * Int(round(exp2(prod_size))), + ), +) + +@show maximum(ndims, Branches(path)) +@show maximum(length, Branches(path)) * sizeof(eltype(tn)) / 1024^3 + +@show log10(mapreduce(flops, +, Branches(path))) + +cutinds = findslices(SizeScorer(), path, size = 2^24) +cuttings = [[i => dim for dim in 1:size(tn, i)] for i in cutinds] + +# mock sliced path - valid for all slices +proj_inds = first.(cuttings) +slice_path = view(path.path, proj_inds...) + +expr = Tenet.codegen(Val(:outplace), slice_path) + +manager = SlurmManager(2 * 112 - 1) +addprocs(manager, cpus_per_task = 1, exeflags = "--project=$(Base.active_project())") +# @everywhere using LinearAlgebra +# @everywhere LinearAlgebra.BLAS.set_num_threads(2) + +@everywhere using Tenet, EinExprs, IterTools, LinearAlgebra, Reactant, AbstractTrees +@everywhere tn = $tn +@everywhere slice_path = $slice_path +@everywhere cuttings = $cuttings +@everywhere expr = $expr + +partial_results = map(enumerate(workers())) do (i, worker) + Distributed.@spawnat worker begin + # interleaved chunking without instantiation + it = takenth(Iterators.drop(Iterators.product(cuttings...), i - 1), nworkers()) + + f = @eval $expr + mock_slice = view(tn, first(it)...) + tensors′ = [ + Tensor(Reactant.ConcreteRArray(copy(parent(mock_slice[head(leaf)...]))), inds(mock_slice[head(leaf)...])) for leaf in Leaves(slice_path) + ] + g = Reactant.compile(f, Tuple(tensors′)) + + # local reduction of chunk + accumulator = zero(eltype(tn)) + + for proj_inds in it + slice = view(tn, proj_inds...) + tensors′ = [ + Tensor( + Reactant.ConcreteRArray(copy(parent(mock_slice[head(leaf)...]))), + inds(mock_slice[head(leaf)...]), + ) for leaf in Leaves(slice_path) + ] + res = only(g(tensors′...)) + + # avoid OOM due to garbage accumulation + GC.gc() + + accumulator += res + end + + return accumulator + end +end + +@show result = sum(Distributed.fetch.(partial_results)) + +rmprocs(workers()) From 2d882011ec7978b8631114189993ed3f203911d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s?= Date: Tue, 11 Jun 2024 10:29:21 +0200 Subject: [PATCH 18/18] Refactor code --- src/Ansatz/Chain.jl | 85 +++++++++++++++++++++------------------------ 1 file changed, 39 insertions(+), 46 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index b382ef7..1c57c8f 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -652,6 +652,8 @@ function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) return ψ end +Tenet.__check_index_sizes(qtn::Chain) = Tenet.__check_index_sizes(TensorNetwork(qtn)) + evolve(ψ::Chain, mpo::Chain) = evolve!(copy(ψ), mpo) """ @@ -660,60 +662,51 @@ evolve(ψ::Chain, mpo::Chain) = evolve!(copy(ψ), mpo) Applies a Matrix Product Operator (MPO) `mpo` to the [`Chain`](@ref). """ function evolve!(ψ::Chain, mpo::Chain) - updated_tensors = Tensor[] - Λ = Tensor[] L = nsites(ψ) - for i in 1:L - contractedind = select(ψ, :index, Site(i)) - t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims = (contractedind,)) - physicalind = select(mpo, :index, Site(i)) - - # Fuse the two right legs of t into one - if i == 1 - wanted_inds = (physicalind, rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) - new_inds = (contractedind, rightindex(ψ, Site(i))) - elseif i < L - wanted_inds = ( - physicalind, - leftindex(ψ, Site(i)), - leftindex(mpo, Site(i)), - rightindex(ψ, Site(i)), - rightindex(mpo, Site(i)), + Tenet.@unsafe_region ψ begin + for i in 1:L + contractedind = inds(ψ; at = Site(i)) + t = contract(tensors(ψ; at = Site(i)), tensors(mpo; at = Site(i)); dims = (contractedind,)) + physicalind = inds(mpo; at = Site(i)) + + # Fuse the two right legs of t into one + if i == 1 + wanted_inds = (physicalind, rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) + new_inds = (contractedind, rightindex(ψ, Site(i))) + elseif i < L + wanted_inds = ( + physicalind, + leftindex(ψ, Site(i)), + leftindex(mpo, Site(i)), + rightindex(ψ, Site(i)), + rightindex(mpo, Site(i)), + ) + new_inds = (contractedind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) + else + wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i))) + new_inds = (contractedind, leftindex(ψ, Site(i))) + end + + perm = Tenet.__find_index_permutation(wanted_inds, inds(t)) + t = permutedims(t, perm) + + t = Tensor( + reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)), + new_inds, ) - new_inds = (contractedind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) - else - wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i))) - new_inds = (contractedind, leftindex(ψ, Site(i))) - end - perm = Tenet.__find_index_permutation(wanted_inds, inds(t)) - t = permutedims(t, perm) + replace!(TensorNetwork(ψ), tensors(ψ; at = Site(i)) => t) - t = Tensor( - reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)), - new_inds, - ) - push!(updated_tensors, t) - - if i < L - d = size(TensorNetwork(mpo), rightindex(mpo, Site(i))) - Λᵢ = select(ψ, :between, Site(i), Site(i + 1)) - Λᵢ = Tensor(diag(kron(Matrix(LinearAlgebra.I, d, d), diagm(parent(Λᵢ)))), inds(Λᵢ)) - push!(Λ, Λᵢ) + if i < L + d = size(TensorNetwork(mpo), rightindex(mpo, Site(i))) + Λᵢ = tensors(ψ; between = (Site(i), Site(i + 1))) + Λᵢ = Tensor(diag(kron(Matrix(LinearAlgebra.I, d, d), diagm(parent(Λᵢ)))), inds(Λᵢ)) + replace!(TensorNetwork(ψ), tensors(ψ; between = (Site(i), Site(i + 1))) => Λᵢ) + end end end - for i in 1:L - i < L && pop!(TensorNetwork(ψ), select(ψ, :between, Site(i), Site(i + 1))) - pop!(TensorNetwork(ψ), select(ψ, :tensor, Site(i))) - end - - for i in 1:L - i < L && push!(TensorNetwork(ψ), Λ[i]) - push!(TensorNetwork(ψ), updated_tensors[i]) - end - return ψ end