Created
August 30, 2021 17:39
-
-
Save rkube/b965267944115af7d13b3f00e7533572 to your computer and use it in GitHub Desktop.
Backpropagating through QR-factorization code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CUDA | |
using LinearAlgebra | |
using Random | |
using Zygote | |
using ChainRules | |
using ChainRulesCore | |
CUDA.allowscalar(false) | |
Random.seed!(1234) | |
function Base.similar(a::CUDA.CUSOLVER.CuQRPackedQ{S,T}) where {S,T} | |
CuArray{S, ndims(a)}(undef, size(a)) | |
end | |
Zygote.@adjoint function sum(xs::CUDA.CUSOLVER.CuQRPackedQ; dims = :) | |
sum(xs), Δ -> (fill!(similar(xs), Δ), ) | |
end | |
Base.sum(xs::CUDA.CUSOLVER.CuQRPackedQ) = sum(CuArray(xs)) | |
function ChainRules.rrule(::typeof(CuArray), A::CUDA.CUSOLVER.CuQRPackedQ) | |
function pullback(Ȳ) | |
return (NoTangent(), Ȳ) | |
end | |
return CuArray(A), pullback | |
end | |
function ChainRules.rrule(::Type{T}, Q::LinearAlgebra.QRCompactWYQ) where {T<:Array} | |
T(Q), dy -> (NoTangent(), hcat(dy, falses(size(Q,1), size(Q,2)-size(Q.factors,2)))) | |
end | |
function ChainRules.rrule(::Type{T}, Q::CUDA.CUSOLVER.CuQRPackedQ) where {T<:CuArray} | |
T(Q), dy -> (NoTangent(), hcat(dy, falses(size(Q,1), size(Q,2)-size(Q.factors,2)))) | |
end | |
function ChainRulesCore.rrule(::typeof(getproperty), F::CUDA.CUSOLVER.CuQR, d::Symbol) | |
function getproperty_qr_pullback(Ȳ) | |
∂factors = d === :Q ? Ȳ : nothing | |
∂τ = d === :R ? Ȳ : nothing | |
∂F = Tangent{CUDA.CUSOLVER.CuQR}(; factors=∂factors, τ=∂τ) | |
return (NoTangent(), ∂F) | |
end | |
return getproperty(F, d), getproperty_qr_pullback | |
end | |
function ChainRules.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol) | |
function getproperty_qr_pullback(Ȳ) | |
∂factors = d === :Q ? Ȳ : nothing | |
∂T = d === :R ? Ȳ : nothing | |
∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T) | |
return (NoTangent(), ∂F) | |
end | |
return getproperty(F, d), getproperty_qr_pullback | |
end | |
function ChainRules.rrule(::typeof(qr), A::CuArray{T}) where {T} | |
QR = qr(A) | |
m, n = size(A) | |
Q, R = QR | |
Q_arr = CuArray(Q) | |
R_arr = CuArray(R) | |
function qr_pullback_cu(Ȳ::Tangent) | |
Q̄ = Ȳ.factors | |
R̄ = Ȳ.T | |
function qr_pullback_square_deep(Q̄, R̄, Q, R) | |
M = R̄*R' - Q'*Q̄ | |
# M <- copyltu(M) | |
M = tril(M) + transpose(tril(M,-1)) | |
Ā = (Q̄ + Q * M) / R' | |
end | |
if m ≥ n | |
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : CuArray(Q̄[:, axes(R, 2)]) | |
Ā = qr_pullback_square_deep(Q̄, R̄, Q_arr, R_arr) | |
else | |
# partition A = [X | Y] | |
# X = A[1:m, 1:m] | |
Y = A[1:m, m + 1:end] | |
# partition R = [U | V], and we don't need V | |
U = R[1:m, 1:m] | |
# V = R[1:m, m:end] | |
if R̄ isa ChainRules.AbstractZero | |
#@info "R̄ = 0, -> init V̄=0, Q̄_prime=0" | |
V̄ = zeros(T, size(Y)) |> CuArray | |
Q̄_prime = zeros(T, size(Q)) |> CuArray | |
Ū = R̄ | |
else | |
# partition R̄ = [Ū | V̄] | |
Ū = R̄[1:m, 1:m] | |
V̄ = R̄[1:m, m + 1:end] | |
Q̄_prime = Y * V̄' | |
end | |
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄ | |
@show typeof(Q̄_prime), typeof(Ū), typeof(Q_arr), typeof(U) | |
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, Q_arr, U) | |
Ȳ = Q * V̄ | |
# partition Ā = [X̄ | Ȳ] | |
Ā = [X̄ Ȳ] | |
end | |
return (NoTangent(), Ā) | |
end | |
return QR, qr_pullback_cu | |
end | |
function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T} | |
QR = qr(A) | |
m, n = size(A) | |
function qr_pullback(Ȳ::Tangent) | |
function qr_pullback_square_deep(Q̄, R̄, Q, R) | |
M = R*R̄' - Q̄'*Q | |
# M <- copyltu(M) | |
M = tril(M) + transpose(tril(M,-1)) | |
Ā = (Q̄ + Q * M) / R' | |
end | |
Q̄ = Ȳ.factors | |
R̄ = Ȳ.T | |
Q = QR.Q | |
R = QR.R | |
if m ≥ n | |
# qr returns the full QR factorization, including silent columns. We need to crop them | |
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : Q̄[:, axes(R, 2)] | |
Q = Matrix(Q) | |
Ā = qr_pullback_square_deep(Q̄, R̄, Q, R) | |
else # This is the case m < n, i.e. a short and wide matrix A | |
Y = @view A[1:m, m + 1:end] | |
U = R[1:m, 1:m] | |
if R̄ isa ChainRules.AbstractZero | |
V̄ = zeros(size(Y)) | |
Q̄_prime = zeros(size(Q)) | |
Ū = R̄ | |
else | |
# partition R̄ = [Ū | V̄] | |
Ū = @view R̄[1:m, 1:m] | |
V̄ = @view R̄[1:m, m + 1:end] | |
Q̄_prime = Y * V̄' | |
end | |
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄ | |
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U) | |
Ȳ = Q * V̄ | |
# partition Ā = [X̄ | Ȳ] | |
Ā = [X̄ Ȳ] | |
end | |
return (NoTangent(), Ā) | |
end | |
return QR, qr_pullback | |
end | |
V_host = rand(Float32, (6,4)) | |
V_dev = V_host |> CuArray | |
x_host = [1.0; 2.0; 3.0; 4.0; 5.0; 6.0] | |
x_dev = x_host |> CuArray | |
Q_host, R_host = qr(V_host); | |
Q_dev, R_dev = qr(V_dev); | |
function myfun_dev(A) | |
Q, R = qr(A); | |
Q2 = CuArray(Q); | |
sum((Q2 * Q2' * x_dev).^2) | |
end | |
function myfun_host(A) | |
Q, R = qr(A) | |
Q2 = Matrix(Q) | |
sum((Q2 * Q2' * x_host).^2) | |
end | |
res_host = Zygote.gradient(myfun_host, V_host) | |
res_dev = Zygote.gradient(myfun_dev, V_dev) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment