Last active
July 8, 2021 17:54
-
-
Save rkube/ccdd21b8009e5be281f3870a0caec47c to your computer and use it in GitHub Desktop.
pullbacks for QR decomposition
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra | |
using Zygote | |
using ChainRules | |
using ChainRulesCore | |
using Random | |
using Statistics | |
using FiniteDifferences | |
using ChainRulesTestUtils | |
ChainRulesCore.debug_mode() = true | |
Random.seed!(1234); | |
function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T} | |
QR = qr(A) | |
m, n = size(A) | |
function qr_pullback(Ȳ::Tangent) | |
# For square (m=n) or tall and skinny (m >= n), use the rule derived by | |
# Seeger et al. (2019) https://arxiv.org/pdf/1710.08717.pdf | |
# | |
# Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ | |
# | |
# where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and | |
# copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)} | |
# | |
# This code is re-used in the wide case and we put it in a separate function. | |
function qr_pullback_square_deep(Q̄, R̄, A, Q, R) | |
M = R̄*R' - Q'*Q̄ | |
# M <- copyltu(M) | |
M = triu(M) + transpose(triu(M,1)) | |
Ā = (Q̄ + Q * M) / R' | |
end | |
# For the wide (m < n) case, we implement the rule derived by | |
# Liao et al. (2019) https://arxiv.org/pdf/1903.09650.pdf | |
# | |
# Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄) | |
# where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n). | |
# R = (U,V). Both X and U are full rank square matrices. | |
# | |
# See also the discussion in https://github.com/JuliaDiff/ChainRules.jl/pull/306 | |
# And https://github.com/pytorch/pytorch/blob/b162d95e461a5ea22f6840bf492a5dbb2ebbd151/torch/csrc/autograd/FunctionsManual.cpp | |
Q̄ = Ȳ.factors | |
R̄ = Ȳ.T | |
Q = QR.Q | |
R = QR.R | |
if m ≥ n | |
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : @view Q̄[:, axes(Q, 2)] | |
Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R) | |
else | |
# partition A = [X | Y] | |
# X = A[1:m, 1:m] | |
Y = A[1:m, m + 1:end] | |
# partition R = [U | V], and we don't need V | |
U = R[1:m, 1:m] | |
if R̄ isa ChainRules.AbstractZero | |
V̄ = zeros(size(Y)) | |
Q̄_prime = zeros(size(Q)) | |
Ū = R̄ | |
else | |
# partition R̄ = [Ū | V̄] | |
Ū = R̄[1:m, 1:m] | |
V̄ = R̄[1:m, m + 1:end] | |
Q̄_prime = Y * V̄' | |
end | |
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄ | |
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U) | |
Ȳ = Q * V̄ | |
# partition Ā = [X̄ | Ȳ] | |
Ā = [X̄ Ȳ] | |
end | |
return (NoTangent(), Ā) | |
end | |
return QR, qr_pullback | |
end | |
function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol) | |
function getproperty_qr_pullback(Ȳ) | |
# The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see | |
# R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989). | |
# Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄ | |
# in the Tangent object. | |
∂factors = if d === :Q | |
Ȳ | |
else | |
nothing | |
end | |
∂T = if d === :R | |
Ȳ | |
else | |
nothing | |
end | |
∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T) | |
return (NoTangent(), ∂F) | |
end | |
return getproperty(F, d), getproperty_qr_pullback | |
end | |
V1 = rand(Float32, (4, 4)); | |
V2 = randn(7, 4); | |
V3 = randn(40, 50) | |
function f1(V) where T | |
Q, _ = qr(V) | |
return sum(Q) | |
end | |
function f2(V) where T | |
_, R = qr(V) | |
return sum(R) | |
end | |
res1_V1_ad = Zygote.gradient(f1, V1) | |
res1_V1_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V1) | |
@assert res1_V1_ad[1] ≈ res1_V1_fd[1] | |
res2_V1_ad = Zygote.gradient(f2, V1) | |
res2_V1_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V1) | |
@assert res2_V1_ad[1] ≈ res2_V1_fd[1] | |
V2 = rand(Float32, (4, 6)) | |
res1_V2_ad = Zygote.gradient(f1, V2) | |
res1_V2_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V2) | |
@assert res1_V2_ad[1] ≈ res1_V2_fd[1] | |
res2_V2_ad = Zygote.gradient(f2, V2) | |
res2_V2_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V2) | |
@assert res2_V2_ad[1] ≈ res2_V2_fd[1] | |
# Fix equality for QR objects: | |
function ChainRulesTestUtils.test_approx(actual::LinearAlgebra.QRCompactWY, expected::LinearAlgebra.QRCompactWY, msg::String; kwargs...) | |
ChainRulesTestUtils.test_approx(actual.Q, expected.Q, msg * " Q:"; kwargs...) | |
ChainRulesTestUtils.test_approx(actual.R, expected.R, msg * " R:"; kwargs...) | |
end | |
# Need working to_vec in FiniteDifferences.jl | |
function FiniteDifferences.to_vec(x::S) where {S <: Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRCompactWY}} | |
x_vec, x_back = to_vec([x.Q x.R]) | |
function QRCompact_from_vec(v) | |
Q_new, R_new = x_back(v) | |
QR = S(Q_new * R_new) | |
return S(QR.factors, QR.T) | |
end | |
return x_vec, QRCompact_from_vec | |
end | |
# Unit test should work for small matrix | |
test_rrule(qr, V1, check_inferred=false, fdm=central_fdm(3,1), rtol=1e-3, atol=1e-3) | |
test_rrule(qr, V2, check_inferred=false, fdm=central_fdm(3,1), reltol=1e-3, abstol=1e-3) | |
test_rrule(qr, V3, check_inferred=false, fdm=central_fdm(3,1), reltol=1e-3, abstol=1e-3) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment