Skip to content

Instantly share code, notes, and snippets.

@josephcc
Created February 23, 2014 19:51
Show Gist options
  • Save josephcc/9176384 to your computer and use it in GitHub Desktop.
Save josephcc/9176384 to your computer and use it in GitHub Desktop.
#!/usr/bin/env ruby
class EM
attr_reader :criterion
def initialize(files, prefix: '', init: nil, weight: nil, uni: false, threshold: 0.001/100, log: true)
@threshold = threshold
@logf = File.open("#{prefix}output.log", 'w') if log
@streams = files.map { |fn| File.open(fn).map(&:to_f) }
@streams << [1.0/@streams[0].size] * @streams[0].size if uni
@lbds = init ? init : [1.0/@streams.size] * @streams.size
@weight = weight ? weight : [1] * @lbds.size
@criterion = convergence()
@iterations = 0
do_log(0.01)
end
def run
delta = 1.0
while delta >= @threshold and @iterations < 100000
lbds = update()
criterion = convergence(lbds)
delta = (criterion - @criterion) / @criterion.abs
update_model(lbds, criterion)
do_log(delta)
end
end
def update_model(lbds, criterion)
@lbds = lbds
@criterion = criterion
@iterations += 1
end
def update
weighted_lbds = @lbds.zip(@weight).map{|i,j| i*j}
denominators = @streams.transpose.map { |p| p.zip(weighted_lbds).map{|i,j| i*j}.inject(:+) }
new_lbds = weighted_lbds.zip(@streams).map do |lbd, stream|
stream.zip(denominators).map{ |p,d| lbd*p/d }.inject(:+) / stream.size
end
end
def convergence(lbds=@lbds)
out = @streams.transpose.map { |p| p.zip(lbds).map{|i,j| i*j}.inject(:+) }
out = out.map{|x| Math.log(x)/out.size}.inject(:+)
end
def do_log(delta)
return unless @logf
@logf.write("#{@iterations},%.6f,%.6f,#{@lbds.inject(:+)},#{@lbds.map{|x|x.round(2)}.join(',')}\n" % [@criterion, delta])
end
def finalize
@logf.close unless @logf.nil?
end
end
def surface(fn, files, steps: 50)
output = File.open(fn, 'w')
(1..steps).to_a.repeated_permutation(2).select{|x,y| steps-x-y > 0}.each do |x,y|
output.write "\n" if y == 1
init = [x, y, steps-x-y].map{|z| z/steps.to_f}
em = EM.new(files, init: init, log: false)
output.write "#{[*init[0...2], em.criterion].join(' ')}\n"
end
output.close
end
surface('out2.data', ARGV)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment