Skip to content

Instantly share code, notes, and snippets.

@sharanry
Created August 25, 2019 07:03
Show Gist options
  • Save sharanry/27d1616bc3c4a4f0a3bfd70c286dcc0b to your computer and use it in GitHub Desktop.
Save sharanry/27d1616bc3c4a4f0a3bfd70c286dcc0b to your computer and use it in GitHub Desktop.
diff --git a/src/interface.jl b/src/interface.jl
index 4536bda..1581def 100644
--- a/src/interface.jl
+++ b/src/interface.jl
@@ -98,7 +98,7 @@ julia> cb = b ∘ b;
julia> x = randn(2, 3)
2×3 Array{Float64,2}:
- 0.0660476 -0.77195 -1.7832
+ 0.0660476 -0.77195 -1.7832
-0.147743 -1.46459 0.264924
julia> forward(cb, x)
@@ -107,10 +107,10 @@ ERROR: MethodError: no method matching +(::Array{Float64,1}, ::Float64)
julia> forward(cb, x, zeros(size(x, 2)))
(rv = [1.10887 0.32029 -0.704563; -0.639206 -1.97935 -0.243419], logabsdetjac = [0.018534, 1.46352e-5, 0.00521633])
```
-
+
"""
forward(b::Bijector, x) = forward(b, x, zero(eltype(x)))
-forward(b::Bijector, x, logjac) = (rv=b(x), logabsdetjac=logjac + logabsdetjac(b, x))
+forward(b::Bijector, x, logjac) = (rv=b(x), logabsdetjac=logjac .+ logabsdetjac(b, x))
forward(ib::Inversed{<: Bijector}, y) = (
rv=ib(y),
logabsdetjac=logabsdetjac(ib, y)
@@ -227,11 +227,11 @@ logabsdetjac(cb::Composed, x) = _logabsdetjac(x, cb.ts...)
function _forward(f, b1::Bijector, b2::Bijector)
f1 = forward(b1, f.rv)
f2 = forward(b2, f1.rv)
- return (rv=f2.rv, logabsdetjac=f2.logabsdetjac + f1.logabsdetjac + f.logabsdetjac)
+ return (rv=f2.rv, logabsdetjac=f2.logabsdetjac .+ f1.logabsdetjac .+ f.logabsdetjac)
end
function _forward(f, b::Bijector, bs::Bijector...)
f1 = forward(b, f.rv)
- f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac + f.logabsdetjac)
+ f_ = (rv=f1.rv, logabsdetjac=f1.logabsdetjac .+ f.logabsdetjac)
return _forward(f_, bs...)
end
# if `x` represents multiple elements to act on, we want to allow the user to
@@ -244,7 +244,7 @@ end
function forward(cb::Composed, x, logjac)
rv = x
logjac_ = logjac
-
+
for t in cb.ts
res = forward(t, rv)
rv = res.rv
@@ -308,17 +308,23 @@ struct Shift{T} <: Bijector
a::T
end
-(b::Shift)(x) = b.a + x
+Shift(dims::Int, container=Array) = Shift(container(zeros(dims, 1)))
+
+(b::Shift)(x::T) where T<:AbstractArray = b.a .+ x
inv(b::Shift) = Shift(-b.a)
-logabsdetjac(b::Shift, x::T) where T = zero(T)
+logabsdetjac(b::Shift, x::T) where T<:Real = zero(T)
+logabsdetjac(b::Shift, x::T) where T<:AbstractArray = zeros(eltype(x), size(x, 2))
struct Scale{T} <: Bijector
a::T
end
+Scale(dims::Int, container=Array) = Scale(container(one(randn(dims, dims))))
+
(b::Scale)(x) = b.a * x
-inv(b::Scale) = Scale(b.a^(-1))
-logabsdetjac(b::Scale, x) = log(abs(b.a))
+inv(b::Scale) = Scale(inv(b.a))
+logabsdetjac(b::Scale, x::T) where T<: Real = log(abs(b.a))
+logabsdetjac(b::Scale, x::T) where T<: AbstractArray = ones(size(x, 2)) .* log(abs(det(b.a)))
####################
# Simplex bijector #
@@ -404,7 +410,7 @@ function (ib::Inversed{<: SimplexBijector{Val{proj}}})(y::AbstractVector{T}) whe
else
x[K] = _clamp(one(T) - sum_tmp - y[K], ib.orig)
end
-
+
return x
end
@@ -438,7 +444,7 @@ end
function logabsdetjac(b::SimplexBijector, x::AbstractVector{T}) where T
ϵ = _eps(T)
lp = zero(T)
-
+
K = length(x)
sum_tmp = zero(eltype(x))
@@ -460,7 +466,7 @@ end
DistributionBijector(d::Distribution)
DistributionBijector{<: ADBackend, D}(d::Distribution)
-This is the default `Bijector` for a distribution.
+This is the default `Bijector` for a distribution.
It uses `link` and `invlink` to compute the transformations, and `AD` to compute
the `jacobian` and `logabsdetjac`.
@@ -503,7 +509,7 @@ const Transformed = TransformedDistribution
Couples distribution `d` with the bijector `b` by returning a `TransformedDistribution`.
-If no bijector is provided, i.e. `transformed(d)` is called, then
+If no bijector is provided, i.e. `transformed(d)` is called, then
`transformed(d, bijector(d))` is returned.
"""
transformed(d::Distribution, b::Bijector) = TransformedDistribution(d, b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment