Last active
November 1, 2019 17:41
-
-
Save zenna/3bbbd99e51582c87b74184de0ecf7e19 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module MiniOmega | |
using Cassette | |
import Base:~ | |
export sample, unif, pointwise, <|, rt | |
const ID = NTuple{N, Int} where N | |
"Ω is a hypercube" | |
struct Ω | |
data::Dict{ID, Real} | |
end | |
"Sample a random ω ∈ Ω" | |
sample(::Type{Ω}) = Ω(Dict{ID, Real}()) | |
Base.getindex(ω::Ω, id) = get!(ω.data, id, rand()) | |
"ΩProj is a projection of omega onto a particular dimension" | |
struct ΩProj | |
ω::Ω | |
id::ID | |
end | |
"Single primitive random variable" | |
unif(ωπ::ΩProj) = ωπ.ω[ωπ.id] | |
"Sample from ΩProj" | |
sample(::Type{ΩProj}) = ΩProj(sample(Ω), (1,)) | |
"Project ω onto id" | |
proj(ωπ::ΩProj, id) = ΩProj(ωπ.ω, (ωπ.id..., id)) | |
sample(f) = f(sample(ΩProj)) | |
# (Conditional )independence | |
# Use cassette to augment enivonrment with extra state | |
Cassette.@context CIIDCtx | |
"""Conditionally independent identically distributed given shared | |
Returns the `i`th element of an (exchangeable) sequence of random variables | |
that are identically distributed with `f` but conditionally independent given | |
random variables in `shared`. | |
""" | |
function ciid(f, id, shared) | |
let ctx = CIIDCtx(metadata = (shared = shared, id = id)) | |
ω -> Cassette.overdub(ctx, f, ω) | |
end | |
end | |
"IID: `ciid` with nothing shared" | |
ciid(f, id::Integer) = ciid(f, id, ()) | |
Cassette.overdub(ctx::CIIDCtx, ::typeof(unif), ω::ΩProj) = | |
unif(proj(ω, ctx.metadata.id)) | |
function Cassette.overdub(ctx::CIIDCtx, x, ω::ΩProj) | |
if x in ctx.metadata.shared | |
x(ω) | |
else | |
Cassette.recurse(ctx, x, ω) | |
end | |
end | |
# Syntactic Sugar (to make model-building nicer) | |
"Random tuple" | |
rt(fs...) = ω -> map(f -> f(ω), fs) | |
"""Supports notation `i ~ x <| (y,z,)` | |
which is the ith element of an (exchangeable) sequence of random variables that are | |
identically distributed with x but conditionally independent given y and z. | |
""" | |
struct Plate{F, S} | |
f::F | |
shared::S | |
end | |
@inline <|(f, shared::Tuple) = Plate(f, shared) | |
~(id::Integer, f) = ciid(f, id) | |
~(id::Integer, plate::Plate) = ciid(plate.f, id, plate.shared) | |
# Pointwise | |
Cassette.@context PWCtx | |
Lifted = Union{map(typeof, (+, -, /, *))...} | |
Cassette.overdub(::PWCtx, op::Lifted, x::Function) = ω -> op(x(ω)) | |
Cassette.overdub(::PWCtx, op::Lifted, x::Function, y::Function) = ω -> op(x(ω), y(ω)) | |
Cassette.overdub(::PWCtx, op::Lifted, x::Function, y) = ω -> op(x(ω), y) | |
Cassette.overdub(::PWCtx, op::Lifted, x, y::Function) = ω -> op(x, y(ω)) | |
pointwise(f) = Cassette.overdub(PWCtx(), f) | |
# AutoId | |
Cassette.@context AutoIdCtx | |
end | |
using .MiniOmega | |
function test() | |
uniform(a, b) = ω -> unif(ω) * (b - a) + a | |
x = 1 ~ uniform(0, 1) | |
y = 2 ~ uniform(0, 1) | |
d = 7 ~ y | |
z = ω -> (x(ω), x(ω), y(ω), d(ω)) | |
@show sample(z) | |
function a(ω) | |
x_ = x(ω) | |
d = 3 ~ uniform(0, 4) | |
e = 4 ~ uniform(0, 4) | |
x_ + d(ω) + e(ω) | |
end | |
@show sample(rt(x, y, z, a)) | |
# Conditional Independence | |
x = 1 ~ uniform(0, 19) | |
measure(ω) = x(ω) + uniform(0, 1)(ω) | |
m1 = 3 ~ measure <| (x,) | |
m2 = 4 ~ measure <| (x,) | |
m3 = 5 ~ measure | |
@show sample(rt(m1, m2, m3)) | |
# Linaer regression | |
α = 1 ~ uniform(0, 1) | |
β = 2 ~ uniform(0, 1) | |
f(x, i) = i ~ (ω -> x*α(ω) + β(ω) + uniform(0.0, 1.0)(ω)) <| (α, β) | |
y1 = f(0.3, 3) | |
y2 = f(0.3, 4) | |
@show sample(rt(y1, y2)) | |
y1b, y2b = pointwise() do | |
f2(x, i) = i ~ (x * α + β + uniform(0.0, 1.0)) <| (α, β) | |
y1b = f2(0.3, 3) | |
y2b = f2(0.3, 4) | |
y1b, y2b | |
end | |
@show sample(rt(y1, y2, y1b, y2b)) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment