{{ message }}

Instantly share code, notes, and snippets.

# dustalov/expectation-maximization.rb

Last active Aug 29, 2015
EM-algorithm coin example
 #!/usr/bin/env ruby =begin http://ai.stanford.edu/~chuongdo/papers/em_tutorial.pdf http://stats.stackexchange.com/questions/72774/numerical-example-to-understand-expectation-maximization http://math.stackexchange.com/questions/25111/how-does-expectation-maximization-work http://math.stackexchange.com/questions/81004/how-does-expectation-maximization-work-in-coin-flipping-problem http://www.youtube.com/watch?v=7e65vXZEv5Q =end # gem install distribution require 'distribution' # error bound EPS = 10**-6 # number of coin tosses N = 10 # observations X = [5, 9, 8, 4, 7] # randomly initialized thetas theta_a, theta_b = 0.6, 0.5 p [theta_a, theta_b] loop do expectation = X.map do |h| like_a = Distribution::Binomial.pdf(h, N, theta_a) like_b = Distribution::Binomial.pdf(h, N, theta_b) norm_a = like_a / (like_a + like_b) norm_b = like_b / (like_a + like_b) [norm_a, norm_b, h] end maximization = expectation.each_with_object([0.0, 0.0, 0.0, 0.0]) do |(norm_a, norm_b, h), r| r[0] += norm_a * h; r[1] += norm_a * (N - h) r[2] += norm_b * h; r[3] += norm_b * (N - h) end theta_a_hat = maximization[0] / (maximization[0] + maximization[1]) theta_b_hat = maximization[2] / (maximization[2] + maximization[3]) error_a = (theta_a_hat - theta_a).abs / theta_a error_b = (theta_b_hat - theta_b).abs / theta_b theta_a, theta_b = theta_a_hat, theta_b_hat p [theta_a, theta_b] break if error_a < EPS && error_b < EPS end

### dustalov commented May 17, 2015

 ``````\$ ./expectation-maximization.rb [0.6, 0.5] [0.7130122354005162, 0.5813393083136627] [0.7452920360819946, 0.5692557501718728] [0.768098834367321, 0.5495359141383478] [0.7831645842999736, 0.5346174541475203] [0.7910552458637528, 0.5262811670299319] [0.7945325379936994, 0.5223904375178747] [0.7959286672497986, 0.5207298780860258] [0.7964656379225264, 0.5200471890029876] [0.7966683078984395, 0.5197703896938074] [0.7967441494752117, 0.5196586622041123] [0.7967724046132105, 0.5196136079148447] [0.7967829009034072, 0.5195954342272031] [0.7967867907879176, 0.5195880980854485] [0.7967882289822585, 0.5195851341793923] [0.7967887593831098, 0.5195839356752803] [0.7967889544439393, 0.5195834506301285] ``````