diff --git a/Project.toml b/Project.toml index 694e89c..e7edab0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,13 @@ name = "Qrochet" uuid = "881a8f22-b5d0-48b0-96e5-a244b33f36d4" authors = ["Sergio Sánchez Ramírez "] -version = "0.1.1" +version = "0.1.2" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Muscle = "21fe5c4b-a943-414d-bf3e-516f24900631" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" -ValSplit = "0625e100-946b-11ec-09cd-6328dd093154" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -24,9 +23,9 @@ QrochetYaoExt = "Yao" [compat] ChainRulesCore = "1.0" +ChainRulesTestUtils = "1" Muscle = "0.1" Quac = "0.3" -Tenet = "0.5" -ValSplit = "0.1" -Yao = "0.8" +Tenet = "0.6" +Yao = "0.8, 0.9" julia = "1.9" diff --git a/docs/src/quantum.md b/docs/src/quantum.md index 1b78a93..0c8dd03 100644 --- a/docs/src/quantum.md +++ b/docs/src/quantum.md @@ -11,8 +11,8 @@ nsites ## Queries ```@docs -Tenet.select(::Quantum, ::Val{:index}, ::Site) -Tenet.select(::Quantum, ::Val{:tensor}, ::Site) +Tenet.inds(::Quantum; kwargs...) +Tenet.tensors(::Quantum; kwargs...) ``` ## Connecting `Quantum` Tensor Networks diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 0000000..9f3cf58 --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,16 @@ +[deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +ClusterManagers = "34f1f09b-3a8b-5176-ab39-66d58a4d544e" +Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" +Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" +EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" +Qrochet = "881a8f22-b5d0-48b0-96e5-a244b33f36d4" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" +Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" +Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec" +TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63" +Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c" diff --git a/examples/dagger.jl b/examples/dagger.jl new file mode 100644 index 0000000..91870b8 --- /dev/null +++ b/examples/dagger.jl @@ -0,0 +1,71 @@ +using Tenet +using Qrochet +using Yao: Yao +using EinExprs +using AbstractTrees +using Distributed +using Dagger +using TimespanLogging +using KaHyPar + +m = 10 +circuit = Yao.EasyBuild.rand_google53(m); +H = Quantum(circuit) +ψ = Product(fill([1, 0], Yao.nqubits(circuit))) +qtn = merge(Quantum(ψ), H, Quantum(ψ)') +tn = Tenet.TensorNetwork(qtn) + +contract_smaller_dims = 20 +target_size = 24 + +Tenet.transform!(tn, Tenet.ContractSimplification()) +path = einexpr( + tn, + optimizer = HyPar( + parts = 2, + imbalance = 0.41, + edge_scaler = (ind_size) -> 10 * Int(round(log2(ind_size))), + vertex_scaler = (prod_size) -> 100 * Int(round(exp2(prod_size))), + ), +); + +max_dims_path = @show maximum(ndims, Branches(path)) +flops_path = @show mapreduce(flops, +, Branches(path)) +@show log10(flops_path) + +grouppath = deepcopy(path); +function recursiveforeach!(f, expr) + f(expr) + foreach(arg -> recursiveforeach!(f, arg), args(expr)) +end +sizedict = merge(Iterators.map(i -> i.size, Leaves(path))...); +recursiveforeach!(grouppath) do expr + merge!(expr.size, sizedict) + if all(<(contract_smaller_dims) ∘ ndims, expr.args) + empty!(expr.args) + end +end + +max_dims_grouppath = maximum(ndims, Branches(grouppath)) +flops_grouppath = mapreduce(flops, +, Branches(grouppath)) +targetinds = findslices(SizeScorer(), grouppath, size = 2^(target_size)); + +subexprs = map(Leaves(grouppath)) do expr + EinExprs.select(path, tuple(head(expr)...)) |> only +end + +addprocs(3) +@everywhere using Dagger, Tenet + +disttn = Tenet.TensorNetwork( + map(subexprs) do subexpr + Tensor( + distribute( # data + parent(Tenet.contract(tn; path = subexpr)), + Blocks([i ∈ targetinds ? 1 : 2 for i in head(subexpr)]...), + ), + head(subexpr), # inds + ) + end, +) +@show Tenet.contract(disttn; path = grouppath) diff --git a/examples/distributed.jl b/examples/distributed.jl new file mode 100644 index 0000000..e9025a1 --- /dev/null +++ b/examples/distributed.jl @@ -0,0 +1,103 @@ +using Yao: Yao +using Qrochet +using Tenet +using EinExprs +using KaHyPar +using Random +using Distributed +using ClusterManagers +using AbstractTrees + +n = 64 +depth = 6 + +circuit = Yao.chain(n) + +for _ in 1:depth + perm = randperm(n) + + for (i, j) in Iterators.partition(perm, 2) + push!(circuit, Yao.put((i, j) => Yao.EasyBuild.FSimGate(2π * rand(), 2π * rand()))) + # push!(circuit, Yao.control(n, i, j => Yao.phase(2π * rand()))) + end +end + +H = Quantum(circuit) +ψ = zeros(Product, n) + +tn = TensorNetwork(merge(Quantum(ψ), H, Quantum(ψ)')) +transform!(tn, Tenet.ContractSimplification()) + +path = einexpr( + tn, + optimizer = HyPar( + parts = 2, + imbalance = 0.41, + edge_scaler = (ind_size) -> 10 * Int(round(log2(ind_size))), + vertex_scaler = (prod_size) -> 100 * Int(round(exp2(prod_size))), + ), +) + +@show maximum(ndims, Branches(path)) +@show maximum(length, Branches(path)) * sizeof(eltype(tn)) / 1024^3 + +@show log10(mapreduce(flops, +, Branches(path))) + +cutinds = findslices(SizeScorer(), path, size = 2^24) +cuttings = [[i => dim for dim in 1:size(tn, i)] for i in cutinds] + +# mock sliced path - valid for all slices +proj_inds = first.(cuttings) +slice_path = view(path.path, proj_inds...) + +expr = Tenet.codegen(Val(:outplace), slice_path) + +manager = SlurmManager(2 * 112 - 1) +addprocs(manager, cpus_per_task = 1, exeflags = "--project=$(Base.active_project())") +# @everywhere using LinearAlgebra +# @everywhere LinearAlgebra.BLAS.set_num_threads(2) + +@everywhere using Tenet, EinExprs, IterTools, LinearAlgebra, Reactant, AbstractTrees +@everywhere tn = $tn +@everywhere slice_path = $slice_path +@everywhere cuttings = $cuttings +@everywhere expr = $expr + +partial_results = map(enumerate(workers())) do (i, worker) + Distributed.@spawnat worker begin + # interleaved chunking without instantiation + it = takenth(Iterators.drop(Iterators.product(cuttings...), i - 1), nworkers()) + + f = @eval $expr + mock_slice = view(tn, first(it)...) + tensors′ = [ + Tensor(Reactant.ConcreteRArray(copy(parent(mock_slice[head(leaf)...]))), inds(mock_slice[head(leaf)...])) for leaf in Leaves(slice_path) + ] + g = Reactant.compile(f, Tuple(tensors′)) + + # local reduction of chunk + accumulator = zero(eltype(tn)) + + for proj_inds in it + slice = view(tn, proj_inds...) + tensors′ = [ + Tensor( + Reactant.ConcreteRArray(copy(parent(mock_slice[head(leaf)...]))), + inds(mock_slice[head(leaf)...]), + ) for leaf in Leaves(slice_path) + ] + res = only(g(tensors′...)) + + # avoid OOM due to garbage accumulation + GC.gc() + + accumulator += res + end + + return accumulator + end +end + +@show result = sum(Distributed.fetch.(partial_results)) + +rmprocs(workers()) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index b6137ac..1a722b7 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -1,5 +1,4 @@ using Tenet -using ValSplit using LinearAlgebra """ @@ -25,11 +24,9 @@ for f in [ :noutputs, :inputs, :outputs, - :sites, :nsites, :nlanes, :socket, - :(Tenet.tensors), :(Tenet.arrays), :(Base.collect), ] @@ -46,33 +43,49 @@ alias(::A) where {A} = string(A) function Base.summary(io::IO, tn::A) where {A<:Ansatz} print(io, "$(alias(tn)) (inputs=$(ninputs(tn)), outputs=$(noutputs(tn)))") end -Base.show(io::IO, tn::A) where {A<:Ansatz} = Base.summary(io, tn) +Base.show(io::IO, tn::A) where {A<:Ansatz} = summary(io, tn) -@valsplit 2 Tenet.select(tn::Ansatz, query::Symbol, args...) = select(Quantum(tn), query, args...) +sites(tn::Ansatz; kwargs...) = sites(Quantum(tn); kwargs...) -function Tenet.select(tn::Ansatz, ::Val{:between}, site1::Site, site2::Site) +function Tenet.inds(tn::Ansatz; kwargs...) + if keys(kwargs) === (:bond,) + inds(tn, Val(:bond), kwargs[:bond]...) + else + inds(Quantum(tn); kwargs...) + end +end + +function Tenet.inds(tn::Ansatz, ::Val{:bond}, site1::Site, site2::Site) @assert site1 ∈ sites(tn) "Site $site1 not found" @assert site2 ∈ sites(tn) "Site $site2 not found" @assert site1 != site2 "Sites must be different" - tensor1 = select(Quantum(tn), :tensor, site1) - tensor2 = select(Quantum(tn), :tensor, site2) + tensor1 = tensors(tn; at = site1) + tensor2 = tensors(tn; at = site2) isdisjoint(inds(tensor1), inds(tensor2)) && return nothing + return only(inds(tensor1) ∩ inds(tensor2)) +end - TensorNetwork(tn)[only(inds(tensor1) ∩ inds(tensor2))] +function Tenet.tensors(tn::Ansatz; kwargs...) + if keys(kwargs) === (:between,) + tensors(tn, Val(:between), kwargs[:between]...) + else + tensors(Quantum(tn); kwargs...) + end end -function Tenet.select(tn::Ansatz, ::Val{:bond}, site1::Site, site2::Site) +function Tenet.tensors(tn::Ansatz, ::Val{:between}, site1::Site, site2::Site) @assert site1 ∈ sites(tn) "Site $site1 not found" @assert site2 ∈ sites(tn) "Site $site2 not found" @assert site1 != site2 "Sites must be different" - tensor1 = select(Quantum(tn), :tensor, site1) - tensor2 = select(Quantum(tn), :tensor, site2) + tensor1 = tensors(tn; at = site1) + tensor2 = tensors(tn; at = site2) isdisjoint(inds(tensor1), inds(tensor2)) && return nothing - return only(inds(tensor1) ∩ inds(tensor2)) + + TensorNetwork(tn)[only(inds(tensor1) ∩ inds(tensor2))] end struct MissingSchmidtCoefficientsException <: Base.Exception @@ -86,8 +99,11 @@ function Base.showerror(io::IO, e::MissingSchmidtCoefficientsException) end function LinearAlgebra.norm(ψ::Ansatz, p::Real = 2; kwargs...) - p != 2 && throw(ArgumentError("p=$p is not implemented yet")) + p == 2 || throw(ArgumentError("only L2-norm is implemented yet")) + + return LinearAlgebra.norm2(ψ; kwargs...) +end - # TODO: Replace with contract(hcat(ψ, ψ')...) when implemented +function LinearAlgebra.norm2(ψ::Ansatz; kwargs...) return contract(merge(TensorNetwork(ψ), TensorNetwork(ψ')); kwargs...) |> only |> sqrt |> abs end diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 4e7d033..1c57c8f 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -129,11 +129,11 @@ rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site) -leftindex(::Periodic, tn::Chain, site::Site) = select(tn, :bond, site, leftsite(tn, site)) +leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, leftsite(tn, site))) rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) rightindex(::Open, tn::Chain, site::Site) = site == Site(nlanes(tn)) ? nothing : rightindex(Periodic(), tn, site) -rightindex(::Periodic, tn::Chain, site::Site) = select(tn, :bond, site, rightsite(tn, site)) +rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, rightsite(tn, site))) Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain)) @@ -248,14 +248,14 @@ function Tenet.contract!( direction::Symbol = :left, delete_Λ = true, ) - Λᵢ = select(tn, :between, site1, site2) + Λᵢ = tensors(tn; between = (site1, site2)) Λᵢ === nothing && return tn if direction === :right - Γᵢ₊₁ = select(tn, :tensor, site2) + Γᵢ₊₁ = tensors(tn; at = site2) replace!(TensorNetwork(tn), Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ, dims = ())) elseif direction === :left - Γᵢ = select(tn, :tensor, site1) + Γᵢ = tensors(tn; at = site1) replace!(TensorNetwork(tn), Γᵢ => contract(Λᵢ, Γᵢ, dims = ())) else throw(ArgumentError("Unknown direction=:$direction")) @@ -322,13 +322,13 @@ Truncate the dimension of the virtual `bond`` of the [`Chain`](@ref) Tensor Netw - The bond must contain the Schmidt coefficients, i.e. a site canonization must be performed before calling `truncate!`. """ function truncate!(qtn::Chain, bond; threshold::Union{Nothing,Real} = nothing, maxdim::Union{Nothing,Int} = nothing) - # TODO replace for select(:between) + # TODO replace for tensors(; between) vind = rightindex(qtn, bond[1]) if vind != leftindex(qtn, bond[2]) throw(ArgumentError("Invalid bond $bond")) end - if vind ∉ inds(TensorNetwork(qtn), :hyper) + if vind ∉ inds(qtn; set = :hyper) throw(MissingSchmidtCoefficientsException(bond)) end @@ -357,7 +357,7 @@ end function isleftcanonical(qtn::Chain, site; atol::Real = 1e-12) right_ind = rightindex(qtn, site) - tensor = select(qtn, :tensor, site) + tensor = tensors(qtn; at = site) # we are at right-most site, we need to add an extra dummy dimension to the tensor if isnothing(right_ind) @@ -375,7 +375,7 @@ end function isrightcanonical(qtn::Chain, site; atol::Real = 1e-12) left_ind = leftindex(qtn, site) - tensor = select(qtn, :tensor, site) + tensor = tensors(qtn; at = site) # we are at left-most site, we need to add an extra dummy dimension to the tensor if isnothing(left_ind) @@ -413,15 +413,15 @@ function canonize!(::Open, tn::Chain) canonize_site!(tn, Site(i); direction = :right, method = :svd) # extract the singular values and contract them with the next tensor - Λᵢ = pop!(TensorNetwork(tn), select(tn, :between, Site(i), Site(i + 1))) - Aᵢ₊₁ = select(tn, :tensor, Site(i + 1)) + Λᵢ = pop!(TensorNetwork(tn), tensors(tn; between = (Site(i), Site(i + 1)))) + Aᵢ₊₁ = tensors(tn; at = Site(i + 1)) replace!(TensorNetwork(tn), Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ, dims = ())) push!(Λ, Λᵢ) end for i in 2:nsites(tn) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ Λᵢ = Λ[i-1] # singular values start between site 1 and 2 - A = select(tn, :tensor, Site(i)) + A = tensors(tn; at = Site(i)) Γᵢ = contract(A, Tensor(diag(pinv(Diagonal(parent(Λᵢ)), atol = 1e-64)), inds(Λᵢ)), dims = ()) replace!(TensorNetwork(tn), A => Γᵢ) push!(TensorNetwork(tn), Λᵢ) @@ -462,7 +462,7 @@ to mixed-canonized form with the given center site. """ function LinearAlgebra.normalize!(tn::Chain, root::Site; p::Real = 2) mixed_canonize!(tn, root) - normalize!(select(Quantum(tn), :tensor, root), p) + normalize!(tensors(Quantum(tn); at = root), p) return tn end @@ -522,11 +522,11 @@ function evolve_1site!(qtn::Chain, gate::Dense) targetsite = only(inputs(gate))' # reindex contracting index - replace!(TensorNetwork(qtn), select(qtn, :index, targetsite) => contracting_index) - replace!(TensorNetwork(gate), select(gate, :index, targetsite') => contracting_index) + replace!(TensorNetwork(qtn), inds(qtn; at = targetsite) => contracting_index) + replace!(TensorNetwork(gate), inds(gate; at = targetsite') => contracting_index) # reindex output of gate to match TN sitemap - replace!(TensorNetwork(gate), select(gate, :index, only(outputs(gate))) => select(qtn, :index, targetsite)) + replace!(TensorNetwork(gate), inds(gate; at = only(outputs(gate))) => inds(qtn; at = targetsite)) # contract gate with TN merge!(TensorNetwork(qtn), TensorNetwork(gate)) @@ -542,23 +542,23 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = left_inds::Vector{Symbol} = !isnothing(leftindex(qtn, sitel)) ? [leftindex(qtn, sitel)] : Symbol[] right_inds::Vector{Symbol} = !isnothing(rightindex(qtn, siter)) ? [rightindex(qtn, siter)] : Symbol[] - virtualind::Symbol = select(qtn, :bond, bond...) + virtualind::Symbol = inds(qtn, :bond, bond...) iscanonical ? contract_2sitewf!(qtn, bond) : contract!(TensorNetwork(qtn), virtualind) # reindex contracting index contracting_inds = [gensym(:tmp) for _ in inputs(gate)] replace!(TensorNetwork(qtn), map(zip(inputs(gate), contracting_inds)) do (site, contracting_index) - select(qtn, :index, site') => contracting_index + inds(qtn; at = site') => contracting_index end) replace!(TensorNetwork(gate), map(zip(inputs(gate), contracting_inds)) do (site, contracting_index) - select(gate, :index, site) => contracting_index + inds(gate; at = site) => contracting_index end) # reindex output of gate to match TN sitemap for site in outputs(gate) - if select(qtn, :index, site) != select(gate, :index, site) - replace!(TensorNetwork(gate), select(gate, :index, site) => select(qtn, :index, site)) + if inds(qtn; at = site) != inds(gate; at = site) + replace!(TensorNetwork(gate), inds(gate; at = site) => inds(qtn; at = site)) end end @@ -567,8 +567,8 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = contract!(TensorNetwork(qtn), contracting_inds) # decompose using SVD - push!(left_inds, select(qtn, :index, sitel)) - push!(right_inds, select(qtn, :index, siter)) + push!(left_inds, inds(qtn; at = sitel)) + push!(right_inds, inds(qtn; at = siter)) if iscanonical unpack_2sitewf!(qtn, bond, left_inds, right_inds, virtualind) @@ -581,7 +581,7 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = # renormalize the bond if renormalize && iscanonical - λ = select(qtn, :between, bond...) + λ = tensors(qtn; between = bond) replace!(TensorNetwork(qtn), λ => normalize(λ)) elseif renormalize && !iscanonical normalize!(qtn, bond[1]) @@ -604,13 +604,13 @@ function contract_2sitewf!(ψ::Chain, bond) (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel) - Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : select(ψ, :between, siter, Site(id(siter) + 1)) + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between = (Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between = (siter, Site(id(siter) + 1))) !isnothing(Λᵢ₋₁) && contract!(ψ, :between, Site(id(sitel) - 1), sitel; direction = :right, delete_Λ = false) !isnothing(Λᵢ₊₁) && contract!(ψ, :between, siter, Site(id(siter) + 1); direction = :left, delete_Λ = false) - contract!(TensorNetwork(ψ), select(ψ, :bond, bond...)) + contract!(TensorNetwork(ψ), inds(ψ, :bond, bond...)) return ψ end @@ -628,11 +628,11 @@ function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel) - Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : select(ψ, :between, siter, Site(id(siter) + 1)) + Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between = (Site(id(sitel) - 1), sitel)) + Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between = (siter, Site(id(siter) + 1))) # do svd of the θ tensor - θ = select(ψ, :tensor, sitel) + θ = tensors(ψ; at = sitel) U, s, Vt = svd(θ; left_inds, right_inds, virtualind) # contract with the inverse of Λᵢ and Λᵢ₊₂ @@ -652,6 +652,64 @@ function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) return ψ end +Tenet.__check_index_sizes(qtn::Chain) = Tenet.__check_index_sizes(TensorNetwork(qtn)) + +evolve(ψ::Chain, mpo::Chain) = evolve!(copy(ψ), mpo) + +""" + evolve!(ψ::Chain, mpo::Chain) + +Applies a Matrix Product Operator (MPO) `mpo` to the [`Chain`](@ref). +""" +function evolve!(ψ::Chain, mpo::Chain) + L = nsites(ψ) + + Tenet.@unsafe_region ψ begin + for i in 1:L + contractedind = inds(ψ; at = Site(i)) + t = contract(tensors(ψ; at = Site(i)), tensors(mpo; at = Site(i)); dims = (contractedind,)) + physicalind = inds(mpo; at = Site(i)) + + # Fuse the two right legs of t into one + if i == 1 + wanted_inds = (physicalind, rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) + new_inds = (contractedind, rightindex(ψ, Site(i))) + elseif i < L + wanted_inds = ( + physicalind, + leftindex(ψ, Site(i)), + leftindex(mpo, Site(i)), + rightindex(ψ, Site(i)), + rightindex(mpo, Site(i)), + ) + new_inds = (contractedind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) + else + wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i))) + new_inds = (contractedind, leftindex(ψ, Site(i))) + end + + perm = Tenet.__find_index_permutation(wanted_inds, inds(t)) + t = permutedims(t, perm) + + t = Tensor( + reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)), + new_inds, + ) + + replace!(TensorNetwork(ψ), tensors(ψ; at = Site(i)) => t) + + if i < L + d = size(TensorNetwork(mpo), rightindex(mpo, Site(i))) + Λᵢ = tensors(ψ; between = (Site(i), Site(i + 1))) + Λᵢ = Tensor(diag(kron(Matrix(LinearAlgebra.I, d, d), diagm(parent(Λᵢ)))), inds(Λᵢ)) + replace!(TensorNetwork(ψ), tensors(ψ; between = (Site(i), Site(i + 1))) => Λᵢ) + end + end + end + + return ψ +end + function expect(ψ::Chain, observables) # contract observable with TN ϕ = copy(ψ) diff --git a/src/Ansatz/Product.jl b/src/Ansatz/Product.jl index f9d6d38..31ac08f 100644 --- a/src/Ansatz/Product.jl +++ b/src/Ansatz/Product.jl @@ -41,6 +41,18 @@ function Product(::Operator, ::Open, arrays) Product(TensorNetwork(_tensors), sitemap) end +function Base.zeros(::Type{Product}, n::Integer; p::Int = 2, eltype = Bool) + Product(State(), Open(), fill(append!([one(eltype)], collect(Iterators.repeated(zero(eltype), p - 1))), n)) +end + +function Base.ones(::Type{Product}, n::Integer; p::Int = 2, eltype = Bool) + Product( + State(), + Open(), + fill(append!([zero(eltype), one(eltype)], collect(Iterators.repeated(zero(eltype), p - 2))), n), + ) +end + LinearAlgebra.norm(tn::Product, p::Real = 2) = LinearAlgebra.norm(socket(tn), tn, p) function LinearAlgebra.norm(::Union{State,Operator}, tn::Product, p::Real) mapreduce(*, tensors(tn)) do tensor diff --git a/src/Qrochet.jl b/src/Qrochet.jl index 5ab658c..b7d6bdc 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -33,6 +33,6 @@ export evolve!, expect, overlap # reexports from Tenet using Tenet -export select +export inds, tensors, arrays end diff --git a/src/Quantum.jl b/src/Quantum.jl index 815ed6f..8c5c898 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -1,5 +1,4 @@ using Tenet -using ValSplit # TODO Should we store here some information about quantum numbers? """ @@ -151,7 +150,15 @@ Base.show(io::IO, q::Quantum) = print(io, "Quantum (inputs=$(ninputs(q)), output Returns the sites of a [`Quantum`](@ref) Tensor Network. """ -sites(tn::Quantum) = collect(keys(tn.sites)) +function sites(tn::Quantum; kwargs...) + if isempty(kwargs) + collect(keys(tn.sites)) + elseif keys(kwargs) === (:at,) + findfirst(i -> i === kwargs[:at], tn.sites) + else + throw(MethodError(sites, (Quantum,), kwargs)) + end +end """ nsites(q::Quantum) @@ -181,7 +188,7 @@ nlanes(tn::Quantum) = length(lanes(tn)) Returns the index associated with a site in a [`Quantum`](@ref) Tensor Network. """ -Base.getindex(q::Quantum, site::Site) = q.sites[site] +Base.getindex(q::Quantum, site::Site) = inds(q; at = site) """ Socket @@ -232,25 +239,43 @@ function socket(q::Quantum) end # forward `TensorNetwork` methods -for f in [:(Tenet.tensors), :(Tenet.arrays), :(Base.collect)] +for f in [:(Tenet.arrays), :(Base.collect)] @eval $f(@nospecialize tn::Quantum) = $f(TensorNetwork(tn)) end -@valsplit 2 Tenet.select(tn::Quantum, query::Symbol, args...) = error("Query ':$query' not defined") - """ - select(q::Quantum, :index, site::Site) + inds(tn::Quantum, set::Symbol = :all, args...; kwargs...) -Selects the index associated with a site in a [`Quantum`](@ref) Tensor Network. +Options: + + - `:at`: index at a site """ -Tenet.select(tn::Quantum, ::Val{:index}, site::Site) = tn[site] +function Tenet.inds(tn::Quantum; kwargs...) + if keys(kwargs) === (:at,) + inds(tn, Val(:at), kwargs[:at]) + else + inds(TensorNetwork(tn); kwargs...) + end +end + +Tenet.inds(tn::Quantum, ::Val{:at}, site::Site) = tn.sites[site] """ - select(q::Quantum, :tensor, site::Site) + tensors(tn::Quantum, query::Symbol, args...; kwargs...) -Selects the tensor associated with a site in a [`Quantum`](@ref) Tensor Network. +Options: + + - `:at`: tensor at a site """ -Tenet.select(tn::Quantum, ::Val{:tensor}, site::Site) = select(TensorNetwork(tn), :any, tn[site]) |> only +function Tenet.tensors(tn::Quantum; kwargs...) + if keys(kwargs) === (:at,) + tensors(tn, Val(:at), kwargs[:at]) + else + tensors(TensorNetwork(tn); kwargs...) + end +end + +Tenet.tensors(tn::Quantum, ::Val{:at}, site::Site) = only(tensors(tn; intersects = inds(tn; at = site))) function reindex!(a::Quantum, ioa, b::Quantum, iob) ioa ∈ [:inputs, :outputs] || error("Invalid argument: :$ioa") @@ -264,7 +289,7 @@ function reindex!(a::Quantum, ioa, b::Quantum, iob) end replacements = map(sitesb) do site - select(b, :index, site) => select(a, :index, ioa != iob ? site' : site) + inds(b; at = site) => inds(a; at = ioa != iob ? site' : site) end if issetequal(first.(replacements), last.(replacements)) @@ -274,7 +299,7 @@ function reindex!(a::Quantum, ioa, b::Quantum, iob) replace!(TensorNetwork(b), replacements...) for site in sitesb - b.sites[site] = select(a, :index, ioa != iob ? site' : site) + b.sites[site] = inds(a; at = ioa != iob ? site' : site) end b @@ -312,11 +337,11 @@ function Base.merge(a::Quantum, b::Quantum) sites = Dict{Site,Symbol}() for site in inputs(a) - sites[site] = select(a, :index, site) + sites[site] = inds(a; at = site) end for site in outputs(b) - sites[site] = select(b, :index, site) + sites[site] = inds(b; at = site) end Quantum(tn, sites) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 2269434..0eb7bcb 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -66,7 +66,7 @@ @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 - singular_values = select(qtn, :between, Site(2), Site(3)) + singular_values = tensors(qtn; between = (Site(2), Site(3))) truncated = Qrochet.truncate(qtn, [Site(2), Site(3)]; threshold = singular_values[2] + 0.1) @test size(TensorNetwork(truncated), rightindex(truncated, Site(2))) == 1 @test size(TensorNetwork(truncated), leftindex(truncated, Site(3))) == 1 @@ -117,10 +117,10 @@ for i in 1:4 contract_some = contract(canonized, :between, Site(i), Site(i + 1)) - Bᵢ = select(contract_some, :tensor, Site(i)) + Bᵢ = tensors(contract_some; at = Site(i)) @test isapprox(contract(TensorNetwork(contract_some)), contract(TensorNetwork(qtn))) - @test_throws ArgumentError select(contract_some, :between, Site(i), Site(i + 1)) + @test_throws MethodError tensors(contract_some, :between, Site(i), Site(i + 1)) @test isrightcanonical(contract_some, Site(i)) @test isleftcanonical( @@ -128,8 +128,8 @@ Site(i + 1), ) - Γᵢ = select(canonized, :tensor, Site(i)) - Λᵢ₊₁ = select(canonized, :between, Site(i), Site(i + 1)) + Γᵢ = tensors(canonized; at = Site(i)) + Λᵢ₊₁ = tensors(canonized; between = (Site(i), Site(i + 1))) @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims = ()) end end @@ -144,28 +144,28 @@ canonized = canonize_site(qtn, site"1"; direction = :right, method = method) @test isleftcanonical(canonized, site"1") @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) canonized = canonize_site(qtn, site"2"; direction = :right, method = method) @test isleftcanonical(canonized, site"2") @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) canonized = canonize_site(qtn, site"2"; direction = :left, method = method) @test isrightcanonical(canonized, site"2") @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) canonized = canonize_site(qtn, site"3"; direction = :left, method = method) @test isrightcanonical(canonized, site"3") @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) end @@ -182,13 +182,13 @@ @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) @test isapprox(norm(qtn), norm(canonized)) # Extract the singular values between each adjacent pair of sites in the canonized chain - Λ = [select(canonized, :between, Site(i), Site(i + 1)) for i in 1:4] + Λ = [tensors(canonized; between = (Site(i), Site(i + 1))) for i in 1:4] @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 for i in 1:5 @@ -198,7 +198,7 @@ @test isleftcanonical(canonized, Site(i)) elseif i == 5 # in the limits of the chain, we get the norm of the state contract!(canonized, :between, Site(i - 1), Site(i); direction = :right) - tensor = select(canonized, :tensor, Site(i)) + tensor = tensors(canonized; at = Site(i)) replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized)) @test isleftcanonical(canonized, Site(i)) else @@ -212,7 +212,7 @@ if i == 1 # in the limits of the chain, we get the norm of the state contract!(canonized, :between, Site(i), Site(i + 1); direction = :left) - tensor = select(canonized, :tensor, Site(i)) + tensor = tensors(canonized; at = Site(i)) replace!(TensorNetwork(canonized), tensor => tensor / norm(canonized)) @test isrightcanonical(canonized, Site(i)) elseif i == 5 @@ -235,7 +235,7 @@ @test isrightcanonical(canonized, Site(5)) @test isapprox( - contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(TensorNetwork(qtn)), ) end diff --git a/test/integration/Quac_test.jl b/test/integration/Quac_test.jl index 3a644bf..89016a2 100644 --- a/test/integration/Quac_test.jl +++ b/test/integration/Quac_test.jl @@ -13,11 +13,11 @@ # all open indices are sites siteinds = getindex.((qftqtn,), sites(qftqtn)) - @test issetequal(inds(TensorNetwork(qftqtn), :open), siteinds) + @test issetequal(inds(TensorNetwork(qftqtn); set = :open), siteinds) # all inner indices are not sites # TODO too strict condition. remove? notsiteinds = setdiff(inds(TensorNetwork(qftqtn)), siteinds) - @test_skip issetequal(inds(TensorNetwork(qftqtn), :inner), notsiteinds) + @test_skip issetequal(inds(TensorNetwork(qftqtn); set = :inner), notsiteinds) end end