Last active
September 9, 2023 19:28
-
-
Save sethaxen/fc732d4a3baf3e0aca3bde42d20f02ff to your computer and use it in GitHub Desktop.
Implementation of family containing Dirichlet and multivariate logit (logistic) normal as special cases, compatible with Turing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Bijectors, Distributions, LinearAlgebra, PDMats | |
""" | |
DirichletMvLogitNormal(α::AbstractVector, J::AbstractMatrix; check=false) | |
A distribution that contains `MvLogitNormal` and `Dirichlet` as special cases.[^Aitchison1985] | |
Let ``x \\sim \\mathrm{DirichletMvLogitNormal}(α, J)`` be a length ``N`` probability vector, | |
where ``α`` is a length ``N`` vector, and ``J`` is a size ``(N-1,N-1)`` positive | |
semi-definite matrix. Given ``z_i = \\log(x_i)``, the density function of the | |
Dirichlet-logit (logistic) normal distribution is | |
```math | |
f(x; α, J) \\propto \\exp\\left( (α - 1_N)^\\top z(x) - \\frac{1}{2} z(x)^\\top Π^\\top J Π z(x) \\right), | |
``` | |
where ``Π = I_{N,N-1} - e_N 1_N^\\top``. | |
The distribution is proper when | |
- `isposdef(J) && sum(α) ≥ 0` OR | |
- `!isposdef(-J) && all(>(0), α)` | |
These conditions are checked when `check=true`. | |
## Relationship to `Dirichlet` | |
When ``J = 0_{N-1,N-1}`` and ``α_i > 0`` for all ``i``, then the result is Dirichlet | |
with parameter vector ``α``. | |
## Relationship to `MvLogitNormal` | |
When ``1_N^\\top α = 0``, the result is an MvLogitNormal distribution with canonical | |
parameters | |
``(I_{N-1,N} \\alpha, J) = (\\Sigma^{-1}\\mu, \\Sigma^{-1})``. | |
!!! note | |
This parameterization is related to the ``(\\beta, \\Gamma)`` one in [^Aitchison1985] by | |
```math | |
\\begin{aligned} | |
\\beta &= \\alpha\\ | |
\\Gamma &= -\\frac{1}{2} Π^\\top J Π | |
\\end{aligned}. | |
``` | |
[^Aitchison1985]: | |
J Aitchison. A General Class of Distributions on the Simplex. | |
J R Stat Soc Series B Stat Methodol, 1985, 47(1): 136-146. | |
""" | |
struct DirichletMvLogitNormal{ | |
T<:Real,TA<:AbstractVector{T},TJ<:AbstractMatrix{T},Ts<:AbstractVector{T} | |
} <: ContinuousMultivariateDistribution | |
α::TA | |
J::TJ | |
Jcolsum::Ts | |
Jsum::T | |
end | |
function DirichletMvLogitNormal( | |
α::AbstractVector{T}, J::AbstractMatrix{T}; check::Bool=false | |
) where {T<:Real} | |
if check | |
size(J, 1) == size(J, 2) || throw(DimensionMismatch("J must be square.")) | |
size(J, 1) == length(α) - 1 || | |
throw(DimensionMismatch("The dimensions of α and J are inconsistent.")) | |
(isposdef(J) && sum(α) ≥ 0) || | |
(isposdef(-J) && all(>(0), α)) || | |
throw( | |
ArgumentError( | |
"Either J must be positive-definite and α have a sum of 0, or J must be non-negative definite and all entries of α must be positive.", | |
), | |
) | |
end | |
J_colsum = dropdims(sum(J; dims=1); dims=1) | |
J_sum = sum(J_colsum) | |
return DirichletMvLogitNormal(α, J, J_colsum, J_sum) | |
end | |
function DirichletMvLogitNormal( | |
α::AbstractVector{<:Real}, J::AbstractMatrix{<:Real}; kwargs... | |
) | |
T = Base.promote_eltype(α, J) | |
return DirichletMvLogitNormal( | |
convert(AbstractArray{T}, α), convert(AbstractArray{T}, J); kwargs... | |
) | |
end | |
function Base.convert(::Type{DirichletMvLogitNormal}, d::Dirichlet) | |
α = d.alpha | |
J = PDMats.ScalMat(length(α) - 1, 0) | |
return DirichletMvLogitNormal(α, J) | |
end | |
if isdefined(Distributions, :MvLogitNormal) | |
function Base.convert(::Type{DirichletMvLogitNormal}, d::MvLogitNormal{<:MvNormalCanon}) | |
dnorm = d.normal | |
α = vcat(dnorm.h, -sum(dnorm.h)) | |
J = invcov(dnorm) | |
return DirichletMvLogitNormal(α, J) | |
end | |
function Base.convert(T::Type{DirichletMvLogitNormal}, d::MvLogitNormal) | |
return convert(T, MvLogitNormal(canonform(d.normal))) | |
end | |
end | |
Base.eltype(::Type{<:DirichletMvLogitNormal{T}}) where {T} = T | |
Base.length(d::DirichletMvLogitNormal) = length(d.α) | |
Distributions.params(d::DirichletMvLogitNormal) = (d.α, d.J) | |
@inline Distributions.partype(::Type{DirichletMvLogitNormal{T}}) where {T} = T | |
function Distributions.insupport(d::DirichletMvLogitNormal, x::AbstractVector{<:Real}) | |
return length(d) == length(x) && all(>(0), x) && sum(x) ≈ 1 | |
end | |
function Distributions.logpdf(d::DirichletMvLogitNormal, x::AbstractVector) | |
z₋ = @views log.(z[firstindex(z):(end - 1)]) | |
zₙ = log(z[end]) | |
zᵀΓz = (dot(z₋, d.J, z₋) - 2zₙ * dot(d.Jcolsum, z₋) + d.Jsum * zₙ^2) / -2 | |
return dot(d.α, z) - sum(z) + zᵀΓz | |
end | |
Bijectors.bijector(::DirichletMvLogitNormal) = Bijectors.SimplexBijector() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment