Last active
January 6, 2016 11:44
-
-
Save seinosuke/7438fad6c92f25e2a8f4 to your computer and use it in GitHub Desktop.
【Ruby】 EMアルゴリズムでクラスタリング 参照→( http://syoshinsyakangeisagi.blogspot.com/2016/01/ruby-em.html )
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EMA | |
attr_reader :all_patterns, :k | |
def initialize(options = {}) | |
@dimension = options[:dimension] | |
@k = options[:k] | |
orig_patterns = options[:patterns].map do |pattern| | |
Matrix[pattern.map(&:to_f)] | |
end | |
@means = orig_patterns.sample(@k) | |
@all_patterns = Array.new(@k) { [] } | |
orig_patterns.each do |pattern| | |
c = @means.map.with_index do |m, i| | |
d = (m - pattern).inject(0.0) { |sum, v| sum + v*v } | |
{:i => i, :d => d} | |
end.min_by { |h| h[:d] }[:i] | |
@all_patterns[c] << pattern | |
end | |
@means = @all_patterns.map do |patterns| | |
patterns.inject(Matrix[Array.new(@dimension) { 0.0 }], :+) | |
.map { |v| v / patterns.size.to_f } | |
end | |
@sigmas = @all_patterns.map.with_index do |patterns, i| | |
patterns.inject(Matrix.zero(@dimension)) do |sum, x| | |
sum + (x - @means[i]).t * (x - @means[i]) | |
end / patterns.size.to_f | |
end | |
end | |
def update | |
e_step | |
m_step | |
end | |
# 発生している確率が最も高いクラスタへ振り分け直す | |
def e_step | |
patterns = @all_patterns.flatten | |
@all_patterns = Array.new(@k) { [] } | |
patterns.each do |pattern| | |
c = @k.times.map do |i| | |
{:i => i, :prob => gauss(i, pattern)} | |
end.max_by { |h| h[:prob] }[:i] | |
@all_patterns[c] << pattern | |
end | |
end | |
# 更新されたクラスタ内で平均と共分散行列を計算し直す | |
def m_step | |
@means = @all_patterns.map.with_index do |patterns, i| | |
patterns.inject(Matrix[Array.new(@dimension) { 0.0 }]) do |sum, x| | |
sum + (x * prob(i, x)) | |
end / patterns.inject(0.0) { |sum, x| sum + prob(i, x) } | |
end | |
@sigmas = @all_patterns.map.with_index do |patterns, i| | |
patterns.inject(Matrix.zero(@dimension)) do |sum, x| | |
sum + (x - @means[i]).t * (x - @means[i]) | |
end / patterns.inject(0.0) { |sum, x| sum + prob(i, x) } | |
end | |
end | |
# i番目の正規分布についてxにおける値を返す | |
def gauss(i, x) | |
Math.exp(-0.5 * ((x - @means[i]) * @sigmas[i].inv * (x - @means[i]).t)[0, 0]) / | |
( (Math.sqrt(2.0*Math::PI)**@dimension) * Math.sqrt(@sigmas[i].det) ) | |
end | |
# xがi番目のクラスタから発生している確率 P(Wi|x) | |
def prob(i, x) | |
gauss(i, x) / @k.times.inject(0.0) { |sum, j| sum + gauss(j, x) } | |
end | |
# gnuplot用に2次元正規分布の文字列を返す | |
# 特にアルゴリズムには関係ない | |
def gauss_str(i) | |
s1 = Math.sqrt(@sigmas[i][0, 0]) | |
s2 = Math.sqrt(@sigmas[i][1, 1]) | |
m1 = @means[i][0, 0] | |
m2 = @means[i][0, 1] | |
rho = @sigmas[i][0, 1] / (s1 * s2) | |
"exp( -( ((x-#{m1})/#{s1})**2 - #{2.0*rho}*((x-#{m1})/#{s1})*((y-#{m2})/#{s2}) + ((y-#{m2})/#{s2})**2 )" << | |
" / " << | |
" #{2.0*(1.0 - rho**2)} )" << | |
" / " << | |
"#{( (Math.sqrt(2.0*Math::PI)**@dimension) * Math.sqrt(@sigmas[i].det) )}" | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'matrix' | |
require 'open3' | |
require 'pp' | |
require 'pry' | |
require "./ema" | |
include Math | |
def generate(num = 100, m = [0, 0], s = [1, 1]) | |
num.times.map do | |
r1, r2 = rand, rand | |
x = sqrt(-2 * log(r1)) * cos(2 * PI * r2) | |
y = sqrt(-2 * log(r1)) * sin(2 * PI * r2) | |
[s[0]*x + m[0], s[1]*y + m[1]] | |
end | |
end | |
# 適当なパターン | |
patterns = | |
generate(100, [2.2, 4.0], [0.2, 0.9]) + | |
generate(100, [1.7, 5.0], [0.2, 0.9]) + | |
generate(100, [1.2, 6.0], [0.2, 0.9]) + | |
generate(200, [6.0, 3.0], [0.8, 0.8]) + | |
generate(200, [7.0, 4.0], [0.8, 0.8]) + | |
generate(50, [5.0, 7.5], [0.8, 0.2]) + | |
generate(50, [6.0, 8.0], [0.8, 0.2]) + | |
generate(50, [7.0, 8.5], [0.8, 0.2]) | |
options = { | |
:patterns => patterns, | |
:k => 3, | |
:dimension => 2, | |
} | |
ema = EMA.new(options) | |
x_size = 10 | |
y_size = 10 | |
Open3.popen3('gnuplot') do |gp_in, gp_out, gp_err| | |
output_file = "./ema_result.gif" | |
offset = 2 | |
loop_num = 40 | |
gp_in.puts "set terminal gif animate delay 12 optimize size 480, 450" | |
gp_in.puts "set tmargin at screen 0.9" | |
gp_in.puts "set bmargin at screen 0.23" | |
gp_in.puts "set output '#{output_file}'" | |
gp_in.puts "set xrange [0:#{x_size}]" | |
gp_in.puts "set yrange [0:#{y_size}]" | |
gp_in.puts "set zrange [0:#{offset + 1.0}]" | |
gp_in.puts "unset ztics" | |
gp_in.puts "unset colorbox" | |
gp_in.puts "set ticslevel 0" | |
gp_in.puts "set view 61, 30, 1, 1" | |
gp_in.puts "set hidden3d" | |
gp_in.puts "set isosample 60" | |
gp_in.puts "set palette defined (0 'white', 1 'light-red')" | |
gp_in.puts "set cbrange [#{offset}:#{offset + 0.2}]" | |
gp_in.puts "set pm3d at b" | |
loop_num.times do |n| | |
splot = "splot " | |
distribution = ema.k.times.map { |i| ema.gauss_str(i) } | |
distribution = distribution.inject("") { |sum, d| sum << "(#{d}) + " } | |
splot << "#{distribution} #{offset} notitle,\\\n" | |
# 各クラスのパターンの色設定 | |
ema.k.times do |i| | |
splot << "'-' notitle pt 1 ps 0.5 lc #{i+2} nohidden3d,\\\n" | |
end | |
splot.gsub!(/,\\\n\z/, "\n") | |
# 色分けされたパターン xy平面に描画 | |
ema.k.times do |i| | |
ema.all_patterns[i].map(&:to_a).flatten(1).each do |x, y| | |
splot << "#{x}, #{y}, 0.0\n" | |
end | |
splot << "e\n" | |
end | |
gp_in.puts splot | |
# 分布を更新 | |
ema.update | |
puts " [#{"*"*n}#{" "*(loop_num-n)}]" | |
printf "\e[1A"; STDOUT.flush; | |
end | |
puts " [#{"*"*loop_num}]" | |
gp_in.puts "set output" | |
gp_in.puts "exit" | |
gp_in.close | |
end | |
Open3.popen3('gnuplot') do |gp_in, gp_out, gp_err| | |
output_file = "./ema_patterns.png" | |
gp_in.puts "set terminal png size 480, 450" | |
gp_in.puts "set output '#{output_file}'" | |
gp_in.puts "set xrange [0:#{x_size}]" | |
gp_in.puts "set yrange [0:#{y_size}]" | |
plot = "plot " | |
# クラスタリング前のパターン | |
plot << "'-' notitle pt 1 ps 0.5 lc 1\n" | |
patterns.each do |x, y| | |
plot << "#{x}, #{y}\n" | |
end | |
plot << "e\n" | |
gp_in.puts plot | |
gp_in.puts "set output" | |
output_file = "./ema_result.png" | |
gp_in.puts "set terminal png size 480, 450" | |
gp_in.puts "set output '#{output_file}'" | |
gp_in.puts "set xrange [0:#{x_size}]" | |
gp_in.puts "set yrange [0:#{y_size}]" | |
plot = "plot " | |
# 各クラスのパターンの色設定 | |
ema.k.times do |i| | |
plot << "'-' notitle pt 1 ps 0.5 lc #{i+2},\\\n" | |
end | |
plot.gsub!(/,\\\n\z/, "\n") | |
# クラスタリング後のパターン | |
ema.k.times do |i| | |
ema.all_patterns[i].map(&:to_a).flatten(1).each do |x, y| | |
plot << "#{x}, #{y}\n" | |
end | |
plot << "e\n" | |
end | |
gp_in.puts plot | |
gp_in.puts "set output" | |
gp_in.puts "exit" | |
gp_in.close | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment