Skip to content

Instantly share code, notes, and snippets.

@sschnug
Created February 26, 2016 20:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sschnug/a397c772c7c0392db723 to your computer and use it in GitHub Desktop.
Save sschnug/a397c772c7c0392db723 to your computer and use it in GitHub Desktop.
Trimodal Gaussian Fit (MLE-opt by JuMP)
using JuMP, AmplNLWriter, NLopt
n = 1000
mu1_true = 0.3
mu2_true = 0.55
mu3_true = 0.10
sig1_true = 0.08
sig2_true = 0.12
sig3_true = 0.32
a_0_true = 0.5
a_1_true = 0.2
a_2_true = 0.3
srand(1)
s1 = randn(round(Int, a_0_true * n)) * sig1_true + mu1_true
s2 = randn(round(Int, a_1_true * n)) * sig2_true + mu2_true
s3 = randn(round(Int, a_2_true * n)) * sig3_true + mu3_true
println(length(s1))
println(length(s2))
println(length(s3))
data = vcat(s1, s2, s3)
println(length(data))
m = Model(solver=AmplNLSolver("couenne"))
#m = Model(solver=NLoptSolver(algorithm=:LD_LBFGS))
#m = Model(solver=NLoptSolver(algorithm=:GN_DIRECT_L))
#m = Model(solver=NLoptSolver(algorithm=:LD_MMA))
@defVar(m, 0 <= mu_1 <= 1, start=0.25)
@defVar(m, 0 <= mu_2 <= 1, start=0.6)
@defVar(m, 0 <= mu_3 <= 1, start=0.2)
@defVar(m, 0 <= sigma_1 <= 1, start=0.1)
@defVar(m, 0 <= sigma_2 <= 1, start=0.1)
@defVar(m, 0 <= sigma_3 <= 1, start=0.2)
@defVar(m, 0 <= a_0 <= 1, start=0.45)
@defVar(m, 0 <= a_1 <= 1, start=0.15)
@defVar(m, 0 <= a_2 <= 1, start=0.25)
@defVar(m, exp(1) <= e <= exp(1), start=exp(1))
@addConstraint(m, a_0 + a_1 + a_2 == 1.0)
@setNLObjective(m, Max, sum{
log(
(((a_0)/sigma_1*sqrt(2*π))*e^((-(data[i]-mu_1)^2/(2*sigma_1^2)))) +
(((a_1)/sigma_2*sqrt(2*π))*e^((-(data[i]-mu_2)^2/(2*sigma_2^2)))) +
(((a_2)/sigma_3*sqrt(2*π))*e^((-(data[i]-mu_3)^2/(2*sigma_3^2))))), i=1:n})
# setValue(mu_1, 0.25)
# setValue(mu_2, 0.6)
# setValue(sigma_1, 0.1)
# setValue(sigma_2, 0.1)
# setValue(a, 0.5)
# setValue(e, exp(1))
solve(m)
println("μ1 = ", getValue(mu_1))
println("σ1 = ", getValue(sigma_1))
println("μ2 = ", getValue(mu_2))
println("σ2 = ", getValue(sigma_2))
println("μ3 = ", getValue(mu_3))
println("σ3 = ", getValue(sigma_3))
println("a = ", getValue(a_0))
println("a = ", getValue(a_1))
println("a = ", getValue(a_2))
println("MLE objective: ", getObjectiveValue(m))
#println(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment