Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Created June 12, 2019 01:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save torfjelde/337a959deea5bae4075c99a668a9213f to your computer and use it in GitHub Desktop.
Save torfjelde/337a959deea5bae4075c99a668a9213f to your computer and use it in GitHub Desktop.
Possible Bijectors.jl interface
using Distributions, Bijectors
using ForwardDiff
using Tracker
using Turing
import Random: AbstractRNG
import Distributions: logpdf, rand, rand!, _rand!, _logpdf
abstract type Bijector end
abstract type CustomBijector{AD} <: Bijector end
"Computes the transformation."
transform(b::Bijector, x) = begin end
transform(b::Bijector) = x -> transform(b, x)
"Computes the inverse transformation of the Bijector."
inverse(b::Bijector, y) = begin end
inverse(b::Bijector) = y -> inverse(b, y)
# TODO: rename? a bit of a mouthful
# TODO: allow batch-computation, especially for univariate case
"Computes the determinant of the Jacobian of the inverse-transformation."
detloginvjac(b::Bijector, y) = begin end
detloginvjac(b::CustomBijector{AD}, y::T) where {AD <: Turing.Core.ForwardDiffAD, T <: Real} = log(abs(ForwardDiff.derivative(z -> inverse(b, z), y)))
detloginvjac(b::CustomBijector{AD}, y::AbstractVector) where AD <: Turing.Core.ForwardDiffAD = logabsdet(ForwardDiff.jacobian(z -> inverse(b, z), y))[1]
# FIXME: untrack? i.e. `Tracker.data(...)`
detloginvjac(b::CustomBijector{AD}, y::T) where {AD <: Turing.Core.TrackerAD, T <: Real} = log(abs(Tracker.gradient(z -> inverse(b, z[1]), [y])[1][1]))
detloginvjac(b::CustomBijector{AD}, y::AbstractVector) where AD <: Turing.Core.TrackerAD = logabsdet(Tracker.jacobian(z -> inverse(b, z), y))[1]
# Example bijector
struct Identity <: Bijector end
transform(::Identity, x) = x
inverse(::Identity, y) = y
detloginvjac(::Identity, y::T) where T <: Real = one(T)
detloginvjac(::Identity, y::AbstractVector{T}) where T <: Real = one(T)
# Transformed distributions
struct UnivariateTransformed{D, B} <: Distribution{Univariate, Continuous} where {D <: UnivariateDistribution, B <: Bijector}
dist::D
transform::B
end
struct MultivariateTransformed{D, B} <: Distribution{Multivariate, Continuous} where {D <: MultivariateDistribution, B <: Bijector}
dist::D
transform::B
end
# implement these on a case-by-case basis, e.g. `PDMatDistribution = Union{InverseWishart, Wishart}`
transformed(d::UnivariateDistribution, b::Bijector) = UnivariateTransformed(d, b)
transformed(d::MultivariateDistribution, b::Bijector) = MultivariateTransformed(d, b)
# Example of specific distribution impl
transformed(d::Normal) = transformed(d, Identity())
# size
Base.length(td::MultivariateTransformed) = length(td.dist)
# logp
logpdf(td::UnivariateTransformed, x::T where T <: Real) = begin
logpdf(td.dist, inverse(td.transform, x)) .+ detloginvjac(td.transform, x)
end
_logpdf(td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) = begin
logpdf(td.dist, inverse(td.transform, x)) .+ detloginvjac(td.transform, x)
end
logpdf_with_jac(td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) = begin
z = detloginvjac(td.transform, x)
return (logpdf(td.dist, inverse(td.transform, x)) .+ z, z)
end
# rand
rand(rng::AbstractRNG, td::UnivariateTransformed) = transform(td.transform, rand(td.dist))
_rand!(rng::AbstractRNG, td::MultivariateTransformed, x::AbstractVector{T} where T <: Real) = begin
rand!(rng, td.dist, x)
y = transform(td.transform, x)
copyto!(x, y)
end
########################
# Some simple examples #
########################
# Example: `Univariate`
d = Normal()
td = transformed(d)
@info "Univariate" rand(td, 10)
@info "Univariate" logpdf.(td, rand(td, 10))
# Example: `Multivariate`
d = MvNormal(zeros(5), ones(5))
td = transformed(d, Identity())
@info "Multivariate" rand(td, 10)
@info "Multivariate" logpdf(td, rand(td))
x = rand(td)
@assert transform(td.transform, x) == x
# Example: Custom stuff
struct PositiveTransform{AD} <: CustomBijector{AD} end
transform(::PositiveTransform, x) = log(x)
inverse(::PositiveTransform, y) = exp(y)
t = PositiveTransform{Turing.Core.ADBackend(:reverse_diff)}()
# t = PositiveTransform{Turing.Core.ADBackend(:forward_diff)}() # <= also works
d = InverseGamma()
td = transformed(d, t)
rand(td)
logpdf.(td, rand(td, 10))
inverse(td.transform, rand(td))
detloginvjac(td.transform, rand(td))
log(abs(Tracker.gradient(z -> inverse(td.transform, z[1]), [y])[1][1]))
x = rand(td.dist)
y = transform(td.transform, x)
@assert inverse(td.transform)(transform(td.transform, x)) == x "f ∘ f⁻¹ ≠ identity"
@assert logpdf(td, y) == logpdf(td.dist, x) + log(x) "autodiff detloginvjac ≠ true detloginvjac"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment