Created
February 12, 2016 03:24
-
-
Save buyoh/f18e79e80d38557784dc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#---------------------- | |
#ロジスティック回帰 | |
require "matrix" | |
def sigmoid(x) | |
1.0/(1.0+Math::exp(-x)) | |
end | |
class LearnLogistic | |
def initialize(dim,eta) | |
@dim=dim | |
@eta=eta | |
@weight=Vector.elements([0.0]*dim) | |
end | |
def test(val) | |
(@weight.covector*val.covector.t).det | |
end | |
def update(data,cls) | |
yn=sigmoid(test(data)) | |
@weight=@weight-@eta*(yn-cls)*data | |
end | |
def learn(data,cls,n) | |
end | |
def dim;@dim;end | |
def weight;@weight;end | |
def weight=(w);@weight=w;end | |
def eta;@eta;end | |
def eta=(w);@eta=w;end | |
end | |
#データの入力 | |
data=[] | |
while cin=gets | |
cl=cin.split(",") | |
data+=[[cl[0]=="t",[cl[1].to_i,cl[2].to_i,1]]] | |
end | |
#データ集合から取り出し | |
size=data.size | |
dim=3 | |
ax=[] | |
at=[] | |
data.each{|e| | |
ax+=[e[1]] | |
at+=[e[0]==true ? 1.0 : 0.0] | |
} | |
mx=Matrix.columns(ax) # 入力データ | |
vt=Vector.elements(at) # クラス分類 | |
#------------------------------- | |
#計算 | |
mlearn=LearnLogistic.new(dim,0.2) | |
maxacc=nil | |
maxw=nil | |
100.times{|cnt| | |
# Learning | |
size.times{|i| | |
mlearn.update(mx.column(i),vt[i]) | |
} | |
# Test | |
acc=0 | |
size.times{|i| | |
t=(mlearn.test(mx.column(i)) > 0) == data[i][0] | |
acc+=1 if t | |
} | |
if acc==size | |
puts "solve "+cnt.to_s | |
maxacc=nil | |
break | |
end | |
mlearn.eta=mlearn.eta*0.95 if mlearn.eta > 1.0e-3 | |
} | |
puts "Result Vector" ; p mlearn.weight | |
puts "Test..." | |
#試しに計算してみる | |
size.times{|i| | |
puts sprintf("res=%8.3f Data={%s}",mlearn.test(mx.column(i)),data[i]*",") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment