Skip to content

Instantly share code, notes, and snippets.

@zenna

zenna/ciid.jl

Last active Nov 1, 2019
Embed
What would you like to do?
module MiniOmega
using Cassette
import Base:~
export sample, unif, pointwise, <|, rt
const ID = NTuple{N, Int} where N
"Ω is a hypercube"
struct Ω
data::Dict{ID, Real}
end
"Sample a random ω ∈ Ω"
sample(::Type{Ω}) = Ω(Dict{ID, Real}())
Base.getindex::Ω, id) = get!.data, id, rand())
"ΩProj is a projection of omega onto a particular dimension"
struct ΩProj
ω::Ω
id::ID
end
"Single primitive random variable"
unif(ωπ::ΩProj) = ωπ.ω[ωπ.id]
"Sample from ΩProj"
sample(::Type{ΩProj}) = ΩProj(sample(Ω), (1,))
"Project ω onto id"
proj(ωπ::ΩProj, id) = ΩProj(ωπ.ω, (ωπ.id..., id))
sample(f) = f(sample(ΩProj))
# (Conditional )independence
# Use cassette to augment enivonrment with extra state
Cassette.@context CIIDCtx
"""Conditionally independent identically distributed given shared
Returns the `i`th element of an (exchangeable) sequence of random variables
that are identically distributed with `f` but conditionally independent given
random variables in `shared`.
"""
function ciid(f, id, shared)
let ctx = CIIDCtx(metadata = (shared = shared, id = id))
ω -> Cassette.overdub(ctx, f, ω)
end
end
"IID: `ciid` with nothing shared"
ciid(f, id::Integer) = ciid(f, id, ())
Cassette.overdub(ctx::CIIDCtx, ::typeof(unif), ω::ΩProj) =
unif(proj(ω, ctx.metadata.id))
function Cassette.overdub(ctx::CIIDCtx, x, ω::ΩProj)
if x in ctx.metadata.shared
x(ω)
else
Cassette.recurse(ctx, x, ω)
end
end
# Syntactic Sugar (to make model-building nicer)
"Random tuple"
rt(fs...) = ω -> map(f -> f(ω), fs)
"""Supports notation `i ~ x <| (y,z,)`
which is the ith element of an (exchangeable) sequence of random variables that are
identically distributed with x but conditionally independent given y and z.
"""
struct Plate{F, S}
f::F
shared::S
end
@inline <|(f, shared::Tuple) = Plate(f, shared)
~(id::Integer, f) = ciid(f, id)
~(id::Integer, plate::Plate) = ciid(plate.f, id, plate.shared)
# Pointwise
Cassette.@context PWCtx
Lifted = Union{map(typeof, (+, -, /, *))...}
Cassette.overdub(::PWCtx, op::Lifted, x::Function) = ω -> op(x(ω))
Cassette.overdub(::PWCtx, op::Lifted, x::Function, y::Function) = ω -> op(x(ω), y(ω))
Cassette.overdub(::PWCtx, op::Lifted, x::Function, y) = ω -> op(x(ω), y)
Cassette.overdub(::PWCtx, op::Lifted, x, y::Function) = ω -> op(x, y(ω))
pointwise(f) = Cassette.overdub(PWCtx(), f)
# AutoId
Cassette.@context AutoIdCtx
end
using .MiniOmega
function test()
uniform(a, b) = ω -> unif(ω) * (b - a) + a
x = 1 ~ uniform(0, 1)
y = 2 ~ uniform(0, 1)
d = 7 ~ y
z = ω -> (x(ω), x(ω), y(ω), d(ω))
@show sample(z)
function a(ω)
x_ = x(ω)
d = 3 ~ uniform(0, 4)
e = 4 ~ uniform(0, 4)
x_ + d(ω) + e(ω)
end
@show sample(rt(x, y, z, a))
# Conditional Independence
x = 1 ~ uniform(0, 19)
measure(ω) = x(ω) + uniform(0, 1)(ω)
m1 = 3 ~ measure <| (x,)
m2 = 4 ~ measure <| (x,)
m3 = 5 ~ measure
@show sample(rt(m1, m2, m3))
# Linaer regression
α = 1 ~ uniform(0, 1)
β = 2 ~ uniform(0, 1)
f(x, i) = i ~-> x*α(ω) + β(ω) + uniform(0.0, 1.0)(ω)) <| (α, β)
y1 = f(0.3, 3)
y2 = f(0.3, 4)
@show sample(rt(y1, y2))
y1b, y2b = pointwise() do
f2(x, i) = i ~ (x * α + β + uniform(0.0, 1.0)) <| (α, β)
y1b = f2(0.3, 3)
y2b = f2(0.3, 4)
y1b, y2b
end
@show sample(rt(y1, y2, y1b, y2b))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment