Last active
September 9, 2015 09:35
-
-
Save seinosuke/eef6495c1b463b156d87 to your computer and use it in GitHub Desktop.
Rubyでパーセプトロンを実装してみた 参照→http://syoshinsyakangeisagi.blogspot.com/2015/09/ruby.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "./perceptron" | |
require 'gnuplot' | |
learning_patterns = [ | |
[ [14], [17] ], | |
[ [12], [5], [2], [-2] ], | |
[ [-4], [-6], [-15] ] | |
] | |
perceptron = Perceptron.new(learning_patterns, 1) | |
perceptron.learn | |
output_file = "./example01.png" | |
Gnuplot.open do |gp| | |
Gnuplot::Plot.new( gp ) do |plot| | |
plot.title "Example 01" | |
plot.set "terminal png size 640, 300" | |
plot.output output_file | |
plot.xrange "[-20:20]" | |
plot.yrange "[-1:1]" | |
plot.xlabel "x" | |
# 数直線 | |
lx = (-20..20).map { |v| v.to_f } | |
ly = lx.map { |_| 0 } | |
plot.data << Gnuplot::DataSet.new( [lx, ly] ) do |ds| | |
ds.with = "lines" | |
ds.linecolor = "rgb 'black'" | |
ds.notitle | |
end | |
# 各学習パターン | |
learning_patterns.each_with_index do |patterns, i| | |
px = patterns | |
py = px.map { |_| 0 } | |
plot.data << Gnuplot::DataSet.new( [px, py] ) do |ds| | |
ds.with = "points ls 1 pt 7 ps 2 lc #{i}" | |
ds.notitle | |
end | |
end | |
# 決定境界 | |
perceptron.weight_vectors.map{ |v| v.map(&:to_f) }.tap do |ary| | |
break (ary.size - 1).times.map { |i| ary[i, 2] } | |
end.each do |wi, wj| | |
y = (-1..1).map(&:to_f) | |
x = y.map { |_| (wi[0] - wj[0]) / (wj[1] - wi[1]) } | |
plot.data << Gnuplot::DataSet.new( [x, y] ) do |ds| | |
ds.with = "lines ls 1 lc 2 lw 2" | |
ds.notitle | |
end | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "./perceptron" | |
require 'gnuplot' | |
learning_patterns = [ | |
[ [-3, 3], [2, 8], [5, 3], [8, 5], [-7, 4], [-5, -2] ], | |
[ [-8, -9], [-2, -8], [5, -9], [5, 0], [0, -5] ] | |
] | |
perceptron = Perceptron.new(learning_patterns, 2) | |
perceptron.learn | |
output_file = "./example02.png" | |
Gnuplot.open do |gp| | |
Gnuplot::Plot.new( gp ) do |plot| | |
plot.title "Example 02" | |
plot.set "terminal png size 480, 480" | |
plot.output output_file | |
plot.xrange "[-10:10]" | |
plot.yrange "[-10:10]" | |
plot.xlabel "x" | |
plot.ylabel "y" | |
# 各学習パターン | |
learning_patterns.each_with_index do |patterns, i| | |
px = patterns | |
py = px.map { 0 } | |
plot.data << Gnuplot::DataSet.new( [px, py] ) do |ds| | |
ds.with = "points ls 1 pt 7 ps 2 lc #{i}" | |
ds.notitle | |
end | |
end | |
# 決定境界 | |
perceptron.weight_vectors.map{ |v| v.map(&:to_f) }.tap do |ary| | |
break (ary.size - 1).times.map { |i| ary[i, 2] } | |
end.each do |wi, wj| | |
x = (-20..20).map(&:to_f) | |
y = x.map do |v| | |
v * ((wj[1]-wi[1]) / (wi[2]-wj[2])) + ((wj[0]-wi[0]) / (wi[2]-wj[2])) | |
end | |
plot.data << Gnuplot::DataSet.new( [x, y] ) do |ds| | |
ds.with = "lines ls 1 lc 2 lw 2" | |
ds.notitle | |
end | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require 'matrix' | |
class Perceptron | |
RHO = 1 | |
attr_accessor :weight_vectors | |
# 重みベクトルの初期化とか | |
def initialize(learning_patterns, dimension) | |
@learning_patterns = learning_patterns.map do |patterns| | |
patterns.map { |pattern| Vector[1, *pattern] } | |
end | |
@class_num = @learning_patterns.size | |
@weight_vectors = Array.new(@class_num) do |i| | |
Vector[ *Array.new(dimension + 1) { i } ] | |
end | |
end | |
# クラスiの識別関数 g_i(x) | |
def discriminate(pattern, i) | |
@weight_vectors[i].inner_product(pattern) | |
end | |
# 誤識別があった場合、重みベクトルを修正する | |
def correct_errors | |
@class_num.times do |i| | |
@learning_patterns[i].each do |pattern| | |
# jは識別結果 | |
j = @class_num.times.map do |c| | |
{ :class => c, :val => discriminate(pattern, c) } | |
end.max_by { |e| e[:val] }[:class] | |
# それが学習パターンのクラスと異なるなら修正 | |
if i != j | |
@weight_vectors[i] += RHO * pattern | |
@weight_vectors[j] -= RHO * pattern | |
end | |
end | |
end | |
end | |
# 重みベクトルが更新されなくなるまで誤り訂正を繰り返す | |
def learn | |
correct_errors | |
prev = nil | |
while prev != @weight_vectors | |
prev = Marshal.load(Marshal.dump(@weight_vectors)) | |
correct_errors | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment