Skip to content

Instantly share code, notes, and snippets.

@rkube
Last active July 9, 2021 20:32
Show Gist options
  • Save rkube/1f9ba50fe37cc0e65583098ac982703c to your computer and use it in GitHub Desktop.
Save rkube/1f9ba50fe37cc0e65583098ac982703c to your computer and use it in GitHub Desktop.
Updated code for qr-factorization pullback
using LinearAlgebra
using Zygote
using ChainRules
using ChainRulesCore
using Random
using Statistics
using FiniteDifferences
using ChainRulesTestUtils
ChainRulesCore.debug_mode() = true
Random.seed!(1234);
function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T}
QR = qr(A)
m, n = size(A)
function qr_pullback(Ȳ::Tangent)
# For square (m=n) or tall and skinny (m >= n), use the rule derived by
# Seeger et al. (2019) https://arxiv.org/pdf/1710.08717.pdf
#
# Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ
#
# where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and
# copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)}
#
# This code is re-used in the wide case and therefore in a separate function.
function qr_pullback_square_deep(Q̄, R̄, A, Q, R)
M = R*R̄' - Q̄'*Q
# M <- copyltu(M)
M = tril(M) + transpose(tril(M,-1))
Ā = (Q̄ + Q * M) / R'
end
# For the wide (m < n) case, we implement the rule derived by
# Liao et al. (2019) https://arxiv.org/pdf/1903.09650.pdf
#
# Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄)
# where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n).
# R = (U,V). Both X and U are full rank square matrices.
#
# See also the discussion in https://github.com/JuliaDiff/ChainRules.jl/pull/306
# And https://github.com/pytorch/pytorch/blob/b162d95e461a5ea22f6840bf492a5dbb2ebbd151/torch/csrc/autograd/FunctionsManual.cpp
Q̄ = Ȳ.factors
R̄ = Ȳ.T
Q = QR.Q
R = QR.R
if m ≥ n
# qr returns the full QR factorization, including silent columns. We need to crop them
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : Q̄[:, axes(R, 2)]
Q = if m > n
@view Q[:, axes(R, 2)]
else
Q
end
Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R)
else # This is the case m < n, i.e. a short and wide matrix A
# partition A = [X | Y]
# X = A[1:m, 1:m]
Y = @view A[1:m, m + 1:end]
# partition R = [U | V], and we don't need V
U = R[1:m, 1:m]
if R̄ isa ChainRules.AbstractZero
V̄ = zeros(size(Y))
Q̄_prime = zeros(size(Q))
Ū = R̄
else
# partition R̄ = [Ū | V̄]
@show size(R̄)
Ū = @view R̄[1:m, 1:m]
V̄ = @view R̄[1:m, m + 1:end]
@show size(Ū), size(V̄)
Q̄_prime = Y * V̄'
end
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U)
Ȳ = Q * V̄
# partition Ā = [X̄ | Ȳ]
Ā = [X̄ Ȳ]
end
return (NoTangent(), Ā)
end
return QR, qr_pullback
end
function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol)
function getproperty_qr_pullback(Ȳ)
# The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see
# R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989).
# Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄
# in the Tangent object.
∂factors = if d === :Q
else
nothing
end
∂T = if d === :R
@show size(Ȳ)
else
nothing
end
∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T)
return (NoTangent(), ∂F)
end
return getproperty(F, d), getproperty_qr_pullback
end
# Fix equality for QR objects:
function ChainRulesTestUtils.test_approx(actual::LinearAlgebra.QRCompactWY, expected::LinearAlgebra.QRCompactWY, msg::String; kwargs...)
ChainRulesTestUtils.test_approx(actual.Q, expected.Q, msg * " Q:"; kwargs...)
ChainRulesTestUtils.test_approx(actual.R, expected.R, msg * " R:"; kwargs...)
end
function FiniteDifferences.to_vec(x::S) where {S <: LinearAlgebra.QRCompactWY}
m, n = size(x)
Q = @view x.Q[axes(x.Q, 1), axes(x.R, 2)]
q_vec, q_back = to_vec(Q)
r_vec, r_back = to_vec(x.R)
function QRCompact_from_vec(v)
Q_new = q_back(v.Q)
R_new = q_back(v.R)
QR = S(Q_new * R_new)
return S(QR.factors, QR.T)
end
return [q_vec; r_vec], QRCompact_from_vec
end
function f1(V) where T
m, n = size(V)
Q, _ = qr(V)
Q = if m ≥ n
@view Q[axes(V)...]
else
Q
end
return sum(Q)
end
function f2(V) where T
_, R = qr(V)
return sum(R)
end
V1 = rand(Float32, (4, 4));
V2 = randn(7, 4);
V3 = randn(5, 11)
println("f1, V1")
res1_V1_ad = Zygote.gradient(f1, V1)
res1_V1_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V1)
@assert res1_V1_ad[1] ≈ res1_V1_fd[1]
println("")
println("f2, V1")
res2_V1_ad = Zygote.gradient(f2, V1)
res2_V1_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V1)
@assert res2_V1_ad[1] ≈ res2_V1_fd[1]
println("")
println("f1, V2")
res1_V2_ad = Zygote.gradient(f1, V2)
res1_V2_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V2)
@assert res1_V2_ad[1] ≈ res1_V2_fd[1]
println("")
println("f2, V2")
res2_V2_ad = Zygote.gradient(f2, V2)
res2_V2_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V2)
@assert res2_V2_ad[1] ≈ res2_V2_fd[1]
println("")
println("f1, V3")
res1_V3_ad = Zygote.gradient(f1, V3)
res1_V3_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V3)
@assert res1_V3_ad[1] ≈ res1_V3_fd[1]
println("")
println("f2, V3")
res2_V3_ad = Zygote.gradient(f2, V3)
res2_V3_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V3)
@assert res2_V3_ad[1] ≈ res2_V3_fd[1]
# Unit test should work for small matrix
test_rrule(qr, V1, check_inferred=false, fdm=central_fdm(3,1), rtol=1e-3, atol=1e-3)
test_rrule(qr, V2, check_inferred=false, fdm=central_fdm(3,1), rtol=1e-3, atol=1e-3)
test_rrule(qr, V3, check_inferred=false, fdm=central_fdm(3,1), rtol=1e-3, atol=1e-3)
3,1 Top
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment