Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Keep index count to avoid index clash
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Mar 8, 2024
1 parent b37656c commit 4fc959a
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 43 deletions.
6 changes: 2 additions & 4 deletions ext/QrochetQuacExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ Qrochet.evolve!(qtn::Ansatz, gate::Gate) = evolve!(qtn, Qrochet.Dense(gate))

function Qrochet.Quantum(circuit::Circuit)
n = lanes(circuit)
wire = [[Tenet.letter(i)] for i in 1:n]
wire = [[Qrochet.nextindex()] for _ in 1:n]
tensors = Tensor[]

i = n + 1
for gate in circuit
G = arraytype(gate)
array = G(gate)
Expand All @@ -27,8 +26,7 @@ function Qrochet.Quantum(circuit::Circuit)
end

inds = map(lanes(gate)) do l
from, to = last(wire[l]), Tenet.letter(i)
i += 1
from, to = last(wire[l]), Qrochet.nextindex()
push!(wire[l], to)
(from, to)
end |> x -> zip(x...) |> Iterators.flatten |> collect
Expand Down
51 changes: 20 additions & 31 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Tenet
using Tenet: letter
using LinearAlgebra
using Random
using Muscle: gramschmidt!
Expand Down Expand Up @@ -32,12 +31,13 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray})
@assert all(==(3) ndims, arrays) "All arrays must have 3 dimensions"

n = length(arrays)
symbols = [nextindex() for _ in 1:2n]

_tensors = map(enumerate(arrays)) do (i, array)
Tensor(array, [letter(i), letter(n + mod1(i - 1, length(arrays))), letter(n + mod1(i, length(arrays)))])
Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]])
end

sitemap = Dict(Site(i) => letter(i) for i in 1:length(arrays))
sitemap = Dict(Site(i) => symbols[i] for i in 1:n)

Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end
Expand All @@ -48,17 +48,19 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray})
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"

n = length(arrays)
symbols = [nextindex() for _ in 1:2n-1]

_tensors = map(enumerate(arrays)) do (i, array)
if i == 1
Tensor(array, [letter(1), letter(1 + n)])
Tensor(array, [symbols[1], symbols[1+n]])
elseif i == n
Tensor(array, [letter(n), letter(n + mod1(n - 1, length(arrays)))])
Tensor(array, [symbols[n], symbols[n+mod1(n - 1, n)]])
else
Tensor(array, [letter(i), letter(n + mod1(i - 1, length(arrays))), letter(n + mod1(i, length(arrays)))])
Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]])
end
end

sitemap = Dict(Site(i) => letter(i) for i in 1:length(arrays))
sitemap = Dict(Site(i) => symbols[i] for i in 1:n)

Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end
Expand All @@ -67,21 +69,14 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray})
@assert all(==(4) ndims, arrays) "All arrays must have 3 dimensions"

n = length(arrays)
symbols = [nextindex() for _ in 1:3n]

_tensors = map(enumerate(arrays)) do (i, array)
Tensor(
array,
[
letter(i),
letter(i + n),
letter(2 * n + mod1(i - 1, length(arrays))),
letter(2 * n + mod1(i, length(arrays))),
],
)
Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]])
end

sitemap = Dict(Site(i) => letter(i) for i in 1:length(arrays))
merge!(sitemap, Dict(Site(i; dual = true) => letter(i + n) for i in 1:length(arrays)))
sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
merge!(sitemap, Dict(Site(i; dual = true) => symbols[i+n] for i in 1:n))

Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end
Expand All @@ -92,26 +87,20 @@ function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray})
@assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions"

n = length(arrays)
symbols = [nextindex() for _ in 1:3n-1]

_tensors = map(enumerate(arrays)) do (i, array)
if i == 1
Tensor(array, [letter(1), letter(n + 1), letter(1 + 2 * n)])
Tensor(array, [symbols[1], symbols[n+1], symbols[1+2n]])
elseif i == n
Tensor(array, [letter(n), letter(2 * n), letter(2 * n + mod1(n - 1, length(arrays)))])
Tensor(array, [symbols[n], symbols[2n], symbols[2n+mod1(n - 1, n)]])
else
Tensor(
array,
[
letter(i),
letter(i + n),
letter(2 * n + mod1(i - 1, length(arrays))),
letter(2 * n + mod1(i, length(arrays))),
],
)
Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]])
end
end

sitemap = Dict(Site(i) => letter(i) for i in 1:length(arrays))
merge!(sitemap, Dict(Site(i; dual = true) => letter(i + n) for i in 1:length(arrays)))
sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
merge!(sitemap, Dict(Site(i; dual = true) => symbols[i+n] for i in 1:n))

Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end
Expand Down
7 changes: 4 additions & 3 deletions src/Ansatz/Dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ function Dense(::State, array::AbstractArray)
@assert ndims(array) > 0
@assert all(>(1), size(array))

symbols = [nextindex() for _ in 1:ndims(array)]
sitemap = map(1:ndims(array)) do i
Site(i) => letter(i)
Site(i) => symbols[i]
end |> Dict{Int,Symbol}

tensor = Tensor(array, [letter(i) for i in 1:ndims(array)])
tensor = Tensor(array, symbols)

tn = TensorNetwork([tensor])
qtn = Quantum(tn, sitemap)
Expand All @@ -22,7 +23,7 @@ function Dense(::Operator, array::AbstractArray; sitemap::Vector{Site})
@assert all(>(1), size(array))
@assert length(sitemap) == ndims(array)

tensor_inds = [letter(i) for i in 1:ndims(array)]
tensor_inds = [nextindex() for _ in 1:ndims(array)]
tensor = Tensor(array, tensor_inds)
tn = TensorNetwork([tensor])

Expand Down
11 changes: 6 additions & 5 deletions src/Ansatz/Product.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using Tenet
using Tenet: letter
using LinearAlgebra

struct Product <: Ansatz
Expand All @@ -12,22 +11,24 @@ struct Product <: Ansatz
end

function Product(arrays::Vector{<:Vector})
symbols = [nextindex() for _ in 1:length(arrays)]
_tensors = map(enumerate(arrays)) do (i, array)
Tensor(array, [letter(i)])
Tensor(array, [symbols[i]])
end

sitemap = Dict(Site(i) => letter(i) for i in 1:length(arrays))
sitemap = Dict(Site(i) => symbols[i] for i in 1:length(arrays))

Product(TensorNetwork(_tensors), sitemap)
end

function Product(arrays::Vector{<:Matrix})
n = length(arrays)
symbols = [nextindex() for _ in 1:2*length(arrays)]
_tensors = map(enumerate(arrays)) do (i, array)
Tensor(array, [letter(i + n), letter(i)])
Tensor(array, [symbols[i+n], symbols[i]])
end

sitemap = merge!(Dict(Site(i, dual = true) => letter(i) for i in 1:n), Dict(Site(i) => letter(i + n) for i in 1:n))
sitemap = merge!(Dict(Site(i, dual = true) => symbols[i] for i in 1:n), Dict(Site(i) => symbols[i+n] for i in 1:n))

Product(TensorNetwork(_tensors), sitemap)
end
Expand Down
2 changes: 2 additions & 0 deletions src/Qrochet.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module Qrochet

include("Utils.jl")

include("Quantum.jl")
export Site, @site_str, isdual
export ninputs, noutputs, inputs, outputs, sites, nsites
Expand Down
6 changes: 6 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using Tenet

const __indexcounter::Threads.Atomic{Int} = Threads.Atomic{Int}(1)

currindex() = Tenet.letter(__indexcounter[])
nextindex() = Tenet.letter(Threads.atomic_add!(__indexcounter, 1))

0 comments on commit 4fc959a

Please sign in to comment.