Skip to content

Instantly share code, notes, and snippets.

@rkube
Last active July 8, 2021 17:54
Show Gist options
  • Save rkube/ccdd21b8009e5be281f3870a0caec47c to your computer and use it in GitHub Desktop.
Save rkube/ccdd21b8009e5be281f3870a0caec47c to your computer and use it in GitHub Desktop.
pullbacks for QR decomposition
using LinearAlgebra
using Zygote
using ChainRules
using ChainRulesCore
using Random
using Statistics
using FiniteDifferences
using ChainRulesTestUtils
ChainRulesCore.debug_mode() = true
Random.seed!(1234);
function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T}
QR = qr(A)
m, n = size(A)
function qr_pullback(Ȳ::Tangent)
# For square (m=n) or tall and skinny (m >= n), use the rule derived by
# Seeger et al. (2019) https://arxiv.org/pdf/1710.08717.pdf
#
# Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ
#
# where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and
# copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)}
#
# This code is re-used in the wide case and we put it in a separate function.
function qr_pullback_square_deep(Q̄, R̄, A, Q, R)
M = R̄*R' - Q'*Q̄
# M <- copyltu(M)
M = triu(M) + transpose(triu(M,1))
Ā = (Q̄ + Q * M) / R'
end
# For the wide (m < n) case, we implement the rule derived by
# Liao et al. (2019) https://arxiv.org/pdf/1903.09650.pdf
#
# Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄)
# where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n).
# R = (U,V). Both X and U are full rank square matrices.
#
# See also the discussion in https://github.com/JuliaDiff/ChainRules.jl/pull/306
# And https://github.com/pytorch/pytorch/blob/b162d95e461a5ea22f6840bf492a5dbb2ebbd151/torch/csrc/autograd/FunctionsManual.cpp
Q̄ = Ȳ.factors
R̄ = Ȳ.T
Q = QR.Q
R = QR.R
if m ≥ n
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : @view Q̄[:, axes(Q, 2)]
Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R)
else
# partition A = [X | Y]
# X = A[1:m, 1:m]
Y = A[1:m, m + 1:end]
# partition R = [U | V], and we don't need V
U = R[1:m, 1:m]
if R̄ isa ChainRules.AbstractZero
V̄ = zeros(size(Y))
Q̄_prime = zeros(size(Q))
Ū = R̄
else
# partition R̄ = [Ū | V̄]
Ū = R̄[1:m, 1:m]
V̄ = R̄[1:m, m + 1:end]
Q̄_prime = Y * V̄'
end
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U)
Ȳ = Q * V̄
# partition Ā = [X̄ | Ȳ]
Ā = [X̄ Ȳ]
end
return (NoTangent(), Ā)
end
return QR, qr_pullback
end
function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol)
function getproperty_qr_pullback(Ȳ)
# The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see
# R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989).
# Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄
# in the Tangent object.
∂factors = if d === :Q
else
nothing
end
∂T = if d === :R
else
nothing
end
∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T)
return (NoTangent(), ∂F)
end
return getproperty(F, d), getproperty_qr_pullback
end
V1 = rand(Float32, (4, 4));
V2 = randn(7, 4);
V3 = randn(40, 50)
function f1(V) where T
Q, _ = qr(V)
return sum(Q)
end
function f2(V) where T
_, R = qr(V)
return sum(R)
end
res1_V1_ad = Zygote.gradient(f1, V1)
res1_V1_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V1)
@assert res1_V1_ad[1] ≈ res1_V1_fd[1]
res2_V1_ad = Zygote.gradient(f2, V1)
res2_V1_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V1)
@assert res2_V1_ad[1] ≈ res2_V1_fd[1]
V2 = rand(Float32, (4, 6))
res1_V2_ad = Zygote.gradient(f1, V2)
res1_V2_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V2)
@assert res1_V2_ad[1] ≈ res1_V2_fd[1]
res2_V2_ad = Zygote.gradient(f2, V2)
res2_V2_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V2)
@assert res2_V2_ad[1] ≈ res2_V2_fd[1]
# Fix equality for QR objects:
function ChainRulesTestUtils.test_approx(actual::LinearAlgebra.QRCompactWY, expected::LinearAlgebra.QRCompactWY, msg::String; kwargs...)
ChainRulesTestUtils.test_approx(actual.Q, expected.Q, msg * " Q:"; kwargs...)
ChainRulesTestUtils.test_approx(actual.R, expected.R, msg * " R:"; kwargs...)
end
# Need working to_vec in FiniteDifferences.jl
function FiniteDifferences.to_vec(x::S) where {S <: Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRCompactWY}}
x_vec, x_back = to_vec([x.Q x.R])
function QRCompact_from_vec(v)
Q_new, R_new = x_back(v)
QR = S(Q_new * R_new)
return S(QR.factors, QR.T)
end
return x_vec, QRCompact_from_vec
end
# Unit test should work for small matrix
test_rrule(qr, V1, check_inferred=false, fdm=central_fdm(3,1), rtol=1e-3, atol=1e-3)
test_rrule(qr, V2, check_inferred=false, fdm=central_fdm(3,1), reltol=1e-3, abstol=1e-3)
test_rrule(qr, V3, check_inferred=false, fdm=central_fdm(3,1), reltol=1e-3, abstol=1e-3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment