diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 09ed2a4..5e14081 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -652,24 +652,27 @@ function unpack_2sitewf!(ψ::Chain, bond, left_inds, right_inds, virtualind) return ψ end +evolve(ψ::Chain, mpo::Chain) = evolve!(copy(ψ), mpo) + """ evolve!(ψ::Chain, mpo::Chain) Applies a Matrix Product Operator (MPO) `mpo` to the [`Chain`](@ref). """ -function evolve(ψ::Chain, mpo::Chain) +function evolve!(ψ::Chain, mpo::Chain) updated_tensors = Tensor[] Λ = Tensor[] L = nsites(ψ) for i in 1:L - t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims = (select(ψ, :index, Site(i)),)) + contractedind = select(ψ, :index, Site(i)) + t = contract(select(ψ, :tensor, Site(i)), select(mpo, :tensor, Site(i)); dims = (contractedind,)) physicalind = select(mpo, :index, Site(i)) # Fuse the two right legs of t into one if i == 1 wanted_inds = (physicalind, rightindex(ψ, Site(i)), rightindex(mpo, Site(i))) - new_inds = (physicalind, rightindex(ψ, Site(i))) + new_inds = (contractedind, rightindex(ψ, Site(i))) elseif i < L wanted_inds = ( physicalind, @@ -678,14 +681,14 @@ function evolve(ψ::Chain, mpo::Chain) rightindex(ψ, Site(i)), rightindex(mpo, Site(i)), ) - new_inds = (physicalind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) + new_inds = (contractedind, leftindex(ψ, Site(i)), rightindex(ψ, Site(i))) else wanted_inds = (physicalind, leftindex(ψ, Site(i)), leftindex(mpo, Site(i))) - new_inds = (physicalind, leftindex(ψ, Site(i))) + new_inds = (contractedind, leftindex(ψ, Site(i))) end perm = Tenet.__find_index_permutation(wanted_inds, inds(t)) - t = PermutedDimsArray(parent(t), perm) + t = permutedims(t, perm) t = Tensor( reshape(t, tuple(size(t, 1), [size(t, k) * size(t, k + 1) for k in 2:2:length(wanted_inds)]...)), @@ -701,12 +704,18 @@ function evolve(ψ::Chain, mpo::Chain) end end - ψ_ev = MPS(updated_tensors) - for i in 1:L-1 - push!(TensorNetwork(ψ_ev), Tensor(parent(Λ[i]), (rightindex(ψ_ev, Site(i)),))) + + for i in 1:L + i < L && pop!(TensorNetwork(ψ), select(ψ, :between, Site(i), Site(i + 1))) + pop!(TensorNetwork(ψ), select(ψ, :tensor, Site(i))) end - return ψ_ev + for i in 1:L + i < L && push!(TensorNetwork(ψ), Λ[i]) + push!(TensorNetwork(ψ), updated_tensors[i]) + end + + return ψ end function expect(ψ::Chain, observables)