Skip to content

Instantly share code, notes, and snippets.

@sethaxen
Created September 19, 2023 17:53
Show Gist options
  • Save sethaxen/907a4a6eb0460c5fe0cd4f1039dd56d7 to your computer and use it in GitHub Desktop.
Save sethaxen/907a4a6eb0460c5fe0cd4f1039dd56d7 to your computer and use it in GitHub Desktop.
Custom bijectors in Turing
using Distributions, Bijectors, Random, FillArrays, LogExpFunctions
struct DistributionWithTransform{D<:ContinuousDistribution,B,V<:VariateForm} <: Distribution{V,Continuous}
dist::D
transform::B
function DistributionWithTransform(d::ContinuousDistribution, b)
F = Distributions.variate_form(typeof(d))
return new{typeof(d),typeof(b),F}(d, b)
end
end
with_transform(d::Distribution, b) = DistributionWithTransform(d, b)
Bijectors.bijector(d::DistributionWithTransform) = inverse(d.transform)
Distributions.logpdf(d::DistributionWithTransform, x) = logpdf(d.dist, x)
Random.rand(rng::AbstractRNG, d::DistributionWithTransform) = rand(rng, d.dist)
struct NonCentered end
# special-case MvNormal to avoid redundant computation
function with_transform(d::MvNormal{T}, ::NonCentered) where {T}
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(cholesky(d.Σ).L)
d_stdnorm = Distributions.MvNormal(Zeros(T, length(d)), I)
return transformed(d_stdnorm, b)
end
function with_transform(d::MvNormalCanon{T}, ::NonCentered) where {T}
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(inv(cholesky(d.J).U))
d_stdnorm = Distributions.MvNormal(Zeros(T, length(d)), I)
return transformed(d_stdnorm, b)
end
function with_transform(d::Normal{T}, ::NonCentered) where {T}
b = Bijectors.Shift(mean(d)) ∘ Bijectors.Scale(std(d))
d_stdnorm = Normal(zero(T), one(T))
return transformed(d_stdnorm, b)
end
# exponential with softplus transform instead of exponential
@model function foo()
x ~ with_transform(Exponential(), softplus)
end
chns = sample(foo(), NUTS(0.8), MCMCThreads(), 1_000, 4)
# eight-schools
y = [28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0]
σ = [15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0]
@model function eight_schools(y, σ; J=length(y), transform=identity)
μ ~ Normal(0, 5)
τ ~ truncated(Cauchy(0, 5); lower=0)
θ ~ filldist(Normal(μ, τ), J)
y .~ with_transform.(Normal.(θ, σ), Ref(transform))
end
## compare centered vs non-centered
chns = sample(eight_schools(y, σ), NUTS(0.8), MCMCThreads(), 1_000, 4)
mean(chns[:numerical_error])
chns = sample(eight_schools(y, σ; transform=NonCentered()), NUTS(0.8), MCMCThreads(), 1_000, 4)
mean(chns[:numerical_error])
@nsiccha
Copy link

nsiccha commented Sep 21, 2023

The model should read

@model function eight_schools(y, σ; J=length(y), transform=identity)
    μ ~ Normal(0, 5)
    τ ~ truncated(Cauchy(0, 5); lower=0)
    θ ~ filldist(with_transform(Normal(μ, τ), transform), J)
    y .~ Normal.(θ, σ)
end

@nsiccha
Copy link

nsiccha commented Sep 21, 2023

I was a bit suspicious about the log jacobian handling, so I've added the "usual" non-centering to the model as a hack and plotted the pair plots of log tau and theta[1], see below:

@model function eight_schools(y, σ; J=length(y), transform=identity)
    μ ~ Normal(0, 5)
    τ ~ truncated(Cauchy(0, 5); lower=0)
    if ismissing(transform)
        theta_raw ~ filldist(Normal(0,1), J)
        θ = μ .+ τ .* theta_raw
    else
        θ ~ filldist(with_transform(Normal(μ, τ), transform), J)
    end
    y .~ Normal.(θ, σ)
end

using Plots
## compare centered vs non-centered
chns = sample(eight_schools(y, σ), NUTS(0.8), MCMCThreads(), 1_000, 4)
sum(chns[:numerical_error]) |> println
p1 = scatter(vec(chns["θ[1]"]), log.(vec(chns[])), alpha=.1)

chns = sample(eight_schools(y, σ; transform=NonCentered()), NUTS(0.8), MCMCThreads(), 1_000, 4)
sum(chns[:numerical_error]) |> println
p2 = scatter(vec(chns["θ[1]"]), log.(vec(chns[])), alpha=.1)

chns = sample(eight_schools(y, σ; transform=missing), NUTS(0.8), MCMCThreads(), 1_000, 4)
sum(chns[:numerical_error]) |> println
t1 = chns["μ"] .+ chns["τ"] .* chns["theta_raw[1]"]
p3 = scatter(vec(t1), log.(vec(chns[])), alpha=.1)
plot(p1, p2, p3, size=(800, 800), layout=(3, 1), link=:both) # link=:both actually only links x for me?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment