Skip to content

Instantly share code, notes, and snippets.

@ztangent
Created September 21, 2021 19:52
Show Gist options
  • Save ztangent/f2b718788de38dfa7b4a03ed36ccb2bb to your computer and use it in GitHub Desktop.
Save ztangent/f2b718788de38dfa7b4a03ed36ccb2bb to your computer and use it in GitHub Desktop.
Gen Truncated Normal Distribution
using Gen
using Distributions: truncated
struct TruncatedNormal <: Gen.Distribution{Float64} end
"""
trunc_normal(mu::Real, std::Real, lb::Real, ub::Real)
Samples a `Float64` value from a normal distribution.
"""
const trunc_normal = TruncatedNormal()
(d::TruncatedNormal)(mu, std, lb, ub) = Gen.random(d, mu, std, lb, ub)
Gen.random(::TruncatedNormal, mu::Real, std::Real, lb::Real, ub::Real) =
rand(truncated(Distributions.Normal(mu, std), lb, ub))
function Gen.logpdf(::TruncatedNormal, x::Real, mu::Real, std::Real, lb::Real, ub::Real)
d = truncated(Distributions.Normal(mu, std), lb, ub)
untrunc_lpdf = Distributions.logpdf(d.untruncated, x)
if d.tp > 0
untrunc_lpdf - d.logtp
elseif cdf(d.untruncated, lb) ≈ 1.0
untrunc_lpdf - Distributions.logccdf(d.untruncated, lb)
elseif cdf(d.untruncated, ub) ≈ 0.0
untrunc_lpdf - Distributions.logcdf(d.untruncated, ub)
end
end
Gen.is_discrete(::TruncatedNormal) = false
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment