Skip to content

Instantly share code, notes, and snippets.

@sethaxen
Created May 4, 2023 20:41
Show Gist options
  • Save sethaxen/2ce75fa84755e7e0cfd54cc9555d120b to your computer and use it in GitHub Desktop.
Save sethaxen/2ce75fa84755e7e0cfd54cc9555d120b to your computer and use it in GitHub Desktop.
WrappedNormal for Distributions.jl
using Distributions
using IrrationalConstants
using LogExpFunctions
using Random
abstract type WrappedNormalDensityMethod end
struct WrappedMethod <: WrappedNormalDensityMethod end
struct JacobiMethod <: WrappedNormalDensityMethod end
"""
WrappedNormal(μ, σ; [tol])
The wrapped normal distribution with mean `μ` and concentration `κ` has probability density
function
```math
f(x; \\mu, \\sigma) = \\sum_{k=-n}^n g(x + 2 \\pi k; \\mu, \\sigma),
```
where ``g(x; \\mu, \\sigma)`` is the density function of the [`Normal(μ, σ)`](@ref)
distribution.
`tol` specifies the absolute error of the approximation to the density function and is used
to determine the series representation and number of terms used to compute the density
function.[^KurzGilitschenskiHanebeck2014]
## External links
- Wrapped normal distribution on Wikipedia (https://en.wikipedia.org/wiki/Wrapped_normal_distribution)
[^KurzGilitschenskiHanebeck2014]: G. Kurz, I. Gilitschenski, U. D. Hanebeck.
"Efficient evaluation of the probability density function of a wrapped normal distribution,"
2014 Sensor Data Fusion: Trends, Solutions, Applications (SDF), pp1-5.
doi: [10.1109/SDF.2014.6954713](https://doi.org/10.1109/SDF.2014.6954713).
arXiv: [1405.6397](https://arxiv.org/abs/1405.6397)
"""
struct WrappedNormal{T<:Real} <: ContinuousUnivariateDistribution
μ::T
σ::T
method::WrappedNormalDensityMethod
nterms::Int
end
function WrappedNormal(μ::T, σ::T; tol=eps(T)) where {T<:Real}
nw = _num_terms_wrappednormal_wrapped(μ, σ, tol)
nj = _num_terms_wrappednormal_jacobi(μ, σ, tol)
if nw < nj
return WrappedNormal{T}(μ, σ, WrappedMethod(), ceil(Int, nw))
else
return WrappedNormal{T}(μ, σ, JacobiMethod(), ceil(Int, nj))
end
end
WrappedNormal(μ::Real, σ::Real; kwargs...) = WrappedNormal(promote(μ, σ)...; kwargs...)
Base.show(io::IO, d::WrappedNormal) = show(io, d, (:μ, :σ))
# make sure density on real line integrates to 1, same as VonMises
Base.minimum(d::WrappedNormal) = d.μ - π
Base.maximum(d::WrappedNormal) = d.μ + π
Distributions.params(d::WrappedNormal) = (d.μ, d.σ)
Distributions.partype(::WrappedNormal{T}) where {T<:Real} = T
_std_angle(x::Real) = rem2pi(x, RoundNearest)
function Distributions.logpdf(d::WrappedNormal, x::Real)
z = _std_angle(x - d.μ)
if d.method isa WrappedMethod
return _logpdf_wrappednormal_wrapped(z, d.σ, d.nterms)
else
return _logpdf_wrappednormal_jacobi(z, d.σ, d.nterms)
end
end
function Distributions.logcdf(d::WrappedNormal, x::Real)
z = _std_angle(x - d.μ)
if d.method isa WrappedMethod
return _logcdf_wrappednormal_wrapped(z, d.σ, d.nterms)
else
return _logcdf_wrappednormal_jacobi(z, d.σ, d.nterms)
end
end
Distributions.cdf(d::WrappedNormal, x::Real) = exp(logcdf(d, x))
function _num_terms_wrappednormal_wrapped(μ, σ, ϵ)
T = Base.promote_eltype(μ, σ, ϵ)
return 1 + ((σ / sqrt2) / π) * max(1, sqrt(-logtwo - log(ϵ) - (3//2) * T(logπ)))
end
function _num_terms_wrappednormal_jacobi(μ, σ, ϵ)
T = Base.promote_eltype(μ, σ, ϵ)
return max(1, sqrt(-T(logtwo) / 2 - logπ - log(σ) - log(ϵ))) / σ
end
function _logpdf_wrappednormal_wrapped(z, σ, n)
T = Base.promote_eltype(z, σ)
logc = -T(log2π) / 2 - log(σ)
log_series = logsumexp((muladd(k, twoπ, z) / σ)^2 / -2 for k in (-n):n)
return log_series + logc
end
function _logcdf_wrappednormal_wrapped(z, σ, n)
T = Base.promote_eltype(z, σ)
dnorm = Normal{T}(zero(T), σ)
return logsumexp(logcdf(dnorm, muladd(k, twoπ, z)) for k in (-n):n)
end
function _logpdf_wrappednormal_jacobi(z, σ, n)
ρ = exp(-σ^2 / 2)
log_series = log1p(2 * sum(ρ^(k^2) * cos(k * z) for k in 1:n))
return log_series - log2π
end
function _logcdf_wrappednormal_jacobi(z, σ, n)
ρ = exp(-σ^2 / 2)
log_series = log(π + z + 2 * sum(ρ^(k^2) * sin(k * z) / k for k in 1:n))
return log_series - log2π
end
function Base.rand(rng::AbstractRNG, d::WrappedNormal)
return _std_angle(randn(rng) * d.σ) + d.μ
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment