Skip to content

Instantly share code, notes, and snippets.

@dharasim
Created March 27, 2018 16:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dharasim/33021a0b667b29bf9bdd511fd0033d54 to your computer and use it in GitHub Desktop.
Save dharasim/33021a0b667b29bf9bdd511fd0033d54 to your computer and use it in GitHub Desktop.
using StatsFuns.RFunctions: betarand, gammarand
using LogProbs
function categorical_sample(tokens, weights)
x = rand() * sum(weights)
cum_weights = zero(eltype(weights))
for (t, w) in zip(tokens, weights)
cum_weights += w
if cum_weights > x
return t
end
end
end
categorical_sample(d::Dict) = categorical_sample(keys(d), values(d))
categorical_sample(v::Vector) = categorical_sample(1:length(v), v)
mutable struct DirCat{T, C}
counts :: Dict{T, C}
end
DirCat(support, priors) = DirCat(Dict(x => p for (x,p) in zip(support, priors)))
support(dc::DirCat) = keys(dc.counts)
function sample(dc::DirCat)
weights = [gammarand(c, 1) for c in values(dc.counts)]
categorical_sample(keys(dc.counts), weights)
end
function logscore(dc::DirCat, obs)
LogProb(lbeta(sum(values(dc.counts)), 1) - lbeta(dc.counts[obs], 1))
end
function add_obs!(dc::DirCat, obs)
dc.counts[obs] += 1
end
function rm_obs!(dc::DirCat, obs)
dc.counts[obs] -= 1
end
mutable struct DirMul{T, C}
counts :: Dict{T, C}
end
DirMul(support) = DirMul(Dict(x => 1.0 for x in support))
support(dm::DirMul) = keys(dm.counts)
function sample(dm::DirMul, n)
weights = [gammarand(c, 1) for c in values(dm.counts)]
d = Dict(k=>0 for k in keys(dm.counts))
for i in 1:n
d[categorical_sample(keys(dm.counts), weights)] += 1
end
d
end
function logscore(dm::DirMul, obs::Associative)
n = sum(values(obs))
LogProb(
log(n) + lbeta(sum(values(dm.counts)), n) -
sum(
log(obs[x]) + lbeta(dm.counts[x], obs[x])
for x in keys(dm.counts) if obs[x] > 0
)
)
end
function add_obs!(dm::DirMul, obs::Associative)
for x in keys(obs)
dm.counts[x] += obs[x]
end
end
function rm_obs!(dm::DirMul, obs::Associative)
for x in keys(obs)
dm.counts[x] -= obs[x]
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment