Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Update evolve code
Browse files Browse the repository at this point in the history
  • Loading branch information
jofrevalles committed Jun 4, 2024
1 parent b126744 commit 8fc17ab
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -652,24 +652,27 @@ function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind)
return ψ
end

evolve::Chain, mpo::Chain) = evolve!(copy(ψ), mpo)

"""
evolve!(ψ::Chain, mpo::Chain)
Applies a Matrix Product Operator (MPO) `mpo` to the [`Chain`](@ref).
"""
function evolve::Chain, mpo::Chain)
function evolve!::Chain, mpo::Chain)
updated_tensors = Tensor[]
Λ = Tensor[]
L = nsites(ψ)

for i in 1:L
t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims = (select(ψ, :index, Site(i)),))
contractedind = select(ψ, :index, Site(i))
t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims = (contractedind,))
physicalind = select(mpo, :index, Site(i))

# Fuse the two right legs of t into one
if i == 1
wanted_inds = (physicalind, rightindex(ψ, Site(i)), rightindex(mpo, Site(i)))
new_inds = (physicalind, rightindex(ψ, Site(i)))
new_inds = (contractedind, rightindex(ψ, Site(i)))
elseif i < L
wanted_inds = (
physicalind,
Expand All @@ -678,14 +681,14 @@ function evolve(ψ::Chain, mpo::Chain)
rightindex(ψ, Site(i)),
rightindex(mpo, Site(i)),
)
new_inds = (physicalind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i)))
new_inds = (contractedind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i)))
else
wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i)))
new_inds = (physicalind, leftindex(ψ, Site(i)))
new_inds = (contractedind, leftindex(ψ, Site(i)))
end

perm = Tenet.__find_index_permutation(wanted_inds, inds(t))
t = PermutedDimsArray(parent(t), perm)
t = permutedims(t, perm)

t = Tensor(
reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)),
Expand All @@ -701,12 +704,18 @@ function evolve(ψ::Chain, mpo::Chain)
end
end

ψ_ev = MPS(updated_tensors)
for i in 1:L-1
push!(TensorNetwork(ψ_ev), Tensor(parent(Λ[i]), (rightindex(ψ_ev, Site(i)),)))

for i in 1:L
i < L && pop!(TensorNetwork(ψ), select(ψ, :between, Site(i), Site(i + 1)))
pop!(TensorNetwork(ψ), select(ψ, :tensor, Site(i)))
end

return ψ_ev
for i in 1:L
i < L && push!(TensorNetwork(ψ), Λ[i])
push!(TensorNetwork(ψ), updated_tensors[i])
end

return ψ
end

function expect::Chain, observables)
Expand Down

0 comments on commit 8fc17ab

Please sign in to comment.