Skip to content

Instantly share code, notes, and snippets.

@phipsgabler
Created April 26, 2021 12:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save phipsgabler/4d91bcee6d3d56021ee485b8dd00a5fc to your computer and use it in GitHub Desktop.
Save phipsgabler/4d91bcee6d3d56021ee485b8dd00a5fc to your computer and use it in GitHub Desktop.
# trying to refactor https://codereview.stackexchange.com/q/259818/180160
# not tested for correctness
const EPSILON = 1e-6
log_ϕ(x, μ, σ) = -(((x - μ) / σ)^2 - log(2π)) / 2 - log(σ)
ϕ(x, μ, σ) = exp(log_ϕ(x, μ, σ))
log_likelihood(x, θ) = sum(log(θ.p' * ϕ.(xₜ, θ.μ, θ.σ)) for xₜ in x)
function normalize!(θ; eps=EPSILON)
θ.μ ./= θ.p
θ.σ .= sqrt.(θ.σ ./ clamp.(θ.p, eps, Inf))
θ.p ./= sum(θ.p)
return θ
end
function initialize!(θ)
fill!(θ.μ, zero(eltype(θ.μ)))
fill!(θ.σ, zero(eltype(θ.σ)))
fill!(θ.p, zero(eltype(θ.p)))
return θ
end
function k_means_update!(θ_new, x, θ; eps=EPSILON)
initialize!(θ_new)
for xₜ in x
idx = argmin(abs.(θ.μ .- xₜ))
θ_new.μ[idx] += xₜ
θ_new.σ[idx] += (xₜ - θ.μ[idx])^2
θ_new.p[idx] += 1
end
return normalize!(θ_new; eps=eps)
end
function em_update!(θ_new, p̂, x, θ; eps=EPSILON)
initialize!(θ_new)
for xₜ in x
p̂ .= θ.p ./ θ.σ .* ϕ.(xₜ, θ.μ, θ.σ)
p̂ ./= sum(p̂)
θ_new.μ .+= xₜ .* p̂
θ_new.σ .+= (xₜ .- θ.μ).^2 .* p̂
θ_new.p .+= p̂
end
return normalize!(θ_new; eps=eps)
end
l1_norm(x, y) = mapreduce(abs ∘ -, max, x, y)
metric(θ, θ′) = sum(l1_norm(getproperty(θ, s), getproperty(θ′, s)) for s in propertynames(θ))
function em(x, k, N; tol=3e-4, eps=EPSILON)
i = 1
θ = (;μ = rand(x, k), σ = ones(k), p = ones(k) ./ k)
θ_old = deepcopy(θ)
p̂ = copy(θ.p)
while i < N
if i < 5
k_means_update!(θ, x, θ_old; eps=eps)
else
em_update!(θ, p̂, x, θ_old)
end
(θ, θ_old) = (θ_old, θ)
i += 1
end
return θ
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment