Last active
November 8, 2023 14:31
-
-
Save sethaxen/4d5a6d22e56794a90ece6224b8100f08 to your computer and use it in GitHub Desktop.
Demo of a bijective transform to a QR factorization using elementary reflectors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Bijectors, ChangesOfVariables, LinearAlgebra | |
""" | |
QRReflectorBijector(n, k; special::Bool=false, fullR::Bool=false) | |
A bijector that maps parameters of elementary LAPACK-style reflectors to a QR factorization | |
with size `(n, k)`. | |
The `R` factor of the element has a positive real diagonal. | |
If `fullR=false`, then the `R` factor is diagonal, which reduces the number of | |
unconstrained parameters; otherwise, it is upper triangular. | |
If `n=k` and `special=true`, then the `Q` factor is constrained to have a determinant of | |
`+1`. | |
If the unconstrained parameters are complex, then the resulting QR factorization is also | |
complex. | |
If the parameters follow a standard normal distribution, then the `Q` factor follows the | |
normalized invariant measure on the Stiefel manifold. | |
# Example | |
The following models can be used as submodels inside any Turing model to generate a | |
`QRPacked` that follows the normalized invariant measure on the Stiefel manifold. | |
```julia | |
using Bijectors, Turing | |
@model function StiefelInvariantMeasure(n::Int, k::Int; complex=false, special=false) | |
b = QRReflectorBijector(n, k; special) | |
m = only(output_size(b, (n, k))) | |
if complex | |
z_re ~ filldist(Normal(), m) | |
z_im ~ filldist(Normal(), m) | |
z = Base.complex.(z_re, z_im) | |
else | |
z ~ filldist(Normal(), m) | |
end | |
Q = transform(inverse(b), z).Q | |
return Q | |
end | |
OrthogonalHaar(n) = StiefelInvariantMeasure(n, n) | |
SpecialOrthogonalHaar(n) = StiefelInvariantMeasure(n, n; special=true) | |
UnitaryHaar(n) = StiefelInvariantMeasure(n, n; complex=true) | |
SpecialUnitaryHaar(n) = StiefelInvariantMeasure(n, n; complex=true, special=true) | |
``` | |
Here for example we sample a semi-orthogonal matrix `X`, a 3x3 rotation matrix `Rot`, and a matrix | |
representation `Q` of a unit quaternion, where the special orthogonal matrix follows the Matrix von | |
Mises-Fisher distribution highly concentrated around the identity matrix. | |
```julia | |
julia> @model function demo() | |
@submodel prefix="Rot" Rot = SpecialOrthogonalHaar(3) | |
@submodel prefix="Q" Q = SpecialUnitaryHaar(2) | |
@submodel prefix="X" X = StiefelInvariantMeasure(8, 3) | |
Turing.@addlogprob! tr((10^2 * I) * Rot) | |
return (; Rot, Q, X) | |
end; | |
julia> chn = sample(demo(), NUTS(0.8), MCMCThreads(), 500, 4); | |
julia> params = generated_quantities(demo(), chn); | |
julia> mean(x -> Matrix(x.Rot), params) | |
3×3 Matrix{Float64}: | |
0.994954 0.000719649 -0.000884539 | |
-0.000716116 0.995093 0.00346598 | |
0.000718068 -0.00300927 0.994792 | |
julia> mean(x -> Matrix(x.Q), params) | |
2×2 Matrix{ComplexF64}: | |
-0.00489216-0.00368763im -0.000754158-0.0419339im | |
0.000754158-0.0419339im -0.00489216+0.00368763im | |
""" | |
struct QRReflectorBijector <: Bijectors.Bijector | |
n::Int | |
k::Int | |
fullR::Bool | |
special::Bool | |
end | |
function QRReflectorBijector(n::Int, k::Int; fullR::Bool=false, special::Bool=false) | |
n ≥ k || throw(DimensionMismatch("n=$n must be at least k=$k.")) | |
k ≥ 1 || throw(DimensionMismatch("k=$k must be at least 1.")) | |
if n > k && special | |
@warn "special=true is only supported for n=k. Setting to false" | |
special = false | |
end | |
return QRReflectorBijector(n, k, fullR, special) | |
end | |
_numdim(::Type{<:Real}) = 1 | |
_numdim(::Type{<:Complex}) = 2 | |
function output_size(b::Inverse{QRReflectorBijector}, (m,)::Tuple{Int}) | |
binv = Bijectors.inverse(b) | |
n = binv.n | |
k = binv.k | |
m_expected = if binv.fullR | |
n * k - binv.special | |
else | |
n * k - div(k * (k - 1), 2) - binv.special | |
end | |
m == m_expected || throw( | |
DimensionMismatch( | |
"size of input $((m,)) does not match expected size $((m_expected,))." | |
), | |
) | |
return (n, k) | |
end | |
function output_size(b::QRReflectorBijector, (n, k)::NTuple{2,Int}) | |
b.n == n && b.k == k || throw( | |
DimensionMismatch( | |
"size of input $((n, k)) does not match expected size $((b.n, b.k))." | |
), | |
) | |
if b.fullR | |
return (n * k - b.special,) | |
else | |
return (n * k - div(k * (k - 1), 2) - b.special,) | |
end | |
end | |
function ChangesOfVariables.with_logabsdet_jacobian( | |
b::Inverse{QRReflectorBijector}, z::AbstractVector{T} | |
) where {T<:Union{Real,Complex}} | |
binv = Bijectors.inverse(b) | |
Base.require_one_based_indexing(z) | |
n, k = output_size(b, size(z)) | |
factors = similar(z, n, k) | |
τ = similar(z, k) | |
fullR = binv.fullR | |
fullR || tril!(factors) | |
nlower = n | |
iz = 1 | |
for j in 1:k | |
# copy to upper part without modification | |
if fullR | |
nupper = n - nlower | |
copyto!(view(factors, 1:nupper, j), view(z, iz:(iz - 1 + nupper))) | |
iz += nupper | |
end | |
j == n && binv.special && break | |
# copy reflector parameters to factor | |
α = z[iz] | |
x = view(factors, (j + 1):n, j) | |
copyto!(x, view(z, (iz + 1):(iz + nlower - 1))) | |
# compute reflector factor in-place | |
τ[j], _, factors[j, j] = get_reflector!(α, x) | |
iz += nlower | |
nlower -= 1 | |
end | |
F = LinearAlgebra.QR(factors, τ) | |
if binv.special | |
_complete_tau!(τ) | |
factors[n, n] = 1 | |
end | |
dimT = _numdim(T) | |
logJ = dimT * sum(1:k) do j | |
return (n - j) * log(real(factors[j, j])) | |
end | |
return F, logJ | |
end | |
function ChangesOfVariables.with_logabsdet_jacobian( | |
b::QRReflectorBijector, F::LinearAlgebra.QR{T} | |
) where {T<:Union{Real,Complex}} | |
factors = F.factors | |
τ = F.τ | |
n, k = size(factors) | |
z = similar(factors, output_size(b, (n, k))) | |
tmp = similar(factors, n) | |
fullR = b.fullR | |
nlower = n | |
iz = 1 | |
for j in 1:k | |
# copy to upper part without modification | |
if fullR | |
nupper = n - nlower | |
copyto!(view(z, iz:(iz - 1 + nupper)), view(factors, 1:nupper, j)) | |
iz += nupper | |
end | |
j == n && binv.special && break | |
# initialize lower part as R[j, j] * e1 | |
zj = view(z, iz:(iz + nlower - 1)) | |
zj[1] = factors[j, j] | |
fill!(view(zj, 2:nlower), 0) | |
# set up reflector in tmp | |
xfull = view(tmp, j:n) | |
xfull[1] = 1 | |
copyto!(view(xfull, 2:nlower), view(factors, (j + 1):n, j)) | |
# apply reflector to lower part | |
apply_reflector!(τ[j], xfull, zj) | |
iz += nlower | |
nlower -= 1 | |
end | |
dimT = _numdim(T) | |
logJ = -dimT * sum(1:k) do j | |
return (n - j) * log(real(factors[j, j])) | |
end | |
return z, logJ | |
end | |
# adapted from XLARFGP from LAPACK | |
function get_reflector!(α, x::AbstractVector) | |
xnorm = norm(x) | |
if iszero(xnorm) && isreal(α) | |
τ = real(α) > 0 ? zero(α) : oftype(α, 2) | |
β = abs(real(α)) | |
return τ, x, β | |
end | |
β = -copysign(hypot(α, xnorm), real(α)) | |
(m, β, α, _, xnorm) = _possibly_rescale!(β, α, x, xnorm) | |
if β ≥ 0 | |
η = α - β | |
else | |
β = -β | |
γ = real(α) + β | |
if α isa Real | |
η = -xnorm * (xnorm / γ) | |
else | |
imα = α - real(α) | |
abs_imα = abs(imα) | |
δ = -(abs_imα * (abs_imα / γ) + xnorm * (xnorm / γ)) | |
η = δ + imα | |
end | |
end | |
τ = -η / β | |
x ./= η | |
β = ldexp(β, m) | |
return τ, x, β | |
end | |
_get_log2_safmin(::Type{T}) where {T<:Real} = exponent(floatmin(T) / eps(T)) | |
function _possibly_rescale!(β, α, x, xnorm) | |
T = float(real(eltype(x))) | |
safmin_exponent = _get_log2_safmin(T) | |
safmin = exp2(T(safmin_exponent)) | |
invsafmin = exp2(-T(safmin_exponent)) | |
m = 0 | |
if abs(β) ≥ safmin | |
return (m, β, α, x, xnorm) | |
end | |
while abs(β) < safmin | |
rmul!(x, invsafmin) | |
β *= invsafmin | |
α *= invsafmin | |
m += 1 | |
end | |
xnorm = norm(x) | |
β = -copysign(hypot(α, xnorm), real(α)) | |
m *= safmin_exponent | |
return (m, β, α, x, xnorm) | |
end | |
function apply_reflector!(τ::Number, x::AbstractVector, b::AbstractVector) | |
iszero(τ) && return b | |
b .-= τ .* x .* dot(x, b) | |
return b | |
end | |
# type piracy so ForwardDiff plays well with hypot(::Complex, ::Complex) | |
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/537 | |
Base.floatmin(x::Number) = floatmin(float(typeof(x))) | |
Base.floatmax(x::Number) = floatmax(float(typeof(x))) | |
function _complete_tau!(τs::AbstractVector) | |
det_prev = @views _det_tau(τs[1:(end - 1)]) | |
τs[end] = 1 - conj(det_prev) | |
return τs | |
end | |
# copied from LinearAlgebra.jl | |
# https://github.com/JuliaLang/julia/blob/5aaa94854367ca875375e38ae14f369f124e7315/stdlib/LinearAlgebra/src/abstractq.jl#L448-L459 | |
# Compute `det` from the number of Householder reflections. Handle | |
# the case `Q.τ` contains zeros. | |
function _det_tau(τs::AbstractVector{<:Real}) | |
return isodd(count(!iszero, τs)) ? -one(eltype(τs)) : one(eltype(τs)) | |
end | |
# In complex case, we need to compute the non-unit eigenvalue `λ = 1 - c*τ` | |
# (where `c = v'v`) of each Householder reflector. As we know that the | |
# reflector must have the determinant of 1, it must satisfy `abs2(λ) == 1`. | |
# Combining this with the constraint `c > 0`, it turns out that the eigenvalue | |
# (hence the determinant) can be computed as `λ = -sign(τ)^2`. | |
# See: https://github.com/JuliaLang/julia/pull/32887#issuecomment-521935716 | |
_det_tau(τs) = prod(τ -> iszero(τ) ? one(τ) : -sign(τ)^2, τs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment