Last active
July 9, 2021 20:32
-
-
Save rkube/1f9ba50fe37cc0e65583098ac982703c to your computer and use it in GitHub Desktop.
Updated code for qr-factorization pullback
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using LinearAlgebra | |
using Zygote | |
using ChainRules | |
using ChainRulesCore | |
using Random | |
using Statistics | |
using FiniteDifferences | |
using ChainRulesTestUtils | |
ChainRulesCore.debug_mode() = true | |
Random.seed!(1234); | |
function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T} | |
QR = qr(A) | |
m, n = size(A) | |
function qr_pullback(Ȳ::Tangent) | |
# For square (m=n) or tall and skinny (m >= n), use the rule derived by | |
# Seeger et al. (2019) https://arxiv.org/pdf/1710.08717.pdf | |
# | |
# Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ | |
# | |
# where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and | |
# copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)} | |
# | |
# This code is re-used in the wide case and therefore in a separate function. | |
function qr_pullback_square_deep(Q̄, R̄, A, Q, R) | |
M = R*R̄' - Q̄'*Q | |
# M <- copyltu(M) | |
M = tril(M) + transpose(tril(M,-1)) | |
Ā = (Q̄ + Q * M) / R' | |
end | |
# For the wide (m < n) case, we implement the rule derived by | |
# Liao et al. (2019) https://arxiv.org/pdf/1903.09650.pdf | |
# | |
# Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄) | |
# where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n). | |
# R = (U,V). Both X and U are full rank square matrices. | |
# | |
# See also the discussion in https://github.com/JuliaDiff/ChainRules.jl/pull/306 | |
# And https://github.com/pytorch/pytorch/blob/b162d95e461a5ea22f6840bf492a5dbb2ebbd151/torch/csrc/autograd/FunctionsManual.cpp | |
Q̄ = Ȳ.factors | |
R̄ = Ȳ.T | |
Q = QR.Q | |
R = QR.R | |
if m ≥ n | |
# qr returns the full QR factorization, including silent columns. We need to crop them | |
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : Q̄[:, axes(R, 2)] | |
Q = if m > n | |
@view Q[:, axes(R, 2)] | |
else | |
Q | |
end | |
Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R) | |
else # This is the case m < n, i.e. a short and wide matrix A | |
# partition A = [X | Y] | |
# X = A[1:m, 1:m] | |
Y = @view A[1:m, m + 1:end] | |
# partition R = [U | V], and we don't need V | |
U = R[1:m, 1:m] | |
if R̄ isa ChainRules.AbstractZero | |
V̄ = zeros(size(Y)) | |
Q̄_prime = zeros(size(Q)) | |
Ū = R̄ | |
else | |
# partition R̄ = [Ū | V̄] | |
@show size(R̄) | |
Ū = @view R̄[1:m, 1:m] | |
V̄ = @view R̄[1:m, m + 1:end] | |
@show size(Ū), size(V̄) | |
Q̄_prime = Y * V̄' | |
end | |
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄ | |
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U) | |
Ȳ = Q * V̄ | |
# partition Ā = [X̄ | Ȳ] | |
Ā = [X̄ Ȳ] | |
end | |
return (NoTangent(), Ā) | |
end | |
return QR, qr_pullback | |
end | |
function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol) | |
function getproperty_qr_pullback(Ȳ) | |
# The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see | |
# R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989). | |
# Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄ | |
# in the Tangent object. | |
∂factors = if d === :Q | |
Ȳ | |
else | |
nothing | |
end | |
∂T = if d === :R | |
@show size(Ȳ) | |
Ȳ | |
else | |
nothing | |
end | |
∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T) | |
return (NoTangent(), ∂F) | |
end | |
return getproperty(F, d), getproperty_qr_pullback | |
end | |
# Fix equality for QR objects: | |
function ChainRulesTestUtils.test_approx(actual::LinearAlgebra.QRCompactWY, expected::LinearAlgebra.QRCompactWY, msg::String; kwargs...) | |
ChainRulesTestUtils.test_approx(actual.Q, expected.Q, msg * " Q:"; kwargs...) | |
ChainRulesTestUtils.test_approx(actual.R, expected.R, msg * " R:"; kwargs...) | |
end | |
function FiniteDifferences.to_vec(x::S) where {S <: LinearAlgebra.QRCompactWY} | |
m, n = size(x) | |
Q = @view x.Q[axes(x.Q, 1), axes(x.R, 2)] | |
q_vec, q_back = to_vec(Q) | |
r_vec, r_back = to_vec(x.R) | |
function QRCompact_from_vec(v) | |
Q_new = q_back(v.Q) | |
R_new = q_back(v.R) | |
QR = S(Q_new * R_new) | |
return S(QR.factors, QR.T) | |
end | |
return [q_vec; r_vec], QRCompact_from_vec | |
end | |
function f1(V) where T | |
m, n = size(V) | |
Q, _ = qr(V) | |
Q = if m ≥ n | |
@view Q[axes(V)...] | |
else | |
Q | |
end | |
return sum(Q) | |
end | |
function f2(V) where T | |
_, R = qr(V) | |
return sum(R) | |
end | |
V1 = rand(Float32, (4, 4)); | |
V2 = randn(7, 4); | |
V3 = randn(5, 11) | |
println("f1, V1") | |
res1_V1_ad = Zygote.gradient(f1, V1) | |
res1_V1_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V1) | |
@assert res1_V1_ad[1] ≈ res1_V1_fd[1] | |
println("") | |
println("f2, V1") | |
res2_V1_ad = Zygote.gradient(f2, V1) | |
res2_V1_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V1) | |
@assert res2_V1_ad[1] ≈ res2_V1_fd[1] | |
println("") | |
println("f1, V2") | |
res1_V2_ad = Zygote.gradient(f1, V2) | |
res1_V2_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V2) | |
@assert res1_V2_ad[1] ≈ res1_V2_fd[1] | |
println("") | |
println("f2, V2") | |
res2_V2_ad = Zygote.gradient(f2, V2) | |
res2_V2_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V2) | |
@assert res2_V2_ad[1] ≈ res2_V2_fd[1] | |
println("") | |
println("f1, V3") | |
res1_V3_ad = Zygote.gradient(f1, V3) | |
res1_V3_fd = FiniteDifferences.grad(central_fdm(5,1), f1, V3) | |
@assert res1_V3_ad[1] ≈ res1_V3_fd[1] | |
println("") | |
println("f2, V3") | |
res2_V3_ad = Zygote.gradient(f2, V3) | |
res2_V3_fd = FiniteDifferences.grad(central_fdm(5,1), f2, V3) | |
@assert res2_V3_ad[1] ≈ res2_V3_fd[1] | |
# Unit test should work for small matrix | |
test_rrule(qr, V1, check_inferred=false, fdm=central_fdm(3,1), rtol=1e-3, atol=1e-3) | |
test_rrule(qr, V2, check_inferred=false, fdm=central_fdm(3,1), rtol=1e-3, atol=1e-3) | |
test_rrule(qr, V3, check_inferred=false, fdm=central_fdm(3,1), rtol=1e-3, atol=1e-3) | |
3,1 Top |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment