Created
September 19, 2023 17:53
-
-
Save sethaxen/907a4a6eb0460c5fe0cd4f1039dd56d7 to your computer and use it in GitHub Desktop.
Custom bijectors in Turing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributions, Bijectors, Random, FillArrays, LogExpFunctions | |
struct DistributionWithTransform{D<:ContinuousDistribution,B,V<:VariateForm} <: Distribution{V,Continuous} | |
dist::D | |
transform::B | |
function DistributionWithTransform(d::ContinuousDistribution, b) | |
F = Distributions.variate_form(typeof(d)) | |
return new{typeof(d),typeof(b),F}(d, b) | |
end | |
end | |
with_transform(d::Distribution, b) = DistributionWithTransform(d, b) | |
Bijectors.bijector(d::DistributionWithTransform) = inverse(d.transform) | |
Distributions.logpdf(d::DistributionWithTransform, x) = logpdf(d.dist, x) | |
Random.rand(rng::AbstractRNG, d::DistributionWithTransform) = rand(rng, d.dist) | |
struct NonCentered end | |
# special-case MvNormal to avoid redundant computation | |
function with_transform(d::MvNormal{T}, ::NonCentered) where {T} | |
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(cholesky(d.Σ).L) | |
d_stdnorm = Distributions.MvNormal(Zeros(T, length(d)), I) | |
return transformed(d_stdnorm, b) | |
end | |
function with_transform(d::MvNormalCanon{T}, ::NonCentered) where {T} | |
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(inv(cholesky(d.J).U)) | |
d_stdnorm = Distributions.MvNormal(Zeros(T, length(d)), I) | |
return transformed(d_stdnorm, b) | |
end | |
function with_transform(d::Normal{T}, ::NonCentered) where {T} | |
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(std(d)) | |
d_stdnorm = Normal(zero(T), one(T)) | |
return transformed(d_stdnorm, b) | |
end | |
# exponential with softplus transform instead of exponential | |
@model function foo() | |
x ~ with_transform(Exponential(), softplus) | |
end | |
chns = sample(foo(), NUTS(0.8), MCMCThreads(), 1_000, 4) | |
# eight-schools | |
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0] | |
σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0] | |
@model function eight_schools(y, σ; J=length(y), transform=identity) | |
μ ~ Normal(0, 5) | |
τ ~ truncated(Cauchy(0, 5); lower=0) | |
θ ~ filldist(Normal(μ, τ), J) | |
y .~ with_transform.(Normal.(θ, σ), Ref(transform)) | |
end | |
## compare centered vs non-centered | |
chns = sample(eight_schools(y, σ), NUTS(0.8), MCMCThreads(), 1_000, 4) | |
mean(chns[:numerical_error]) | |
chns = sample(eight_schools(y, σ; transform=NonCentered()), NUTS(0.8), MCMCThreads(), 1_000, 4) | |
mean(chns[:numerical_error]) |
I was a bit suspicious about the log jacobian handling, so I've added the "usual" non-centering to the model as a hack and plotted the pair plots of log tau and theta[1], see below:
@model function eight_schools(y, σ; J=length(y), transform=identity)
μ ~ Normal(0, 5)
τ ~ truncated(Cauchy(0, 5); lower=0)
if ismissing(transform)
theta_raw ~ filldist(Normal(0,1), J)
θ = μ .+ τ .* theta_raw
else
θ ~ filldist(with_transform(Normal(μ, τ), transform), J)
end
y .~ Normal.(θ, σ)
end
using Plots
## compare centered vs non-centered
chns = sample(eight_schools(y, σ), NUTS(0.8), MCMCThreads(), 1_000, 4)
sum(chns[:numerical_error]) |> println
p1 = scatter(vec(chns["θ[1]"]), log.(vec(chns[:τ])), alpha=.1)
chns = sample(eight_schools(y, σ; transform=NonCentered()), NUTS(0.8), MCMCThreads(), 1_000, 4)
sum(chns[:numerical_error]) |> println
p2 = scatter(vec(chns["θ[1]"]), log.(vec(chns[:τ])), alpha=.1)
chns = sample(eight_schools(y, σ; transform=missing), NUTS(0.8), MCMCThreads(), 1_000, 4)
sum(chns[:numerical_error]) |> println
t1 = chns["μ"] .+ chns["τ"] .* chns["theta_raw[1]"]
p3 = scatter(vec(t1), log.(vec(chns[:τ])), alpha=.1)
plot(p1, p2, p3, size=(800, 800), layout=(3, 1), link=:both) # link=:both actually only links x for me?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The model should read