Skip to content

Instantly share code, notes, and snippets.

@zenna
Created March 24, 2019 08:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zenna/94cea3c455a6a6c0a574cb240d9a5a29 to your computer and use it in GitHub Desktop.
Save zenna/94cea3c455a6a6c0a574cb240d9a5a29 to your computer and use it in GitHub Desktop.
module MiniOmega
using Random
# Ω Subtypes Random.AbstractRNG to be compatible with existing samplers
mutable struct Ω <: Random.AbstractRNG
data::Dict{Int, Any} # Map ids to random values
i::Int # Current id
logscore::Float64
end
# An empty mapping
Ω() = Ω(Dict(), 0, log(1.0))
# Retrieve value if index already exists, otherwise sample a value
sample::Ω, args...) =.i += 1; get!.data, ω.i, rand(args...)))
# Compatability with rand (and resolve type ambiguities)
Base.rand::Ω, T::Type{X}) where X = sample(ω, T)
Base.rand::Ω, T::Type{X}, dims::Dims) where X = sample(ω, T, dims)
"A Random Variable"
struct RandVar
f::Function
end
# Reset counter prior to function application
(rv::RandVar)(w:) = (w.i = 0; rv.f(w))
Base.rand(x::RandVar) = x(Ω())
# Return Bottom if condition unsatisfeid
cond(x::RandVar, y::RandVar) = RandVar(rng -> y(rng) ? x(rng) : error())
"Random Sample from `x`"
Base.rand(x::RandVar) = x(Ω())
"Rejetion Sampling"
randrs(x::RandVar) = try x(Ω()) catch; rand(x) end
"Update score"
score(ω, logval) = ω.logscore += logval
export RandVar, Ω, score
end
## Example
using Main.MiniOmega
x_(rng) = rand(rng)
x = RandVar(x_)
ω = Ω()
@show x(ω)
@show x(ω)
using Distributions
fakedata = rand(Normal(0, 1), 10)
function xs_(ω)
xs = quantile.(Normal(0,1), [rand(ω) for i = 1:10])
for x in xs
val = pdf(Normal(0, 1),x)
score(ω, val)
end
xs
end
xs = RandVar(xs_)
ω = Ω()
xs(ω)
ω.logscore
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment