Last active
April 2, 2019 10:11
-
-
Save GiggleLiu/0fb539d1a453b2cc6aca769d14d2cc79 to your computer and use it in GitHub Desktop.
Automatic Differentiation over Tensor Renormalization Group
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#= | |
To use this script, please install Flux.jl first by typing `]add Flux` in a Julia REPL. | |
=# | |
import LinearAlgebra | |
using Flux.Tracker: @grad, data, track, TrackedTuple, TrackedArray | |
using Flux | |
import Flux.Tracker: _forward | |
""" | |
backward function for svd. | |
""" | |
function svd_back(U, S, V, dU, dS, dV) | |
NS = length(S) | |
S2 = S.^2 | |
Sinv = @. S/(S2+1e-12) | |
F = S2' .- S2 | |
@. F = F/(F^2+1e-12) | |
UdU = U'*dU | |
VdV = V'*dV | |
Su = (F.*(UdU-UdU'))*LinearAlgebra.Diagonal(S) | |
Sv = LinearAlgebra.Diagonal(S) * (F.*(VdV-VdV')) | |
U * (Su + Sv + LinearAlgebra.Diagonal(dS)) * V' + | |
(LinearAlgebra.I - U*U') * dU*LinearAlgebra.Diagonal(Sinv) * V' + | |
U*LinearAlgebra.Diagonal(Sinv) * dV' * (LinearAlgebra.I - V*V') | |
end | |
"""redefine svd!, using Tuple as output.""" | |
svd!(A) = LinearAlgebra.svd!(A) | |
svd!(A::TrackedArray) = track(svd!, A) | |
svd(A) = svd!(A |> copy) | |
function _forward(::typeof(svd!), a) | |
U, S, V = LinearAlgebra.svd!(data(a)) | |
(U, S, Matrix(V)), Δ -> (svd_back(U, S, V, Δ...),) | |
end | |
Base.iterate(xs::TrackedTuple, state=1) = state > length(xs) ? nothing : (xs[state], state+1) | |
""" | |
svd decoposition for trg, `Ma` is a input matrix, it returns two rank 3 tensors. | |
""" | |
function trg_svd(Ma, Dmax; tol::Float64=1e-12) | |
U, S, V = svd(Ma) | |
Dmax = min(searchsorted(S, tol, rev=true).stop, Dmax) | |
D = isqrt(size(Ma, 1)) | |
FS = S[1:Dmax] | |
S1 = reshape(view(U,:,1:Dmax) .* sqrt.(FS'), (D, D, Dmax)) | |
S3 = reshape(sqrt.(FS) .* view(V',1:Dmax,:), (Dmax, D, D)) | |
S1, S3 | |
end | |
""" | |
TRG(K::RT, Dcut::Int, niter::Int) -> RT | |
TRG main program, `Dcut` is the maximum bond dimension. | |
""" | |
function TRG(K::RT, Dcut::Int, niter::Int) where RT | |
D = 2 | |
inds = 1:D | |
M = [sqrt(cosh(K)) sqrt(sinh(K)); | |
sqrt(cosh(K)) -sqrt(sinh(K))] | |
T = [mapreduce(a->M[a, i] * M[a, j] * M[a, k] * M[a, l], +, inds) for i in inds, j in inds, k in inds, l in inds] | |
eltype(T) <: Tracker.TrackedReal && (T = Tracker.collect(T)) | |
lnZ = zero(RT) | |
for n in 1:niter | |
maxval = maximum(T) | |
T = T/maxval | |
lnZ += 2^(niter-n+1)*log(maxval) | |
D = size(T, 1) | |
Ma = reshape(permutedims(T, (3, 2, 1, 4)), (D^2, D^2)) | |
Mb = reshape(permutedims(T, (4, 3, 2, 1)), (D^2, D^2)) | |
S1, S3 = trg_svd(Ma, Dcut) | |
S2, S4 = trg_svd(Mb, Dcut) | |
# T[r, u, l, d] := S1[w, a, r] * S2[a, b, u] * S3[l, b, g] * S4[d, g, w] | |
Dc = size(S1, 3) | |
S12 = reshape(permutedims(S1, (1,3,2)), D*Dc, D) * reshape(S2, D, D*Dc) # wr,bu | |
S123 = reshape(permutedims(reshape(S12, D*Dc, D, Dc), (1,3,2)), D*Dc*Dc, D)*reshape(permutedims(S3, (2,1,3)), D, Dc*D) # wru,lg | |
T = reshape(permutedims(reshape(S123, D, Dc*Dc*Dc, D), (2,3,1)), Dc*Dc*Dc, D*D)* reshape(permutedims(S4, (2,3,1)), D*D, Dc) # rul,d | |
T = reshape(T, Dc, Dc, Dc, Dc) | |
end | |
trace = zero(RT) | |
for i in 1:size(T, 1) | |
for j in 1:size(T, 2) | |
trace += T[i, j, i, j] | |
end | |
end | |
lnZ += log(trace) | |
end | |
#################### The following is a test for the correctness of autodiff #################### | |
using Test | |
"""numerical gradient""" | |
function num_grad(K::Float64, Dcut, niter; δ::Float64=1e-5) | |
x = K - δ/2 | |
Z0 = TRG(x, Dcut, niter) | |
x = K + δ/2 | |
Z1 = TRG(x, Dcut, niter) | |
(Z1-Z0)/δ | |
end | |
K = 0.5 | |
@test isapprox(Tracker.gradient(x->TRG(x, 24, 20), K)[1], num_grad(data(K), 24, 20), rtol=1e-3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment