Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Add distributed contraction example with Dagger
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Jun 4, 2024
1 parent e89322d commit e6b41ee
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 0 deletions.
11 changes: 11 additions & 0 deletions examples/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
KaHyPar = "2a6221f6-aa48-11e9-3542-2d9e0ef01880"
Qrochet = "881a8f22-b5d0-48b0-96e5-a244b33f36d4"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
Tenet = "85d41934-b9cd-44e1-8730-56d86f15f3ec"
TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
Yao = "5872b779-8223-5990-8dd0-5abbb0748c8c"
71 changes: 71 additions & 0 deletions examples/dagger.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using Tenet
using Qrochet
using Yao: Yao
using EinExprs
using AbstractTrees
using Distributed
using Dagger
using TimespanLogging
using KaHyPar

m = 10
circuit = Yao.EasyBuild.rand_google53(m);
H = Quantum(circuit)
ψ = Product(fill([1, 0], Yao.nqubits(circuit)))
qtn = merge(Quantum(ψ), H, Quantum(ψ)')
tn = Tenet.TensorNetwork(qtn)

contract_smaller_dims = 20
target_size = 24

Tenet.transform!(tn, Tenet.ContractSimplification())
path = einexpr(
tn,
optimizer = HyPar(
parts = 2,
imbalance = 0.41,
edge_scaler = (ind_size) -> 10 * Int(round(log2(ind_size))),
vertex_scaler = (prod_size) -> 100 * Int(round(exp2(prod_size))),
),
);

max_dims_path = @show maximum(ndims, Branches(path))
flops_path = @show mapreduce(flops, +, Branches(path))
@show log10(flops_path)

grouppath = deepcopy(path);
function recursiveforeach!(f, expr)
f(expr)
foreach(arg -> recursiveforeach!(f, arg), args(expr))
end
sizedict = merge(Iterators.map(i -> i.size, Leaves(path))...);
recursiveforeach!(grouppath) do expr
merge!(expr.size, sizedict)
if all(<(contract_smaller_dims) ndims, expr.args)
empty!(expr.args)
end
end

max_dims_grouppath = maximum(ndims, Branches(grouppath))
flops_grouppath = mapreduce(flops, +, Branches(grouppath))
targetinds = findslices(SizeScorer(), grouppath, size = 2^(target_size));

subexprs = map(Leaves(grouppath)) do expr
EinExprs.select(path, tuple(head(expr)...)) |> only
end

addprocs(3)
@everywhere using Dagger, Tenet

disttn = Tenet.TensorNetwork(
map(subexprs) do subexpr
Tensor(
distribute( # data
parent(Tenet.contract(tn; path = subexpr)),
Blocks([i targetinds ? 1 : 2 for i in head(subexpr)]...),
),
head(subexpr), # inds
)
end,
)
@show Tenet.contract(disttn; path = grouppath)

0 comments on commit e6b41ee

Please sign in to comment.