Instantly share code, notes, and snippets.

# dustalov/expectation-maximization.rb

Last active Aug 29, 2015
EM-algorithm coin example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 #!/usr/bin/env ruby =begin http://ai.stanford.edu/~chuongdo/papers/em_tutorial.pdf http://stats.stackexchange.com/questions/72774/numerical-example-to-understand-expectation-maximization http://math.stackexchange.com/questions/25111/how-does-expectation-maximization-work http://math.stackexchange.com/questions/81004/how-does-expectation-maximization-work-in-coin-flipping-problem http://www.youtube.com/watch?v=7e65vXZEv5Q =end # gem install distribution require 'distribution' # error bound EPS = 10**-6 # number of coin tosses N = 10 # observations X = [5, 9, 8, 4, 7] # randomly initialized thetas theta_a, theta_b = 0.6, 0.5 p [theta_a, theta_b] loop do expectation = X.map do |h| like_a = Distribution::Binomial.pdf(h, N, theta_a) like_b = Distribution::Binomial.pdf(h, N, theta_b) norm_a = like_a / (like_a + like_b) norm_b = like_b / (like_a + like_b) [norm_a, norm_b, h] end maximization = expectation.each_with_object([0.0, 0.0, 0.0, 0.0]) do |(norm_a, norm_b, h), r| r += norm_a * h; r += norm_a * (N - h) r += norm_b * h; r += norm_b * (N - h) end theta_a_hat = maximization / (maximization + maximization) theta_b_hat = maximization / (maximization + maximization) error_a = (theta_a_hat - theta_a).abs / theta_a error_b = (theta_b_hat - theta_b).abs / theta_b theta_a, theta_b = theta_a_hat, theta_b_hat p [theta_a, theta_b] break if error_a < EPS && error_b < EPS end

### dustalov commented May 17, 2015 ``````\$ ./expectation-maximization.rb
[0.6, 0.5]
[0.7130122354005162, 0.5813393083136627]
[0.7452920360819946, 0.5692557501718728]
[0.768098834367321, 0.5495359141383478]
[0.7831645842999736, 0.5346174541475203]
[0.7910552458637528, 0.5262811670299319]
[0.7945325379936994, 0.5223904375178747]
[0.7959286672497986, 0.5207298780860258]
[0.7964656379225264, 0.5200471890029876]
[0.7966683078984395, 0.5197703896938074]
[0.7967441494752117, 0.5196586622041123]
[0.7967724046132105, 0.5196136079148447]
[0.7967829009034072, 0.5195954342272031]
[0.7967867907879176, 0.5195880980854485]
[0.7967882289822585, 0.5195851341793923]
[0.7967887593831098, 0.5195839356752803]
[0.7967889544439393, 0.5195834506301285]
``````