Skip to content

Commit

Permalink
Fix truncate! function (#254)
Browse files Browse the repository at this point in the history
* Fix truncate function and add small test

* Fix mixed_canonize! function

* Add form tests

* Format code

* Fix code

* Fix orthog_center field in MixedCanonical form

* Enhance mixed_canonize! tests

* Add recanonize kwarg for truncate(::Canonical, ...) function

* Small fixes on check_form functions

* Small fixes on tests

* Format code

* Add comment
  • Loading branch information
jofrevalles authored Nov 20, 2024
1 parent 1a9540d commit 4b8c356
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 35 deletions.
32 changes: 22 additions & 10 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct NonCanonical <: Form end
left of the orthogonality center are left-canonical and the tensors to the right are right-canonical.
"""
struct MixedCanonical <: Form
orthog_center::Union{Site,Vector{Site}}
orthog_center::Union{Site,Vector{<:Site}}
end

"""
Expand Down Expand Up @@ -255,8 +255,8 @@ Truncate the dimension of the virtual `bond`` of an [`Ansatz`](@ref) Tensor Netw
- Either `threshold` or `maxdim` must be provided. If both are provided, `maxdim` is used.
"""
function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing)
return truncate!(form(tn), tn, bond; threshold, maxdim)
function truncate!(tn::AbstractAnsatz, bond; threshold=nothing, maxdim=nothing, kwargs...)
return truncate!(form(tn), tn, bond; threshold, maxdim, kwargs...)
end

"""
Expand Down Expand Up @@ -290,14 +290,18 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim,

spectrum = parent(tensors(tn; bond))

maxdim = isnothing(maxdim) ? size(tn, virtualind) : maxdim
maxdim = isnothing(maxdim) ? size(tn, virtualind) : min(maxdim, length(spectrum))

extent = if isnothing(threshold)
1:maxdim
else
1:something(findfirst(1:maxdim) do i
# Find the first index where the condition is met
found_index = findfirst(1:maxdim) do i
abs(spectrum[i]) < threshold
end - 1, maxdim)
end

# If no index is found, return 1:length(spectrum), otherwise calculate the range
1:(isnothing(found_index) ? maxdim : found_index - 1)
end

slice!(tn, virtualind, extent)
Expand All @@ -308,13 +312,21 @@ end
function truncate!(::MixedCanonical, tn::AbstractAnsatz, bond; threshold, maxdim)
# move orthogonality center to bond
mixed_canonize!(tn, bond)
return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false)
return truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=true)
end

function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim)
"""
truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true)
Truncate the dimension of the virtual `bond` of a [`Canonical`](@ref) Tensor Network by keeping the `maxdim` largest
**Schmidt coefficients** or those larger than `threshold`, and then recanonizes the Tensor Network if `recanonize` is `true`.
"""
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim, recanonize=true)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false)
# requires a sweep to recanonize the TN
return canonize!(tn)

recanonize && canonize!(tn)

return tn
end

overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)'))
Expand Down
36 changes: 27 additions & 9 deletions src/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,26 @@ function MPS(::Canonical, arrays, λ; order=defaultorder(MPS), check=true)
return mps
end

"""
check_form(mps::AbstractMPO)
Check if the tensors in the mps are in the proper [`Form`](@ref).
"""
check_form(mps::AbstractMPO) = check_form(form(mps), mps)

function check_form(config::MixedCanonical, mps::AbstractMPO)
orthog_center = config.orthog_center

left, right = if orthog_center isa Site
id(orthog_center) .+ (0, 0) # So left and right get the same value
elseif orthog_center isa Vector{<:Site}
extrema(id.(orthog_center))
end

for i in 1:nsites(mps)
if i < id(orthog_center) # Check left-canonical tensors
if i < left # Check left-canonical tensors
isisometry(mps, Site(i); dir=:right) || throw(ArgumentError("Tensors are not left-canonical"))
elseif i > id(orthog_center) # Check right-canonical tensors
elseif i > right # Check right-canonical tensors
isisometry(mps, Site(i); dir=:left) || throw(ArgumentError("Tensors are not right-canonical"))
end
end
Expand All @@ -143,8 +155,7 @@ end

function check_form(::Canonical, mps::AbstractMPO)
for i in 1:nsites(mps)
if i > 1
!isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right)
if i > 1 && !isisometry(contract(mps; between=(Site(i - 1), Site(i)), direction=:right), Site(i); dir=:right)
throw(ArgumentError("Can not form a left-canonical tensor in Site($i) from Γ and λ contraction."))
end

Expand All @@ -157,6 +168,8 @@ function check_form(::Canonical, mps::AbstractMPO)
return true
end

check_form(::NonCanonical, mps::AbstractMPO) = true

"""
MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO))
Expand Down Expand Up @@ -504,19 +517,24 @@ end
# TODO dispatch on form
# TODO generalize to AbstractAnsatz
function mixed_canonize!(tn::AbstractMPO, orthog_center)
left, right = if orthog_center isa Site
id(orthog_center) .+ (-1, 1)
elseif orthog_center isa Vector{<:Site}
extrema(id.(orthog_center)) .+ (-1, 1)
else
throw(ArgumentError("`orthog_center` must be a `Site` or a `Vector{Site}`"))
end

# left-to-right QR sweep (left-canonical tensors)
for i in 1:(id(orthog_center) - 1)
for i in 1:left
canonize_site!(tn, Site(i); direction=:right, method=:qr)
end

# right-to-left QR sweep (right-canonical tensors)
for i in nsites(tn):-1:(id(orthog_center) + 1)
for i in nsites(tn):-1:right
canonize_site!(tn, Site(i); direction=:left, method=:qr)
end

# center SVD sweep to get singular values
# canonize_site!(tn, orthog_center; direction=:left, method=:svd)

tn.form = MixedCanonical(orthog_center)

return tn
Expand Down
71 changes: 55 additions & 16 deletions test/MPS_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,36 @@ using LinearAlgebra
end

@testset "truncate!" begin
ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)])
canonize_site!(ψ, Site(2); direction=:right, method=:svd)
@testset "NonCanonical" begin
ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)])
canonize_site!(ψ, Site(2); direction=:right, method=:svd)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=1)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1

singular_values = tensors(ψ; between=(site"2", site"3"))
truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1

# If maxdim > size(spectrum), the bond dimension is not truncated
truncated = truncate(ψ, [site"2", site"3"]; maxdim=4)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
end

@testset "Canonical" begin
ψ = rand(MPS; n=5, maxdim=16)
canonize!(ψ)

truncated = truncate(ψ, [site"2", site"3"]; maxdim=2)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 2
end

truncated = truncate(ψ, [site"2", site"3"]; maxdim=1)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1
@testset "MixedCanonical" begin
ψ = rand(MPS; n=5, maxdim=16)

singular_values = tensors(ψ; between=(site"2", site"3"))
truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1
truncated = truncate(ψ, [site"2", site"3"]; maxdim=3)
@test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 3
end
end

@testset "norm" begin
Expand Down Expand Up @@ -206,18 +227,36 @@ using LinearAlgebra
end

@testset "mixed_canonize!" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
canonized = mixed_canonize(ψ, site"3")
@testset "single Site" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
canonized = mixed_canonize(ψ, site"3")
@test Tenet.check_form(canonized)

@test form(canonized) isa MixedCanonical
@test form(canonized).orthog_center == site"3"

@test isisometry(canonized, site"1"; dir=:right)
@test isisometry(canonized, site"2"; dir=:right)
@test isisometry(canonized, site"4"; dir=:left)
@test isisometry(canonized, site"5"; dir=:left)

@test form(canonized) isa MixedCanonical
@test form(canonized).orthog_center == site"3"
@test contract(canonized) contract(ψ)
end

@testset "multiple Sites" begin
ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
canonized = mixed_canonize(ψ, [site"2", site"3"])

@test isisometry(canonized, site"1"; dir=:right)
@test isisometry(canonized, site"2"; dir=:right)
@test isisometry(canonized, site"4"; dir=:left)
@test isisometry(canonized, site"5"; dir=:left)
@test Tenet.check_form(canonized)
@test form(canonized) isa MixedCanonical
@test form(canonized).orthog_center == [site"2", site"3"]

@test contract(canonized) contract(ψ)
@test isisometry(canonized, site"1"; dir=:right)
@test isisometry(canonized, site"4"; dir=:left)
@test isisometry(canonized, site"5"; dir=:left)

@test contract(canonized) contract(ψ)
end
end

@testset "expect" begin
Expand Down

0 comments on commit 4b8c356

Please sign in to comment.