Skip to content

Instantly share code, notes, and snippets.

@GiggleLiu
Last active April 2, 2019 10:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save GiggleLiu/0fb539d1a453b2cc6aca769d14d2cc79 to your computer and use it in GitHub Desktop.
Save GiggleLiu/0fb539d1a453b2cc6aca769d14d2cc79 to your computer and use it in GitHub Desktop.
Automatic Differentiation over Tensor Renormalization Group
#=
To use this script, please install Flux.jl first by typing `]add Flux` in a Julia REPL.
=#
import LinearAlgebra
using Flux.Tracker: @grad, data, track, TrackedTuple, TrackedArray
using Flux
import Flux.Tracker: _forward
"""
backward function for svd.
"""
function svd_back(U, S, V, dU, dS, dV)
NS = length(S)
S2 = S.^2
Sinv = @. S/(S2+1e-12)
F = S2' .- S2
@. F = F/(F^2+1e-12)
UdU = U'*dU
VdV = V'*dV
Su = (F.*(UdU-UdU'))*LinearAlgebra.Diagonal(S)
Sv = LinearAlgebra.Diagonal(S) * (F.*(VdV-VdV'))
U * (Su + Sv + LinearAlgebra.Diagonal(dS)) * V' +
(LinearAlgebra.I - U*U') * dU*LinearAlgebra.Diagonal(Sinv) * V' +
U*LinearAlgebra.Diagonal(Sinv) * dV' * (LinearAlgebra.I - V*V')
end
"""redefine svd!, using Tuple as output."""
svd!(A) = LinearAlgebra.svd!(A)
svd!(A::TrackedArray) = track(svd!, A)
svd(A) = svd!(A |> copy)
function _forward(::typeof(svd!), a)
U, S, V = LinearAlgebra.svd!(data(a))
(U, S, Matrix(V)), Δ -> (svd_back(U, S, V, Δ...),)
end
Base.iterate(xs::TrackedTuple, state=1) = state > length(xs) ? nothing : (xs[state], state+1)
"""
svd decoposition for trg, `Ma` is a input matrix, it returns two rank 3 tensors.
"""
function trg_svd(Ma, Dmax; tol::Float64=1e-12)
U, S, V = svd(Ma)
Dmax = min(searchsorted(S, tol, rev=true).stop, Dmax)
D = isqrt(size(Ma, 1))
FS = S[1:Dmax]
S1 = reshape(view(U,:,1:Dmax) .* sqrt.(FS'), (D, D, Dmax))
S3 = reshape(sqrt.(FS) .* view(V',1:Dmax,:), (Dmax, D, D))
S1, S3
end
"""
TRG(K::RT, Dcut::Int, niter::Int) -> RT
TRG main program, `Dcut` is the maximum bond dimension.
"""
function TRG(K::RT, Dcut::Int, niter::Int) where RT
D = 2
inds = 1:D
M = [sqrt(cosh(K)) sqrt(sinh(K));
sqrt(cosh(K)) -sqrt(sinh(K))]
T = [mapreduce(a->M[a, i] * M[a, j] * M[a, k] * M[a, l], +, inds) for i in inds, j in inds, k in inds, l in inds]
eltype(T) <: Tracker.TrackedReal && (T = Tracker.collect(T))
lnZ = zero(RT)
for n in 1:niter
maxval = maximum(T)
T = T/maxval
lnZ += 2^(niter-n+1)*log(maxval)
D = size(T, 1)
Ma = reshape(permutedims(T, (3, 2, 1, 4)), (D^2, D^2))
Mb = reshape(permutedims(T, (4, 3, 2, 1)), (D^2, D^2))
S1, S3 = trg_svd(Ma, Dcut)
S2, S4 = trg_svd(Mb, Dcut)
# T[r, u, l, d] := S1[w, a, r] * S2[a, b, u] * S3[l, b, g] * S4[d, g, w]
Dc = size(S1, 3)
S12 = reshape(permutedims(S1, (1,3,2)), D*Dc, D) * reshape(S2, D, D*Dc) # wr,bu
S123 = reshape(permutedims(reshape(S12, D*Dc, D, Dc), (1,3,2)), D*Dc*Dc, D)*reshape(permutedims(S3, (2,1,3)), D, Dc*D) # wru,lg
T = reshape(permutedims(reshape(S123, D, Dc*Dc*Dc, D), (2,3,1)), Dc*Dc*Dc, D*D)* reshape(permutedims(S4, (2,3,1)), D*D, Dc) # rul,d
T = reshape(T, Dc, Dc, Dc, Dc)
end
trace = zero(RT)
for i in 1:size(T, 1)
for j in 1:size(T, 2)
trace += T[i, j, i, j]
end
end
lnZ += log(trace)
end
#################### The following is a test for the correctness of autodiff ####################
using Test
"""numerical gradient"""
function num_grad(K::Float64, Dcut, niter; δ::Float64=1e-5)
x = K - δ/2
Z0 = TRG(x, Dcut, niter)
x = K + δ/2
Z1 = TRG(x, Dcut, niter)
(Z1-Z0)/δ
end
K = 0.5
@test isapprox(Tracker.gradient(x->TRG(x, 24, 20), K)[1], num_grad(data(K), 24, 20), rtol=1e-3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment