Skip to content

Instantly share code, notes, and snippets.

@johnmyleswhite
Last active December 17, 2015 20:19
Show Gist options
  • Save johnmyleswhite/5666467 to your computer and use it in GitHub Desktop.
Save johnmyleswhite/5666467 to your computer and use it in GitHub Desktop.
Estimating the parameters of a Dirichlet in Julia
using Distributions
d = Dirichlet([100.0, 17.0, 31.0, 45.0])
X = rand(d, 1_000_000)
fixed_point(X)
@elapsed alpha = fixed_point(X)
norm(d.alpha - alpha, Inf)
newton(X)
@elapsed alpha = newton(X)
norm(d.alpha - alpha, Inf)
function fixed_point{T <: Real}(X::Matrix{T})
K, N = size(X)
lpbar = zeros(Float64, K)
for k in 1:K
for i in 1:N
lpbar[k] += log(X[k, i])
end
lpbar[k] /= N
end
alpha = ones(Float64, K)
maxdelta = Inf
iteration = 0
while maxdelta > 1e-8 && iteration < 10_000
iteration += 1
maxdelta = 0.0
alpha0 = sum(alpha)
for k in 1:K
tmp = invdigamma(digamma(alpha0) + lpbar[k])
delta = abs(alpha[k] - tmp)
if delta > maxdelta
maxdelta = delta
end
alpha[k] = tmp
end
end
@printf "Fixed Point %d\n" iteration
return alpha
end
function newton{T <: Real}(X::Matrix{T})
K, N = size(X)
lpbar = zeros(Float64, K)
for i in 1:N
for k in 1:K
lpbar[k] += log(X[k, i])
end
end
for k in 1:K
lpbar[k] /= N
end
alpha = ones(Float64, K)
g = ones(Float64, K)
q = Array(Float64, K)
iteration = 0
while norm(g, Inf) > 1e-8
iteration += 1
alpha0 = sum(alpha)
for k in 1:K
g[k] = N * (digamma(alpha0) - digamma(alpha[k]) + lpbar[k])
end
for k in 1:K
q[k] = -N * trigamma(alpha[k])
end
b = 0.0
for k in 1:K
b += g[k] / q[k]
end
iz = 1.0 / (N * trigamma(alpha0))
iqs = 0.0
for k in 1:K
iqs += 1.0 / q[k]
end
b /= (iz + iqs)
for k in 1:K
alpha[k] -= (g[k] - b) / q[k]
end
end
@printf "Newton's Method %d\n" iteration
return alpha
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment