Last active
November 24, 2015 09:59
-
-
Save seinosuke/8cd29c69ffd2815a72b7 to your computer and use it in GitHub Desktop.
Rubyで学習ベクトル量子化 参照→( http://syoshinsyakangeisagi.blogspot.com/2015/11/ruby.html )
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LVQ | |
attr_accessor :log | |
ALPHA = 0.005 | |
def initialize(learning_patterns, dimension) | |
@log = [] | |
@dimension = dimension | |
@class_num = learning_patterns.size | |
@learning_patterns = learning_patterns.map do |patterns| | |
patterns.map { |pattern| Vector[*pattern.map(&:to_f)] } | |
end | |
# 各クラスの平均ベクトル | |
m = @learning_patterns.map do |patterns| | |
patterns.each.inject(Vector[*Array.new(@dimension) { 0.0 }], :+) | |
.map { |v| v / patterns.size.to_f } | |
end | |
# 各クラスの平均ベクトル付近を代表パターンの初期値に | |
@representative_patterns = @class_num.times.map do |i| | |
Array.new(4) do | |
m[i] + Vector[ *@dimension.times.map { rand } ] | |
end | |
end | |
end | |
def learn | |
50.times do | |
correct_errors | |
end | |
end | |
# 代表パターンを修正していく | |
def correct_errors | |
@learning_patterns.each_with_index do |patterns, i| | |
patterns.each do |pattern| | |
r_i, r_j = nearest_neighbor(pattern) | |
if i == r_i | |
@representative_patterns[r_i][r_j] += | |
ALPHA * (pattern - @representative_patterns[r_i][r_j]) | |
else | |
@representative_patterns[r_i][r_j] -= | |
ALPHA * (pattern - @representative_patterns[r_i][r_j]) | |
end | |
end | |
@log << Marshal.load(Marshal.dump(@representative_patterns)) | |
end | |
end | |
# 最近傍の代表パターンは何クラスの何番目のものかを返す | |
def nearest_neighbor(l_pattern) | |
@representative_patterns.map.with_index do |patterns, i| | |
patterns.map.with_index do |r_pattern, j| | |
distance = @dimension.times.inject(0) do |sum, k| | |
sum + (r_pattern[k] - l_pattern[k])**2 | |
end | |
{ :at => [i, j], :distance => distance } | |
end | |
end.flatten.min_by { |h| h[:distance] }[:at] | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "./lvq" | |
require "./voronoi" | |
require 'matrix' | |
require 'open3' | |
require 'pp' | |
include Math | |
def generate(num = 100, m = [0, 0], s = [1, 1]) | |
num.times.map do | |
r1, r2 = rand, rand | |
x = sqrt(-2 * log(r1)) * cos(2 * PI * r2) | |
y = sqrt(-2 * log(r1)) * sin(2 * PI * r2) | |
[s[0]*x + m[0], s[1]*y + m[1]] | |
end | |
end | |
# 適当な学習パターン(3クラス) | |
learning_patterns = [ | |
generate(100, [25, 35], [9, 3]) + | |
generate(100, [50, 40], [9, 3]), | |
generate(100, [85, 40], [4, 9]) + | |
generate(100, [75, 55], [6, 4]) + | |
generate(100, [70, 25], [9, 3]), | |
generate(100, [30, 80], [9, 3]) + | |
generate(100, [50, 70], [9, 4]) + | |
generate(100, [30, 60], [9, 3]), | |
] | |
lvq = LVQ.new(learning_patterns, 2) | |
lvq.learn | |
voronoi = Voronoi.new(100, 100, lvq.log[-1]) | |
Open3.popen3('gnuplot') do |gp_in, gp_out, gp_err| | |
output_file = "./lvq_patterns.png" | |
gp_in.puts "set terminal png size 480, 450" | |
gp_in.puts "set output '#{output_file}'" | |
gp_in.puts "set xrange [0:100]" | |
gp_in.puts "set yrange [0:100]" | |
plot = "plot " | |
# 各クラスのパターンの色設定 | |
learning_patterns.size.times do |i| | |
plot << "'-' notitle pt 1 ps 0.5 lc #{i+1}," | |
end | |
plot << "\n" | |
learning_patterns.size.times do |i| | |
# 各クラスの学習パターン | |
learning_patterns[i].each do |x, y| | |
plot << "#{x}, #{y}\n" | |
end | |
plot << "e\n" | |
end | |
gp_in.puts plot.gsub(/,\\\n$/, "") | |
gp_in.puts "set output" | |
gp_in.puts "exit" | |
gp_in.close | |
end | |
Open3.popen3('gnuplot') do |gp_in, gp_out, gp_err| | |
output_file = "./lvq_result.gif" | |
gp_in.puts "set terminal gif animate optimize size 480, 450" | |
gp_in.puts "set output '#{output_file}'" | |
gp_in.puts "set xrange [0:100]" | |
gp_in.puts "set yrange [0:100]" | |
lvq.log.each do |representative_patterns| | |
plot = "plot " | |
# 各クラスのパターンの色設定 | |
learning_patterns.size.times do |i| | |
plot << "'-' notitle pt 1 ps 0.5 lc #{i+1}," | |
plot << "'-' notitle pt 7 ps 2 lc #{i+1}," | |
end | |
plot << "\n" | |
learning_patterns.size.times do |i| | |
# 各クラスの学習パターン | |
learning_patterns[i].each do |x, y| | |
plot << "#{x}, #{y}\n" | |
end | |
plot << "e\n" | |
# 各クラスの代表パターン | |
representative_patterns[i].map(&:to_a).each do |x, y| | |
plot << "#{x}, #{y}\n" | |
end | |
plot << "e\n" | |
end | |
gp_in.puts plot.gsub(/,\\\n$/, "") | |
end | |
gp_in.puts "set output" | |
gp_in.puts "exit" | |
gp_in.close | |
end | |
Open3.popen3('gnuplot') do |gp_in, gp_out, gp_err| | |
output_file = "./lvq_voronoi.png" | |
gp_in.puts "set terminal png size 480, 450" | |
gp_in.puts "set output '#{output_file}'" | |
gp_in.puts "set xrange [0:100]" | |
gp_in.puts "set yrange [0:100]" | |
plot = "plot " | |
# 各クラスのパターンの色設定 | |
learning_patterns.size.times do |i| | |
plot << "'-' notitle pt 1 ps 0.5 lc #{i+1}," | |
plot << "'-' notitle pt 7 ps 2 lc #{i+1}," | |
end | |
plot << "'-' notitle pt 1 ps 1 lc rgb '#afeeee'\n" | |
learning_patterns.size.times do |i| | |
# 各クラスの学習パターン | |
learning_patterns[i].each do |x, y| | |
plot << "#{x}, #{y}\n" | |
end | |
plot << "e\n" | |
# 各クラスの代表パターン | |
lvq.log[-1][i].map(&:to_a).each do |x, y| | |
plot << "#{x}, #{y}\n" | |
end | |
plot << "e\n" | |
end | |
# 境界線 | |
voronoi.border.each do |x, y| | |
plot << "#{x}, #{y}\n" | |
end | |
plot << "e\n" | |
gp_in.puts plot.gsub(/,\\\n$/, "") | |
gp_in.puts "set output" | |
gp_in.puts "exit" | |
gp_in.close | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Voronoi | |
def initialize(i, j, generators) | |
@i_size, @j_size = i, j | |
@generators = generators.map.with_index do |c_generators, c| | |
c_generators.map do |generator| | |
{:class => c, :generator => generator} | |
end | |
end.flatten | |
end | |
def border | |
ans = [] | |
@i_size.times.map do |e| | |
(0..9).map { |f| e + eval("0.#{f}") } | |
end.flatten.product ( | |
@i_size.times.map do |e| | |
(0..9).map { |f| e + eval("0.#{f}") } | |
end.flatten ) do |i, j| | |
@generators.map do |generator| | |
gi, gj = generator[:generator].to_a | |
distance = ((i-gi)**2) + ((j-gj)**2) | |
{:class => generator[:class], :distance => distance} | |
end | |
.sort_by { |v| v[:distance] }.tap do |ary| | |
if ary[0][:class] != ary[1][:class] | |
if (ary[0][:distance] - ary[1][:distance]).abs < 9.0 | |
ans << [i, j] | |
end | |
end | |
end | |
end | |
ans | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment