Skip to content

Instantly share code, notes, and snippets.

@sethaxen
Last active October 9, 2023 23:34
Show Gist options
  • Save sethaxen/0b1c9cf92cbea99035a0261e16921a69 to your computer and use it in GitHub Desktop.
Save sethaxen/0b1c9cf92cbea99035a0261e16921a69 to your computer and use it in GitHub Desktop.
Bijector from Cholesky factor of correlation matrix with structural zeros
using Bijectors, ChangesOfVariables, LinearAlgebra
struct VecCholeskyCorrBijectorWithZeros{Z<:AbstractVector{Tuple{Int,Int}}} <: Bijectors.Bijector
zero_inds::Z
function VecCholeskyCorrBijectorWithZeros(zero_inds::Z) where Z<:AbstractVector{Tuple{Int,Int}}
new{Z}(unique!(sort(map(sort, zero_inds); by=reverse)))
end
end
function Bijectors.output_size(b::VecCholeskyCorrBijectorWithZeros, sz::Tuple{Int,Int})
n = sz[1]
return ((n * (n - 1)) ÷ 2 - length(b.zero_inds),)
end
function Bijectors.output_size(b::Inverse{<:VecCholeskyCorrBijectorWithZeros}, sz::Tuple{Int})
n = Bijectors._triu1_dim_from_length(first(sz) + length(inverse(b).zero_inds))
return (n, n)
end
function ChangesOfVariables.with_logabsdet_jacobian(b::Inverse{<:VecCholeskyCorrBijectorWithZeros}, y::AbstractVector{<:Real})
x = similar(y, Bijectors.output_size(b, size(y)))
zero_inds = inverse(b).zero_inds
x[1, 1] = 1
iy = 1
iz = firstindex(zero_inds)
logJ = zero(eltype(x)) # logabsdet of Jacobian
for j in axes(x, 2)
xj = @view x[1:j, j]
# get column indices of perpendicular columns
zero_inds_tail = @view zero_inds[iz:end]
idx_perp = map(first, zero_inds_tail[searchsorted(zero_inds_tail, j; by=last)])
num_perp = length(idx_perp)
iz += num_perp
# fill head of xj with zeros
xj[1:num_perp] .= 0
# construct tail of xj a unit vector on the positive hemisphere
s = one(eltype(x))
for k in (num_perp + 1):(j - 1)
zk = tanh(y[iy])
xj[k] = zk * s
s *= sqrt(1 - zk^2)
logJ += (j - k + 1) * log1p(-zk^2) / 2
iy += 1
end
xj[j] = s
if num_perp > 0
xperp = @view x[1:j, idx_perp]
Q, R = qr(xperp)
lmul!(Q, xj) # lift xj to nullspace of xperp', i.e. so that xperp'xj = 0
for iR in diagind(R)
logJ -= log(abs(R[iR]))
end
end
# zero out the lower triangle
x[(j+1):end, j] .= 0
end
return Cholesky(UpperTriangular(x)), logJ
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment