Created
September 10, 2015 03:28
-
-
Save seinosuke/eea033ac85025ed0a20e to your computer and use it in GitHub Desktop.
Rubyでパーセプトロン (1次元2クラスにしか対応してない版)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
require "./perceptron" | |
require 'gnuplot' | |
# クラス1とクラス2それぞれの学習パターン | |
learning_patterns = [ | |
[12, 5, 2, -2], | |
[-4, -6, -15] | |
] | |
perceptron = Perceptron.new(learning_patterns) | |
perceptron.learn | |
# | |
# 特徴空間をプロット | |
# | |
output_file = "./feature_space.png" | |
Gnuplot.open do |gp| | |
Gnuplot::Plot.new( gp ) do |plot| | |
plot.title "feature space" | |
plot.set "terminal png size 640, 300" | |
plot.output output_file | |
plot.xrange "[-20:20]" | |
plot.yrange "[-1:1]" | |
plot.xlabel "x" | |
# 数直線 | |
lx = [*-20..20].map(&:to_f) | |
ly = lx.map { |_| 0 } | |
plot.data << Gnuplot::DataSet.new( [lx, ly] ) do |ds| | |
ds.with = "lines" | |
ds.linecolor = "rgb 'black'" | |
ds.notitle | |
end | |
# 各学習パターン | |
learning_patterns.each_with_index do |patterns, i| | |
px = patterns | |
py = px.map { |_| 0 } | |
plot.data << Gnuplot::DataSet.new( [px, py] ) do |ds| | |
ds.with = "points ls 1 pt 7 ps 2 lc #{i}" | |
ds.notitle | |
end | |
end | |
# 決定境界 | |
y = (-1..1).map(&:to_f) | |
x = y.map do |_| | |
-perceptron.weight_vector[0].to_f / | |
perceptron.weight_vector[1].to_f | |
end | |
plot.data << Gnuplot::DataSet.new( [x, y] ) do |ds| | |
ds.with = "lines ls 1 lc 2 lw 2" | |
ds.notitle | |
end | |
end | |
end | |
# | |
# 重み空間をプロット | |
# | |
output_file = "./weight_space.png" | |
Gnuplot.open do |gp| | |
Gnuplot::Plot.new( gp ) do |plot| | |
plot.title "weight space" | |
plot.set "terminal png size 400, 420" | |
plot.output output_file | |
plot.xrange "[-3:6]" | |
plot.yrange "[-3:6]" | |
plot.xlabel "W1" | |
plot.ylabel "W0" | |
# 各学習パターンに対応した超平面 (赤色) | |
learning_patterns.flatten.each do |pattern| | |
x = [*-20..20].map(&:to_f) | |
y = x.map { |v| -pattern * v } | |
plot.data << Gnuplot::DataSet.new( [x, y] ) do |ds| | |
ds.with = "lines ls 1 lc rgb 'red'" | |
ds.notitle | |
end | |
end | |
# 重みベクトルの更新履歴 (青色) | |
vx = perceptron.log.map(&:rotate) | |
vy = vx.map { |_| 0 } | |
plot.data << Gnuplot::DataSet.new( [vx, vy] ) do |ds| | |
ds.with = "linesp ls 1 pt 7 lw 2 lc rgb 'blue'" | |
ds.notitle | |
end | |
end | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Perceptron | |
RHO = 1 | |
attr_accessor :weight_vector, :log | |
def initialize( learning_patterns ) | |
@learning_patterns = learning_patterns.map do |patterns| | |
patterns.map { |pattern| [1, pattern] } | |
end | |
@weight_vector = [1, 1] | |
@log = [] | |
end | |
# 識別関数 g(x) | |
def discriminate( pattern ) | |
@weight_vector.zip( pattern ).inject(0) do |sum, wx| | |
sum + wx.inject(:*) | |
end | |
end | |
# 誤識別があった場合重みベクトルを修正する | |
def correct_errors | |
@learning_patterns.size.times do |i| | |
@learning_patterns[i].each do |pattern| | |
case i | |
# w' = w + ρx | |
when ->(c) { c == 0 && discriminate( pattern ) <= 0 } | |
@log << @weight_vector | |
@weight_vector = | |
@weight_vector.zip( pattern.map { |e| RHO * e } ) | |
.map { |a| a.inject(:+) } | |
# w' = w - ρx | |
when ->(c) { c == 1 && discriminate(pattern) >= 0 } | |
@log << @weight_vector | |
@weight_vector = | |
@weight_vector.zip( pattern.map { |e| RHO * e } ) | |
.map { |a| a.inject(:-) } | |
end | |
end | |
end | |
end | |
# 重みベクトルが更新されなくなるまで誤り訂正を繰り返す | |
def learn | |
prev = nil | |
while prev != @weight_vector | |
prev = @weight_vector | |
correct_errors | |
end | |
@log << @weight_vector | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment