Last active
September 7, 2016 07:06
-
-
Save seinosuke/4f565647f33c815e18123404a224a3e1 to your computer and use it in GitHub Desktop.
【Ruby】 VBEMアルゴリズム 参照→( http://syoshinsyakangeisagi.blogspot.com/2016/09/ruby-vbemprml10.html )
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'matrix' | |
require 'open3' | |
require 'pp' | |
require 'pry' | |
require "./utils" | |
require "./vbem" | |
LOOP = 150 # 更新回数 | |
THRESHOLD = 0.1 # 描画する分布の混合係数のしきい値 | |
# 適当なデータを用意 | |
patterns = Utils.generate(80, [4,3.5], [1,2], 60) | |
patterns += Utils.generate(120, [-4,3.5], [3,1], 60) | |
patterns += Utils.generate(100, [0,-4.5], [3,1], 30) | |
patterns.map! { |pattern| pattern.map { |x| x*0.1 } } | |
options = { | |
:k => 5, | |
:X => patterns | |
} | |
vbem = VBEM.new(options) | |
############################################################ | |
# 元のデータ | |
############################################################ | |
Open3.popen3('gnuplot') do |gp_in, gp_out, gp_err| | |
output_file = "./original_data.png" | |
gp_in.puts "set terminal png size 600, 600" | |
gp_in.puts "set output '#{output_file}'" | |
gp_in.puts "set grid" | |
gp_in.puts "set size square" | |
gp_in.puts "set xtics 0.1" | |
gp_in.puts "set ytics 0.1" | |
xrange = [-1.0, 1.0] | |
yrange = [-1.0, 1.0] | |
gp_in.puts xrange.tap { |f, t| break "set xrange [#{f}:#{t}]" } | |
gp_in.puts yrange.tap { |f, t| break "set yrange [#{f}:#{t}]" } | |
plot = "plot " | |
plot << "'-' notitle with points lt 1 lw 1 lc rgb 'black',\\\n" | |
plot.gsub!(/,\\\n\z/, "\n") | |
patterns.each { |x1, x2| plot << "#{x1}, #{x2}\n" } | |
plot << "e\n" | |
gp_in.puts plot | |
gp_in.puts "set output" | |
gp_in.puts "exit" | |
gp_in.close | |
end | |
############################################################ | |
# 途中経過のGIF | |
############################################################ | |
Open3.popen3('gnuplot') do |gp_in, gp_out, gp_err| | |
output_file = "./result.gif" | |
gp_in.puts "set terminal gif animate delay 10 size 600, 600" | |
gp_in.puts "set output '#{output_file}'" | |
gp_in.puts "set grid" | |
gp_in.puts "set size square" | |
gp_in.puts "set parametric" | |
gp_in.puts "set style fill transparent solid 0.2 border lc rgb 'red'" | |
gp_in.puts "set xtics 0.1" | |
gp_in.puts "set ytics 0.1" | |
xrange = [-1.0, 1.0] | |
yrange = [-1.0, 1.0] | |
trange = [0, 2*PI] | |
gp_in.puts xrange.tap { |f, t| break "set xrange [#{f}:#{t}]" } | |
gp_in.puts yrange.tap { |f, t| break "set yrange [#{f}:#{t}]" } | |
gp_in.puts trange.tap { |f, t| break "set trange [#{f}:#{t}]" } | |
LOOP.times do |t| | |
vbem.update | |
plot = "plot " | |
vbem.k.times do |k| | |
unless vbem.E_pi[k] < THRESHOLD | |
px, py = Utils.gnuplot_ellipse(vbem.E_mu[k] ,vbem.E_lambda[k].inv) | |
plot << "#{px}, #{py} notitle with filledcurves lt 1 lw 2 lc rgb 'red',\\\n" | |
plot << "'-' notitle with points lt 7 lw 2 lc rgb 'red',\\\n" | |
end | |
end | |
plot << "'-' notitle with points lt 1 lw 1 lc rgb 'black',\\\n" | |
plot.gsub!(/,\\\n\z/, "\n") | |
vbem.E_mu.each_with_index do |mu, k| | |
unless vbem.E_pi[k] < THRESHOLD | |
plot << "#{mu[0,0]}, #{mu[0,1]}\n" | |
plot << "e\n" | |
end | |
end | |
patterns.each { |x1, x2| plot << "#{x1}, #{x2}\n" } | |
plot << "e\n" | |
gp_in.puts plot | |
puts " [#{("*"*((t.to_f / LOOP)*10).to_i).ljust(9, " ")}]" | |
print "\e[1A"; STDOUT.flush; | |
end | |
gp_in.puts "set output" | |
gp_in.puts "exit" | |
gp_in.close | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
include Math | |
module Utils | |
class << self | |
# ディガンマ関数 | |
def digamma(x) | |
if x > 0 | |
result = 0.0 | |
loop do | |
result -= 1.0/x | |
x += 1.0 | |
break if x > 6.0 | |
end | |
x -= 1.0/2.0 | |
xx = 1.0/x | |
xx2 = xx*xx | |
xx4 = xx2*xx2 | |
result += log(x)+(1.0/24.0)*xx2-(7.0/960.0)*xx4+(31.0/8064.0)*xx4*xx2-(127.0/30720.0)*xx4*xx4 | |
else | |
raise | |
end | |
result = (result > 0 ? 1e+10 : -1e+10) if result.abs > 1e+10 | |
result = (result > 0 ? 1e-10 : -1e-10) if result.abs < 1e-10 | |
result | |
end | |
def max(x, min) | |
x < min ? min : x | |
end | |
# 適当なデータをつくる | |
def generate(num = 100, m = [0, 0], s = [1, 1], θ = 0) | |
θ = PI * (θ / 180.0) | |
r = Matrix[[cos(θ), -sin(θ)], [sin(θ), cos(θ)]] | |
num.times.map do | |
r1, r2 = rand, rand | |
x = sqrt(-2 * log(r1)) * cos(2 * PI * r2) | |
y = sqrt(-2 * log(r1)) * sin(2 * PI * r2) | |
x = r*Matrix[[s[0]*x, s[1]*y]].t | |
(x.t + Matrix[m]).to_a.flatten | |
end | |
end | |
# 正規分布の平均、共分散行列から楕円を描く | |
def gnuplot_ellipse(mean, sigma, c1 = 0.95) | |
s1 = sqrt(sigma[0, 0]) | |
s2 = sqrt(sigma[1, 1]) | |
m1 = mean[0, 0] | |
m2 = mean[0, 1] | |
rho = sigma[0, 1] / (s1 * s2) | |
c2 = (-2*(1-rho**2)*log(2*PI*c1*s1*s2*sqrt(1-rho**2))).abs | |
a = sqrt(c2 / (1+rho)) | |
b = sqrt(c2 / (1-rho)) | |
θ = PI * (315.0 / 180.0) | |
px = "#{a*s1*cos(θ)}*cos(t) - #{b*s1*sin(θ)}*sin(t) " | |
py = "#{a*s2*sin(θ)}*cos(t) + #{b*s2*cos(θ)}*sin(t) " | |
[px << "+#{m1}", py << "+#{m2}"] | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class VBEM | |
attr_reader :m, :W, :r, :n, :k | |
attr_reader :E_lambda, :E_pi | |
alias :E_mu :m | |
def initialize(options = {}) | |
@dim = options[:X].first.size | |
@k = options[:k] | |
@X = options[:X].map { |x| Matrix[x.map(&:to_f)] } | |
@num = options[:X].size | |
set_hyper_parameters | |
rand_init | |
end | |
# 1ステップ進める | |
def update | |
vb_m_step | |
vb_e_step | |
p @E_pi | |
end | |
# ハイパーパラメータを初期化 | |
def set_hyper_parameters | |
@alpha_0 = Matrix[ Array.new(@k) { 1e-0 } ] | |
@beta_0 = Matrix[ Array.new(@k) { 1e-3 } ] | |
@m_0 = Matrix[ Array.new(@dim) { 0.0 } ] | |
@W_0 = Matrix[ *Matrix.I(@dim).to_a.map { |row| row.map(&:to_f) } ] | |
@invW_0 = @W_0.inv | |
@nu_0 = Matrix[ Array.new(@k) { @dim } ] | |
end | |
# ランダムに負担率を初期化 | |
def rand_init | |
@r = Array.new(@num) { Array.new(@k) { rand } } | |
@r = @r.map do |row_n| | |
sum = row_n.inject(:+) | |
row_n.map { |v| v / sum } | |
end | |
end | |
############################################################ | |
# VB Mstep | |
############################################################ | |
def vb_m_step | |
@n = Matrix[ @r.transpose.map { |row| row.inject(:+) } ] | |
@alpha = @alpha_0 + @n | |
@beta = @beta_0 + @n | |
@nu = @nu_0 + @n | |
# 後の計算に使う (K mat(1 D)) | |
@x_ = @k.times.map do |k| | |
sum = Matrix[Array.new(@dim) { 0.0 }] | |
@num.times { |n| sum += @r[n][k] * @X[n] } | |
sum / @n[0, k] | |
end | |
# 後の計算に使う (K mat(D D)) | |
@S = @k.times.map do |k| | |
sum = Matrix.zero(@dim) | |
@num.times do |n| | |
tmp = (@X[n] - @x_[k]).t * (@X[n] - @x_[k]) | |
sum += @r[n][k] * tmp | |
end | |
sum / @n[0, k] | |
end | |
# 平均のパラメータを更新 (K mat(1 D)) | |
@m = @k.times.map do |k| | |
ret = Matrix[Array.new(@dim) { 0.0 }] | |
ret += @m_0 * @beta_0[0, k] | |
ret += @n[0, k]*@x_[k] | |
ret / @beta[0, k] | |
end | |
# 精度のパラメータを更新 (K mat(D D)) | |
@W = @k.times.map do |k| | |
ret = Matrix.zero(@dim) | |
ret += @invW_0 + @n[0, k]*@S[k] | |
tmp = (@beta_0[0, k] * @n[0, k]) / (@beta_0[0, k] + @n[0, k]) | |
ret += tmp * (@x_[k] - @m_0).t * (@x_[k] - @m_0) | |
ret.inv | |
end | |
end | |
############################################################ | |
# VB Estep | |
############################################################ | |
def vb_e_step | |
# PIの期待値 (1 K) | |
@E_pi = @k.times.map do |k| | |
@alpha[0, k] / @alpha.to_a.flatten.inject(:+) | |
end | |
# lnPIの期待値 (1 K) | |
@E_lnpi = @k.times.map do |k| | |
Utils.digamma(@alpha[0, k]) - Utils.digamma(@alpha.to_a.flatten.inject(:+)) | |
end | |
# Λの期待値 (K mat(D D)) | |
@E_lambda = @k.times.map do |k| | |
@nu[0, k] * @W[k] | |
end | |
# lnΛの期待値 (1 K) | |
@E_lnlambda = @k.times.map do |k| | |
ret = @dim * Math.log(2.0) + Math.log(@W[k].det) | |
ret += (1..@dim).inject(0.0) { |sum, i| Utils.digamma((@nu[0, k]+1-i) / 2.0) } | |
ret | |
end | |
# 後の計算に使う mat(N K) | |
@E_x_mWx_m = @num.times.map do |n| | |
@k.times.map do |k| | |
ret = @dim / @beta[0, k] | |
d = @X[n] - @m[k] | |
ret += @nu[0, k] * (d * @W[k] * d.t)[0, 0] | |
ret | |
end | |
end | |
@E_x_mWx_m = Matrix[*@E_x_mWx_m] | |
# lnrhoの計算で使う mat(N K) | |
@E_lnN = @num.times.map do |n| | |
@k.times.map do |k| | |
ret = @E_lnlambda[k] - @dim*Math.log(2.0*PI) - @E_x_mWx_m[n, k] | |
ret / 2.0 | |
end | |
end | |
@E_lnN = Matrix[*@E_lnN] | |
# 負担率を更新 mat(N K) | |
@lnrho = @E_lnN + Matrix[ *Array.new(@num) { @E_lnpi } ] | |
@r = @lnrho.to_a.map do |row| | |
logsumexp = Math.log( row.inject(0.0) { |sum, v| sum + Math.exp(v) } ) | |
row.map { |v| Math.exp(v - logsumexp) } | |
end.map do |row| | |
sum = row.inject(:+) | |
row.map do |v| | |
v = Utils.max(v, 1e-10) | |
v / sum | |
end | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment