Last active September 9, 2023 19:28
Implementation of family containing Dirichlet and multivariate logit (logistic) normal as special cases, compatible with Turing
using Bijectors, Distributions, LinearAlgebra, PDMats
DirichletMvLogitNormal(α::AbstractVector, J::AbstractMatrix; check=false)
A distribution that contains `MvLogitNormal` and `Dirichlet` as special cases.[^Aitchison1985]
Let ``x \\sim \\mathrm{DirichletMvLogitNormal}(α, J)`` be a length ``N`` probability vector,
where ``α`` is a length ``N`` vector, and ``J`` is a size ``(N-1,N-1)`` positive
semi-definite matrix. Given ``z_i = \\log(x_i)``, the density function of the
Dirichlet-logit (logistic) normal distribution is
f(x; α, J) \\propto \\exp\\left( (α - 1_N)^\\top z(x) - \\frac{1}{2} z(x)^\\top Π^\\top J Π z(x) \\right),
where ``Π = I_{N,N-1} - e_N 1_N^\\top``.
The distribution is proper when
- `isposdef(J) && sum(α) ≥ 0` OR
- `!isposdef(-J) && all(>(0), α)`
These conditions are checked when `check=true`.
## Relationship to `Dirichlet`
When ``J = 0_{N-1,N-1}`` and ``α_i > 0`` for all ``i``, then the result is Dirichlet
with parameter vector ``α``.
## Relationship to `MvLogitNormal`
When ``1_N^\\top α = 0``, the result is an MvLogitNormal distribution with canonical
``(I_{N-1,N} \\alpha, J) = (\\Sigma^{-1}\\mu, \\Sigma^{-1})``.
!!! note
This parameterization is related to the ``(\\beta, \\Gamma)`` one in [^Aitchison1985] by
\\beta &= \\alpha\\
\\Gamma &= -\\frac{1}{2} Π^\\top J Π
J Aitchison. A General Class of Distributions on the Simplex.
J R Stat Soc Series B Stat Methodol, 1985, 47(1): 136-146.
struct DirichletMvLogitNormal{
} <: ContinuousMultivariateDistribution
function DirichletMvLogitNormal(
α::AbstractVector{T}, J::AbstractMatrix{T}; check::Bool=false
) where {T<:Real}
if check
size(J, 1) == size(J, 2) || throw(DimensionMismatch("J must be square."))
size(J, 1) == length(α) - 1 ||
throw(DimensionMismatch("The dimensions of α and J are inconsistent."))
(isposdef(J) && sum(α) ≥ 0) ||
(isposdef(-J) && all(>(0), α)) ||
"Either J must be positive-definite and α have a sum of 0, or J must be non-negative definite and all entries of α must be positive.",
J_colsum = dropdims(sum(J; dims=1); dims=1)
J_sum = sum(J_colsum)
return DirichletMvLogitNormal(α, J, J_colsum, J_sum)
function DirichletMvLogitNormal(
α::AbstractVector{<:Real}, J::AbstractMatrix{<:Real}; kwargs...
T = Base.promote_eltype(α, J)
return DirichletMvLogitNormal(
convert(AbstractArray{T}, α), convert(AbstractArray{T}, J); kwargs...
function Base.convert(::Type{DirichletMvLogitNormal}, d::Dirichlet)
α = d.alpha
J = PDMats.ScalMat(length(α) - 1, 0)
return DirichletMvLogitNormal(α, J)
if isdefined(Distributions, :MvLogitNormal)
function Base.convert(::Type{DirichletMvLogitNormal}, d::MvLogitNormal{<:MvNormalCanon})
dnorm = d.normal
α = vcat(dnorm.h, -sum(dnorm.h))
J = invcov(dnorm)
return DirichletMvLogitNormal(α, J)
function Base.convert(T::Type{DirichletMvLogitNormal}, d::MvLogitNormal)
return convert(T, MvLogitNormal(canonform(d.normal)))
Base.eltype(::Type{<:DirichletMvLogitNormal{T}}) where {T} = T
Base.length(d::DirichletMvLogitNormal) = length(d.α)
Distributions.params(d::DirichletMvLogitNormal) = (d.α, d.J)
@inline Distributions.partype(::Type{DirichletMvLogitNormal{T}}) where {T} = T
function Distributions.insupport(d::DirichletMvLogitNormal, x::AbstractVector{<:Real})
return length(d) == length(x) && all(>(0), x) && sum(x) ≈ 1
function Distributions.logpdf(d::DirichletMvLogitNormal, x::AbstractVector)
z₋ = @views log.(z[firstindex(z):(end - 1)])
zₙ = log(z[end])
zᵀΓz = (dot(z₋, d.J, z₋) - 2zₙ * dot(d.Jcolsum, z₋) + d.Jsum * zₙ^2) / -2
return dot(d.α, z) - sum(z) + zᵀΓz
Bijectors.bijector(::DirichletMvLogitNormal) = Bijectors.SimplexBijector()
