Skip to content

Instantly share code, notes, and snippets.

@sethaxen
Last active November 8, 2023 14:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sethaxen/4d5a6d22e56794a90ece6224b8100f08 to your computer and use it in GitHub Desktop.
Save sethaxen/4d5a6d22e56794a90ece6224b8100f08 to your computer and use it in GitHub Desktop.
Demo of a bijective transform to a QR factorization using elementary reflectors
using Bijectors, ChangesOfVariables, LinearAlgebra
"""
QRReflectorBijector(n, k; special::Bool=false, fullR::Bool=false)
A bijector that maps parameters of elementary LAPACK-style reflectors to a QR factorization
with size `(n, k)`.
The `R` factor of the element has a positive real diagonal.
If `fullR=false`, then the `R` factor is diagonal, which reduces the number of
unconstrained parameters; otherwise, it is upper triangular.
If `n=k` and `special=true`, then the `Q` factor is constrained to have a determinant of
`+1`.
If the unconstrained parameters are complex, then the resulting QR factorization is also
complex.
If the parameters follow a standard normal distribution, then the `Q` factor follows the
normalized invariant measure on the Stiefel manifold.
# Example
The following models can be used as submodels inside any Turing model to generate a
`QRPacked` that follows the normalized invariant measure on the Stiefel manifold.
```julia
using Bijectors, Turing
@model function StiefelInvariantMeasure(n::Int, k::Int; complex=false, special=false)
b = QRReflectorBijector(n, k; special)
m = only(output_size(b, (n, k)))
if complex
z_re ~ filldist(Normal(), m)
z_im ~ filldist(Normal(), m)
z = Base.complex.(z_re, z_im)
else
z ~ filldist(Normal(), m)
end
Q = transform(inverse(b), z).Q
return Q
end
OrthogonalHaar(n) = StiefelInvariantMeasure(n, n)
SpecialOrthogonalHaar(n) = StiefelInvariantMeasure(n, n; special=true)
UnitaryHaar(n) = StiefelInvariantMeasure(n, n; complex=true)
SpecialUnitaryHaar(n) = StiefelInvariantMeasure(n, n; complex=true, special=true)
```
Here for example we sample a semi-orthogonal matrix `X`, a 3x3 rotation matrix `Rot`, and a matrix
representation `Q` of a unit quaternion, where the special orthogonal matrix follows the Matrix von
Mises-Fisher distribution highly concentrated around the identity matrix.
```julia
julia> @model function demo()
@submodel prefix="Rot" Rot = SpecialOrthogonalHaar(3)
@submodel prefix="Q" Q = SpecialUnitaryHaar(2)
@submodel prefix="X" X = StiefelInvariantMeasure(8, 3)
Turing.@addlogprob! tr((10^2 * I) * Rot)
return (; Rot, Q, X)
end;
julia> chn = sample(demo(), NUTS(0.8), MCMCThreads(), 500, 4);
julia> params = generated_quantities(demo(), chn);
julia> mean(x -> Matrix(x.Rot), params)
3×3 Matrix{Float64}:
0.994954 0.000719649 -0.000884539
-0.000716116 0.995093 0.00346598
0.000718068 -0.00300927 0.994792
julia> mean(x -> Matrix(x.Q), params)
2×2 Matrix{ComplexF64}:
-0.00489216-0.00368763im -0.000754158-0.0419339im
0.000754158-0.0419339im -0.00489216+0.00368763im
"""
struct QRReflectorBijector <: Bijectors.Bijector
n::Int
k::Int
fullR::Bool
special::Bool
end
function QRReflectorBijector(n::Int, k::Int; fullR::Bool=false, special::Bool=false)
n ≥ k || throw(DimensionMismatch("n=$n must be at least k=$k."))
k ≥ 1 || throw(DimensionMismatch("k=$k must be at least 1."))
if n > k && special
@warn "special=true is only supported for n=k. Setting to false"
special = false
end
return QRReflectorBijector(n, k, fullR, special)
end
_numdim(::Type{<:Real}) = 1
_numdim(::Type{<:Complex}) = 2
function output_size(b::Inverse{QRReflectorBijector}, (m,)::Tuple{Int})
binv = Bijectors.inverse(b)
n = binv.n
k = binv.k
m_expected = if binv.fullR
n * k - binv.special
else
n * k - div(k * (k - 1), 2) - binv.special
end
m == m_expected || throw(
DimensionMismatch(
"size of input $((m,)) does not match expected size $((m_expected,))."
),
)
return (n, k)
end
function output_size(b::QRReflectorBijector, (n, k)::NTuple{2,Int})
b.n == n && b.k == k || throw(
DimensionMismatch(
"size of input $((n, k)) does not match expected size $((b.n, b.k))."
),
)
if b.fullR
return (n * k - b.special,)
else
return (n * k - div(k * (k - 1), 2) - b.special,)
end
end
function ChangesOfVariables.with_logabsdet_jacobian(
b::Inverse{QRReflectorBijector}, z::AbstractVector{T}
) where {T<:Union{Real,Complex}}
binv = Bijectors.inverse(b)
Base.require_one_based_indexing(z)
n, k = output_size(b, size(z))
factors = similar(z, n, k)
τ = similar(z, k)
fullR = binv.fullR
fullR || tril!(factors)
nlower = n
iz = 1
for j in 1:k
# copy to upper part without modification
if fullR
nupper = n - nlower
copyto!(view(factors, 1:nupper, j), view(z, iz:(iz - 1 + nupper)))
iz += nupper
end
j == n && binv.special && break
# copy reflector parameters to factor
α = z[iz]
x = view(factors, (j + 1):n, j)
copyto!(x, view(z, (iz + 1):(iz + nlower - 1)))
# compute reflector factor in-place
τ[j], _, factors[j, j] = get_reflector!(α, x)
iz += nlower
nlower -= 1
end
F = LinearAlgebra.QR(factors, τ)
if binv.special
_complete_tau!(τ)
factors[n, n] = 1
end
dimT = _numdim(T)
logJ = dimT * sum(1:k) do j
return (n - j) * log(real(factors[j, j]))
end
return F, logJ
end
function ChangesOfVariables.with_logabsdet_jacobian(
b::QRReflectorBijector, F::LinearAlgebra.QR{T}
) where {T<:Union{Real,Complex}}
factors = F.factors
τ = F.τ
n, k = size(factors)
z = similar(factors, output_size(b, (n, k)))
tmp = similar(factors, n)
fullR = b.fullR
nlower = n
iz = 1
for j in 1:k
# copy to upper part without modification
if fullR
nupper = n - nlower
copyto!(view(z, iz:(iz - 1 + nupper)), view(factors, 1:nupper, j))
iz += nupper
end
j == n && binv.special && break
# initialize lower part as R[j, j] * e1
zj = view(z, iz:(iz + nlower - 1))
zj[1] = factors[j, j]
fill!(view(zj, 2:nlower), 0)
# set up reflector in tmp
xfull = view(tmp, j:n)
xfull[1] = 1
copyto!(view(xfull, 2:nlower), view(factors, (j + 1):n, j))
# apply reflector to lower part
apply_reflector!(τ[j], xfull, zj)
iz += nlower
nlower -= 1
end
dimT = _numdim(T)
logJ = -dimT * sum(1:k) do j
return (n - j) * log(real(factors[j, j]))
end
return z, logJ
end
# adapted from XLARFGP from LAPACK
function get_reflector!(α, x::AbstractVector)
xnorm = norm(x)
if iszero(xnorm) && isreal(α)
τ = real(α) > 0 ? zero(α) : oftype(α, 2)
β = abs(real(α))
return τ, x, β
end
β = -copysign(hypot(α, xnorm), real(α))
(m, β, α, _, xnorm) = _possibly_rescale!(β, α, x, xnorm)
if β ≥ 0
η = α - β
else
β = -β
γ = real(α) + β
if α isa Real
η = -xnorm * (xnorm / γ)
else
imα = α - real(α)
abs_imα = abs(imα)
δ = -(abs_imα * (abs_imα / γ) + xnorm * (xnorm / γ))
η = δ + imα
end
end
τ = -η / β
x ./= η
β = ldexp(β, m)
return τ, x, β
end
_get_log2_safmin(::Type{T}) where {T<:Real} = exponent(floatmin(T) / eps(T))
function _possibly_rescale!(β, α, x, xnorm)
T = float(real(eltype(x)))
safmin_exponent = _get_log2_safmin(T)
safmin = exp2(T(safmin_exponent))
invsafmin = exp2(-T(safmin_exponent))
m = 0
if abs(β) ≥ safmin
return (m, β, α, x, xnorm)
end
while abs(β) < safmin
rmul!(x, invsafmin)
β *= invsafmin
α *= invsafmin
m += 1
end
xnorm = norm(x)
β = -copysign(hypot(α, xnorm), real(α))
m *= safmin_exponent
return (m, β, α, x, xnorm)
end
function apply_reflector!(τ::Number, x::AbstractVector, b::AbstractVector)
iszero(τ) && return b
b .-= τ .* x .* dot(x, b)
return b
end
# type piracy so ForwardDiff plays well with hypot(::Complex, ::Complex)
# https://github.com/JuliaDiff/ForwardDiff.jl/issues/537
Base.floatmin(x::Number) = floatmin(float(typeof(x)))
Base.floatmax(x::Number) = floatmax(float(typeof(x)))
function _complete_tau!(τs::AbstractVector)
det_prev = @views _det_tau(τs[1:(end - 1)])
τs[end] = 1 - conj(det_prev)
return τs
end
# copied from LinearAlgebra.jl
# https://github.com/JuliaLang/julia/blob/5aaa94854367ca875375e38ae14f369f124e7315/stdlib/LinearAlgebra/src/abstractq.jl#L448-L459
# Compute `det` from the number of Householder reflections. Handle
# the case `Q.τ` contains zeros.
function _det_tau(τs::AbstractVector{<:Real})
return isodd(count(!iszero, τs)) ? -one(eltype(τs)) : one(eltype(τs))
end
# In complex case, we need to compute the non-unit eigenvalue `λ = 1 - c*τ`
# (where `c = v'v`) of each Householder reflector. As we know that the
# reflector must have the determinant of 1, it must satisfy `abs2(λ) == 1`.
# Combining this with the constraint `c > 0`, it turns out that the eigenvalue
# (hence the determinant) can be computed as `λ = -sign(τ)^2`.
# See: https://github.com/JuliaLang/julia/pull/32887#issuecomment-521935716
_det_tau(τs) = prod(τ -> iszero(τ) ? one(τ) : -sign(τ)^2, τs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment