Skip to content

Instantly share code, notes, and snippets.

@seinosuke
Created September 10, 2015 03:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save seinosuke/eea033ac85025ed0a20e to your computer and use it in GitHub Desktop.
Save seinosuke/eea033ac85025ed0a20e to your computer and use it in GitHub Desktop.
Rubyでパーセプトロン (1次元2クラスにしか対応してない版)
require "./perceptron"
require 'gnuplot'
# クラス1とクラス2それぞれの学習パターン
learning_patterns = [
[12, 5, 2, -2],
[-4, -6, -15]
]
perceptron = Perceptron.new(learning_patterns)
perceptron.learn
#
# 特徴空間をプロット
#
output_file = "./feature_space.png"
Gnuplot.open do |gp|
Gnuplot::Plot.new( gp ) do |plot|
plot.title "feature space"
plot.set "terminal png size 640, 300"
plot.output output_file
plot.xrange "[-20:20]"
plot.yrange "[-1:1]"
plot.xlabel "x"
# 数直線
lx = [*-20..20].map(&:to_f)
ly = lx.map { |_| 0 }
plot.data << Gnuplot::DataSet.new( [lx, ly] ) do |ds|
ds.with = "lines"
ds.linecolor = "rgb 'black'"
ds.notitle
end
# 各学習パターン
learning_patterns.each_with_index do |patterns, i|
px = patterns
py = px.map { |_| 0 }
plot.data << Gnuplot::DataSet.new( [px, py] ) do |ds|
ds.with = "points ls 1 pt 7 ps 2 lc #{i}"
ds.notitle
end
end
# 決定境界
y = (-1..1).map(&:to_f)
x = y.map do |_|
-perceptron.weight_vector[0].to_f /
perceptron.weight_vector[1].to_f
end
plot.data << Gnuplot::DataSet.new( [x, y] ) do |ds|
ds.with = "lines ls 1 lc 2 lw 2"
ds.notitle
end
end
end
#
# 重み空間をプロット
#
output_file = "./weight_space.png"
Gnuplot.open do |gp|
Gnuplot::Plot.new( gp ) do |plot|
plot.title "weight space"
plot.set "terminal png size 400, 420"
plot.output output_file
plot.xrange "[-3:6]"
plot.yrange "[-3:6]"
plot.xlabel "W1"
plot.ylabel "W0"
# 各学習パターンに対応した超平面 (赤色)
learning_patterns.flatten.each do |pattern|
x = [*-20..20].map(&:to_f)
y = x.map { |v| -pattern * v }
plot.data << Gnuplot::DataSet.new( [x, y] ) do |ds|
ds.with = "lines ls 1 lc rgb 'red'"
ds.notitle
end
end
# 重みベクトルの更新履歴 (青色)
vx = perceptron.log.map(&:rotate)
vy = vx.map { |_| 0 }
plot.data << Gnuplot::DataSet.new( [vx, vy] ) do |ds|
ds.with = "linesp ls 1 pt 7 lw 2 lc rgb 'blue'"
ds.notitle
end
end
end
class Perceptron
RHO = 1
attr_accessor :weight_vector, :log
def initialize( learning_patterns )
@learning_patterns = learning_patterns.map do |patterns|
patterns.map { |pattern| [1, pattern] }
end
@weight_vector = [1, 1]
@log = []
end
# 識別関数 g(x)
def discriminate( pattern )
@weight_vector.zip( pattern ).inject(0) do |sum, wx|
sum + wx.inject(:*)
end
end
# 誤識別があった場合重みベクトルを修正する
def correct_errors
@learning_patterns.size.times do |i|
@learning_patterns[i].each do |pattern|
case i
# w' = w + ρx
when ->(c) { c == 0 && discriminate( pattern ) <= 0 }
@log << @weight_vector
@weight_vector =
@weight_vector.zip( pattern.map { |e| RHO * e } )
.map { |a| a.inject(:+) }
# w' = w - ρx
when ->(c) { c == 1 && discriminate(pattern) >= 0 }
@log << @weight_vector
@weight_vector =
@weight_vector.zip( pattern.map { |e| RHO * e } )
.map { |a| a.inject(:-) }
end
end
end
end
# 重みベクトルが更新されなくなるまで誤り訂正を繰り返す
def learn
prev = nil
while prev != @weight_vector
prev = @weight_vector
correct_errors
end
@log << @weight_vector
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment