Skip to content

Instantly share code, notes, and snippets.

@rkube
Last active August 25, 2021 20:03
Show Gist options
  • Save rkube/b17ef683409d76a3f01bcc590b85de6e to your computer and use it in GitHub Desktop.
Save rkube/b17ef683409d76a3f01bcc590b85de6e to your computer and use it in GitHub Desktop.
Minimal working example of QR pullback
using LinearAlgebra
using Zygote
using ChainRules
using ChainRulesCore
using Random
using CUDA
Random.seed!(1234);
function ChainRules.rrule(::typeof(qr), A::CuArray{T}) where {T}
QR = qr(A)
m, n = size(A)
Q, R = QR
Q_arr = CuArray(Q)
R_arr = CuArray(R)
function qr_pullback_cu(Ȳ::Tangent)
Q̄ = Ȳ.factors
R̄ = Ȳ.T
function qr_pullback_square_deep(Q̄, R̄, Q, R)
M = R̄*R' - Q'*Q̄
# M <- copyltu(M)
M = triu(M) + transpose(triu(M,1))
Ā = (Q̄ + Q * M) / R'
end
if m ≥ n
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : CuArray(Q̄[:, axes(Q, 2)])
Ā = qr_pullback_square_deep(Q̄, R̄, Q_arr, R_arr)
else
# partition A = [X | Y]
# X = A[1:m, 1:m]
Y = A[1:m, m + 1:end]
# partition R = [U | V], and we don't need V
U = R[1:m, 1:m]
if R̄ isa ChainRules.AbstractZero
V̄ = zeros(T, size(Y)) |> CuArray
Q̄_prime = zeros(T, size(Q)) |> CuArray
Ū = R̄
else
# partition R̄ = [Ū | V̄]
Ū = R̄[1:m, 1:m]
V̄ = R̄[1:m, m + 1:end]
Q̄_prime = Y * V̄'
end
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, Q_arr, U)
Ȳ = Q * V̄
# partition Ā = [X̄ | Ȳ]
Ā = [X̄ Ȳ]
end
return (NoTangent(), Ā)
end
return QR, qr_pullback_cu
end
function ChainRulesCore.rrule(::typeof(getproperty), F::CUDA.CUSOLVER.CuQR, d::Symbol)
function getproperty_qr_pullback(Ȳ)
∂factors = d === :Q ? Ȳ : nothing
∂τ = d === :R ? Ȳ : nothing
∂F = Tangent{CUDA.CUSOLVER.CuQR}(; factors=∂factors, τ=∂τ)
return (NoTangent(), ∂F)
end
return getproperty(F, d), getproperty_qr_pullback
end
V = rand(Float32, (4, 4)) |> CuArray
function f1(V)
Q, _ = qr(V)
return sum(Q)
end
res = Zygote.gradient(f1, V)
V2 = rand(Float32, (4, 6)) |> CuArray
res_2 = Zygote.gradient(f1, V2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment