Skip to content

Instantly share code, notes, and snippets.

@dustalov
Last active August 29, 2015 14:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dustalov/f8c12abaca6313114acf to your computer and use it in GitHub Desktop.
Save dustalov/f8c12abaca6313114acf to your computer and use it in GitHub Desktop.
EM-algorithm coin example
#!/usr/bin/env ruby
=begin
http://ai.stanford.edu/~chuongdo/papers/em_tutorial.pdf
http://stats.stackexchange.com/questions/72774/numerical-example-to-understand-expectation-maximization
http://math.stackexchange.com/questions/25111/how-does-expectation-maximization-work
http://math.stackexchange.com/questions/81004/how-does-expectation-maximization-work-in-coin-flipping-problem
http://www.youtube.com/watch?v=7e65vXZEv5Q
=end
# gem install distribution
require 'distribution'
# error bound
EPS = 10**-6
# number of coin tosses
N = 10
# observations
X = [5, 9, 8, 4, 7]
# randomly initialized thetas
theta_a, theta_b = 0.6, 0.5
p [theta_a, theta_b]
loop do
expectation = X.map do |h|
like_a = Distribution::Binomial.pdf(h, N, theta_a)
like_b = Distribution::Binomial.pdf(h, N, theta_b)
norm_a = like_a / (like_a + like_b)
norm_b = like_b / (like_a + like_b)
[norm_a, norm_b, h]
end
maximization = expectation.each_with_object([0.0, 0.0, 0.0, 0.0]) do |(norm_a, norm_b, h), r|
r[0] += norm_a * h; r[1] += norm_a * (N - h)
r[2] += norm_b * h; r[3] += norm_b * (N - h)
end
theta_a_hat = maximization[0] / (maximization[0] + maximization[1])
theta_b_hat = maximization[2] / (maximization[2] + maximization[3])
error_a = (theta_a_hat - theta_a).abs / theta_a
error_b = (theta_b_hat - theta_b).abs / theta_b
theta_a, theta_b = theta_a_hat, theta_b_hat
p [theta_a, theta_b]
break if error_a < EPS && error_b < EPS
end
@dustalov
Copy link
Author

Do & Batzoglou

$ ./expectation-maximization.rb
[0.6, 0.5]
[0.7130122354005162, 0.5813393083136627]
[0.7452920360819946, 0.5692557501718728]
[0.768098834367321, 0.5495359141383478]
[0.7831645842999736, 0.5346174541475203]
[0.7910552458637528, 0.5262811670299319]
[0.7945325379936994, 0.5223904375178747]
[0.7959286672497986, 0.5207298780860258]
[0.7964656379225264, 0.5200471890029876]
[0.7966683078984395, 0.5197703896938074]
[0.7967441494752117, 0.5196586622041123]
[0.7967724046132105, 0.5196136079148447]
[0.7967829009034072, 0.5195954342272031]
[0.7967867907879176, 0.5195880980854485]
[0.7967882289822585, 0.5195851341793923]
[0.7967887593831098, 0.5195839356752803]
[0.7967889544439393, 0.5195834506301285]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment