Skip to content

Instantly share code, notes, and snippets.

@ybenjo
Created August 29, 2013 13:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ybenjo/6378026 to your computer and use it in GitHub Desktop.
Save ybenjo/6378026 to your computer and use it in GitHub Desktop.
GMM
# GMM (generalized Mixture Model) of EM
# usage
# gmm = GMM.new([1, 2, 3, 5, 2, 1, 10, 20, 30, 20], {k: 2})
LIM = 10 ** -5
PROB_LIM = 10 ** -100
def dist(x, mu, sigma)
prob = 1.0 / (2 * Math::PI * sigma) ** 0.5 * Math::exp(-(x - mu) ** 2 / (2 * sigma))
# avoid p(x|mu, sigma) = 0.0
prob < PROB_LIM ? PROB_LIM : prob
end
class GMM
attr_reader :log_l, :mu, :sigma, :pi, :gamma
def initialize(ary, options = { })
@values = ary
raise StandardError if @values.uniq.size < 2
# initialize
seed = options[:srand] || 0
srand(seed)
@k = options[:k] || 2
# initialize paramters
@mu = { }
@sigma = { }
@pi = { }
@gamma = { }
sum_pi = 0.0
@k.times do |i|
@mu[i] = rand
@sigma[i] = rand
pi = rand
sum_pi += pi
@pi[i] = pi
end
@pi.each_key{|k| @pi[k] /= sum_pi}
# initialize log likelihood
@log_l = [ ]
end
def log_likelihood
sum = 0.0
@values.each do |v|
log_sum = 0.0
@k.times do |k|
log_sum += @pi[k] * dist(v, @mu[k], @sigma[k])
end
sum += Math::log(log_sum)
end
raise StandardError if sum.nan?
sum
end
def e_step
@values.each_with_index do |v, n|
numer = { }
@k.times do |k|
numer[k] = @pi[k] * dist(v, @mu[k], @sigma[k])
end
denom = numer.values.inject(:+)
# update
@k.times do |k|
@gamma[n => k] = numer[k] / denom
end
end
end
def m_step
# calc N_k
large_n = Hash.new{0.0}
@k.times do |k|
@values.each_index do |n|
large_n[k] += @gamma[n => k]
end
end
# update mu
@k.times do |k|
sum = 0.0
@values.each_with_index do |v, n|
sum += @gamma[n => k] * v
end
@mu[k] = sum / large_n[k]
end
# update sigma and pi
@k.times do |k|
# update sigma
sum = 0.0
@values.each_with_index do |v, n|
sum += @gamma[n => k] * ((v - @mu[k]) ** 2)
end
@sigma[k] = sum / large_n[k]
# update pi
@pi[k] = large_n[k] / @values.size.to_f
end
end
def converge?
(@log_l.size > 1) && (@log_l[-1] - @log_l[-2] <= LIM)
end
def train
while !converge?
e_step
m_step
@log_l.push log_likelihood
end
end
def prob(new_x)
sum = 0.0
@k.times do |k|
sum += @pi[k] * dist(new_x, @mu[k], @sigma[k])
end
sum
end
end
if __FILE__ == $0
gmm = GMM.new([1, 2, 3, 5, 2, 1, 199, 200, 201, 200])
gmm.train
p gmm
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment