Last active
October 9, 2023 23:34
-
-
Save sethaxen/0b1c9cf92cbea99035a0261e16921a69 to your computer and use it in GitHub Desktop.
Bijector from Cholesky factor of correlation matrix with structural zeros
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Bijectors, ChangesOfVariables, LinearAlgebra | |
struct VecCholeskyCorrBijectorWithZeros{Z<:AbstractVector{Tuple{Int,Int}}} <: Bijectors.Bijector | |
zero_inds::Z | |
function VecCholeskyCorrBijectorWithZeros(zero_inds::Z) where Z<:AbstractVector{Tuple{Int,Int}} | |
new{Z}(unique!(sort(map(sort, zero_inds); by=reverse))) | |
end | |
end | |
function Bijectors.output_size(b::VecCholeskyCorrBijectorWithZeros, sz::Tuple{Int,Int}) | |
n = sz[1] | |
return ((n * (n - 1)) ÷ 2 - length(b.zero_inds),) | |
end | |
function Bijectors.output_size(b::Inverse{<:VecCholeskyCorrBijectorWithZeros}, sz::Tuple{Int}) | |
n = Bijectors._triu1_dim_from_length(first(sz) + length(inverse(b).zero_inds)) | |
return (n, n) | |
end | |
function ChangesOfVariables.with_logabsdet_jacobian(b::Inverse{<:VecCholeskyCorrBijectorWithZeros}, y::AbstractVector{<:Real}) | |
x = similar(y, Bijectors.output_size(b, size(y))) | |
zero_inds = inverse(b).zero_inds | |
x[1, 1] = 1 | |
iy = 1 | |
iz = firstindex(zero_inds) | |
logJ = zero(eltype(x)) # logabsdet of Jacobian | |
for j in axes(x, 2) | |
xj = @view x[1:j, j] | |
# get column indices of perpendicular columns | |
zero_inds_tail = @view zero_inds[iz:end] | |
idx_perp = map(first, zero_inds_tail[searchsorted(zero_inds_tail, j; by=last)]) | |
num_perp = length(idx_perp) | |
iz += num_perp | |
# fill head of xj with zeros | |
xj[1:num_perp] .= 0 | |
# construct tail of xj a unit vector on the positive hemisphere | |
s = one(eltype(x)) | |
for k in (num_perp + 1):(j - 1) | |
zk = tanh(y[iy]) | |
xj[k] = zk * s | |
s *= sqrt(1 - zk^2) | |
logJ += (j - k + 1) * log1p(-zk^2) / 2 | |
iy += 1 | |
end | |
xj[j] = s | |
if num_perp > 0 | |
xperp = @view x[1:j, idx_perp] | |
Q, R = qr(xperp) | |
lmul!(Q, xj) # lift xj to nullspace of xperp', i.e. so that xperp'xj = 0 | |
for iR in diagind(R) | |
logJ -= log(abs(R[iR])) | |
end | |
end | |
# zero out the lower triangle | |
x[(j+1):end, j] .= 0 | |
end | |
return Cholesky(UpperTriangular(x)), logJ | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment