Created
September 19, 2023 17:53
-
-
Save sethaxen/907a4a6eb0460c5fe0cd4f1039dd56d7 to your computer and use it in GitHub Desktop.
Custom bijectors in Turing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions, Bijectors, Random, FillArrays, LogExpFunctions | |
struct DistributionWithTransform{D<:ContinuousDistribution,B,V<:VariateForm} <: Distribution{V,Continuous} | |
dist::D | |
transform::B | |
function DistributionWithTransform(d::ContinuousDistribution, b) | |
F = Distributions.variate_form(typeof(d)) | |
return new{typeof(d),typeof(b),F}(d, b) | |
end | |
end | |
with_transform(d::Distribution, b) = DistributionWithTransform(d, b) | |
Bijectors.bijector(d::DistributionWithTransform) = inverse(d.transform) | |
Distributions.logpdf(d::DistributionWithTransform, x) = logpdf(d.dist, x) | |
Random.rand(rng::AbstractRNG, d::DistributionWithTransform) = rand(rng, d.dist) | |
struct NonCentered end | |
# special-case MvNormal to avoid redundant computation | |
function with_transform(d::MvNormal{T}, ::NonCentered) where {T} | |
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(cholesky(d.Σ).L) | |
d_stdnorm = Distributions.MvNormal(Zeros(T, length(d)), I) | |
return transformed(d_stdnorm, b) | |
end | |
function with_transform(d::MvNormalCanon{T}, ::NonCentered) where {T} | |
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(inv(cholesky(d.J).U)) | |
d_stdnorm = Distributions.MvNormal(Zeros(T, length(d)), I) | |
return transformed(d_stdnorm, b) | |
end | |
function with_transform(d::Normal{T}, ::NonCentered) where {T} | |
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(std(d)) | |
d_stdnorm = Normal(zero(T), one(T)) | |
return transformed(d_stdnorm, b) | |
end | |
# exponential with softplus transform instead of exponential | |
@model function foo() | |
x ~ with_transform(Exponential(), softplus) | |
end | |
chns = sample(foo(), NUTS(0.8), MCMCThreads(), 1_000, 4) | |
# eight-schools | |
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0] | |
σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0] | |
@model function eight_schools(y, σ; J=length(y), transform=identity) | |
μ ~ Normal(0, 5) | |
τ ~ truncated(Cauchy(0, 5); lower=0) | |
θ ~ filldist(Normal(μ, τ), J) | |
y .~ with_transform.(Normal.(θ, σ), Ref(transform)) | |
end | |
## compare centered vs non-centered | |
chns = sample(eight_schools(y, σ), NUTS(0.8), MCMCThreads(), 1_000, 4) | |
mean(chns[:numerical_error]) | |
chns = sample(eight_schools(y, σ; transform=NonCentered()), NUTS(0.8), MCMCThreads(), 1_000, 4) | |
mean(chns[:numerical_error]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I was a bit suspicious about the log jacobian handling, so I've added the "usual" non-centering to the model as a hack and plotted the pair plots of log tau and theta[1], see below: