Skip to content

Instantly share code, notes, and snippets.

@sethaxen
Last active September 9, 2023 19:28
Show Gist options
  • Save sethaxen/fc732d4a3baf3e0aca3bde42d20f02ff to your computer and use it in GitHub Desktop.
Save sethaxen/fc732d4a3baf3e0aca3bde42d20f02ff to your computer and use it in GitHub Desktop.
Implementation of family containing Dirichlet and multivariate logit (logistic) normal as special cases, compatible with Turing
using Bijectors, Distributions, LinearAlgebra, PDMats
"""
DirichletMvLogitNormal(α::AbstractVector, J::AbstractMatrix; check=false)
A distribution that contains `MvLogitNormal` and `Dirichlet` as special cases.[^Aitchison1985]
Let ``x \\sim \\mathrm{DirichletMvLogitNormal}(α, J)`` be a length ``N`` probability vector,
where ``α`` is a length ``N`` vector, and ``J`` is a size ``(N-1,N-1)`` positive
semi-definite matrix. Given ``z_i = \\log(x_i)``, the density function of the
Dirichlet-logit (logistic) normal distribution is
```math
f(x; α, J) \\propto \\exp\\left( (α - 1_N)^\\top z(x) - \\frac{1}{2} z(x)^\\top Π^\\top J Π z(x) \\right),
```
where ``Π = I_{N,N-1} - e_N 1_N^\\top``.
The distribution is proper when
- `isposdef(J) && sum(α) ≥ 0` OR
- `!isposdef(-J) && all(>(0), α)`
These conditions are checked when `check=true`.
## Relationship to `Dirichlet`
When ``J = 0_{N-1,N-1}`` and ``α_i > 0`` for all ``i``, then the result is Dirichlet
with parameter vector ``α``.
## Relationship to `MvLogitNormal`
When ``1_N^\\top α = 0``, the result is an MvLogitNormal distribution with canonical
parameters
``(I_{N-1,N} \\alpha, J) = (\\Sigma^{-1}\\mu, \\Sigma^{-1})``.
!!! note
This parameterization is related to the ``(\\beta, \\Gamma)`` one in [^Aitchison1985] by
```math
\\begin{aligned}
\\beta &= \\alpha\\
\\Gamma &= -\\frac{1}{2} Π^\\top J Π
\\end{aligned}.
```
[^Aitchison1985]:
J Aitchison. A General Class of Distributions on the Simplex.
J R Stat Soc Series B Stat Methodol, 1985, 47(1): 136-146.
"""
struct DirichletMvLogitNormal{
T<:Real,TA<:AbstractVector{T},TJ<:AbstractMatrix{T},Ts<:AbstractVector{T}
} <: ContinuousMultivariateDistribution
α::TA
J::TJ
Jcolsum::Ts
Jsum::T
end
function DirichletMvLogitNormal(
α::AbstractVector{T}, J::AbstractMatrix{T}; check::Bool=false
) where {T<:Real}
if check
size(J, 1) == size(J, 2) || throw(DimensionMismatch("J must be square."))
size(J, 1) == length(α) - 1 ||
throw(DimensionMismatch("The dimensions of α and J are inconsistent."))
(isposdef(J) && sum(α) ≥ 0) ||
(isposdef(-J) && all(>(0), α)) ||
throw(
ArgumentError(
"Either J must be positive-definite and α have a sum of 0, or J must be non-negative definite and all entries of α must be positive.",
),
)
end
J_colsum = dropdims(sum(J; dims=1); dims=1)
J_sum = sum(J_colsum)
return DirichletMvLogitNormal(α, J, J_colsum, J_sum)
end
function DirichletMvLogitNormal(
α::AbstractVector{<:Real}, J::AbstractMatrix{<:Real}; kwargs...
)
T = Base.promote_eltype(α, J)
return DirichletMvLogitNormal(
convert(AbstractArray{T}, α), convert(AbstractArray{T}, J); kwargs...
)
end
function Base.convert(::Type{DirichletMvLogitNormal}, d::Dirichlet)
α = d.alpha
J = PDMats.ScalMat(length(α) - 1, 0)
return DirichletMvLogitNormal(α, J)
end
if isdefined(Distributions, :MvLogitNormal)
function Base.convert(::Type{DirichletMvLogitNormal}, d::MvLogitNormal{<:MvNormalCanon})
dnorm = d.normal
α = vcat(dnorm.h, -sum(dnorm.h))
J = invcov(dnorm)
return DirichletMvLogitNormal(α, J)
end
function Base.convert(T::Type{DirichletMvLogitNormal}, d::MvLogitNormal)
return convert(T, MvLogitNormal(canonform(d.normal)))
end
end
Base.eltype(::Type{<:DirichletMvLogitNormal{T}}) where {T} = T
Base.length(d::DirichletMvLogitNormal) = length(d.α)
Distributions.params(d::DirichletMvLogitNormal) = (d.α, d.J)
@inline Distributions.partype(::Type{DirichletMvLogitNormal{T}}) where {T} = T
function Distributions.insupport(d::DirichletMvLogitNormal, x::AbstractVector{<:Real})
return length(d) == length(x) && all(>(0), x) && sum(x) ≈ 1
end
function Distributions.logpdf(d::DirichletMvLogitNormal, x::AbstractVector)
z₋ = @views log.(z[firstindex(z):(end - 1)])
zₙ = log(z[end])
zᵀΓz = (dot(z₋, d.J, z₋) - 2zₙ * dot(d.Jcolsum, z₋) + d.Jsum * zₙ^2) / -2
return dot(d.α, z) - sum(z) + zᵀΓz
end
Bijectors.bijector(::DirichletMvLogitNormal) = Bijectors.SimplexBijector()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment