Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Add Distributed example
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jun 4, 2024
1 parent 0207695 commit e1e8eb6
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
5 changes: 5 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
ClusterManagers = "34f1f09b-3a8b-5176-ab39-66d58a4d544e"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Qrochet = "881a8f22-b5d0-48b0-96e5-a244b33f36d4"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec"
TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c"
103 changes: 103 additions & 0 deletions examples/distributed.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
using Yao: Yao
using Qrochet
using Tenet
using EinExprs
using KaHyPar
using Random
using Distributed
using ClusterManagers
using AbstractTrees

n = 64
depth = 6

circuit = Yao.chain(n)

for _ in 1:depth
perm = randperm(n)

for (i, j) in Iterators.partition(perm, 2)
push!(circuit, Yao.put((i, j) => Yao.EasyBuild.FSimGate(2π * rand(), 2π * rand())))
# push!(circuit, Yao.control(n, i, j => Yao.phase(2π * rand())))
end
end

H = Quantum(circuit)
ψ = zeros(Product, n)

tn = TensorNetwork(merge(Quantum(ψ), H, Quantum(ψ)'))
transform!(tn, Tenet.ContractSimplification())

path = einexpr(
tn,
optimizer = HyPar(
parts = 2,
imbalance = 0.41,
edge_scaler = (ind_size) -> 10 * Int(round(log2(ind_size))),
vertex_scaler = (prod_size) -> 100 * Int(round(exp2(prod_size))),
),
)

@show maximum(ndims, Branches(path))
@show maximum(length, Branches(path)) * sizeof(eltype(tn)) / 1024^3

@show log10(mapreduce(flops, +, Branches(path)))

cutinds = findslices(SizeScorer(), path, size = 2^24)
cuttings = [[i => dim for dim in 1:size(tn, i)] for i in cutinds]

# mock sliced path - valid for all slices
proj_inds = first.(cuttings)
slice_path = view(path.path, proj_inds...)

expr = Tenet.codegen(Val(:outplace), slice_path)

manager = SlurmManager(2 * 112 - 1)
addprocs(manager, cpus_per_task = 1, exeflags = "--project=$(Base.active_project())")
# @everywhere using LinearAlgebra
# @everywhere LinearAlgebra.BLAS.set_num_threads(2)

@everywhere using Tenet, EinExprs, IterTools, LinearAlgebra, Reactant, AbstractTrees
@everywhere tn = $tn
@everywhere slice_path = $slice_path
@everywhere cuttings = $cuttings
@everywhere expr = $expr

partial_results = map(enumerate(workers())) do (i, worker)
Distributed.@spawnat worker begin
# interleaved chunking without instantiation
it = takenth(Iterators.drop(Iterators.product(cuttings...), i - 1), nworkers())

f = @eval $expr
mock_slice = view(tn, first(it)...)
tensors′ = [
Tensor(Reactant.ConcreteRArray(copy(parent(mock_slice[head(leaf)...]))), inds(mock_slice[head(leaf)...])) for leaf in Leaves(slice_path)
]
g = Reactant.compile(f, Tuple(tensors′))

# local reduction of chunk
accumulator = zero(eltype(tn))

for proj_inds in it
slice = view(tn, proj_inds...)
tensors′ = [
Tensor(
Reactant.ConcreteRArray(copy(parent(mock_slice[head(leaf)...]))),
inds(mock_slice[head(leaf)...]),
) for leaf in Leaves(slice_path)
]
res = only(g(tensors′...))

# avoid OOM due to garbage accumulation
GC.gc()

accumulator += res
end

return accumulator
end
end

@show result = sum(Distributed.fetch.(partial_results))

rmprocs(workers())

0 comments on commit e1e8eb6

Please sign in to comment.