Created
May 9, 2017 06:18
-
-
Save ytbilly3636/e3ba8b95e24b06ed18ad30b040ad90bc to your computer and use it in GitHub Desktop.
パブロフの犬モデル?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import sys | |
import numpy as np | |
class Net(): | |
def __init__(self, input_size=2): | |
if input_size < 1: | |
print 'Error: input_size (argument of Net.__init__) should be > 0.' | |
sys.exit(1) | |
self.inputs_buf = [] | |
self.weight_val = [] | |
for i in xrange(input_size): | |
self.inputs_buf.append(0.0) | |
self.weight_val.append(0.0) | |
return | |
def give(self, inputs): | |
if not len(inputs) == len(self.inputs_buf): | |
print 'Error: Size of inputs (argument of Net.give) is different.' | |
sys.exit(1) | |
self.inputs_buf = inputs | |
return np.dot(inputs, self.weight_val) | |
def update(self, reward=True, learning_rate=0.05): | |
for i, buf in enumerate(self.inputs_buf): | |
if buf == 0.0 or reward == False: | |
self.weight_val[i] = self.weight_val[i] - learning_rate | |
else: | |
self.weight_val[i] = self.weight_val[i] + learning_rate * buf | |
self.weight_val = np.clip(self.weight_val, 0, 1) | |
return | |
# --- | |
net = Net(input_size=2) | |
def feed(): | |
out = net.give([1.0, 0.0]) | |
net.update(reward=True) | |
return out | |
def bell(): | |
out = net.give([0.0, 1.0]) | |
net.update(reward=False) | |
return out | |
def feed_and_bell(): | |
out = net.give([1.0, 1.0]) | |
net.update(reward=True) | |
return out | |
# --- | |
if __name__ == '__main__': | |
# 唾液が出たと判断するしきい値 | |
threshold = 0.8 | |
# 餌だけを与える | |
print 'Only Feed' | |
for i in xrange(20): | |
saliva = feed() | |
print i, ('わんわん' if saliva > threshold else '...') | |
# ベルだけ鳴らす | |
print 'Only Bell' | |
for i in xrange(20): | |
saliva = bell() | |
print i, ('わんわん' if saliva > threshold else '...') | |
# 餌を与えベルを鳴らす | |
print 'Feed and Bell' | |
for i in xrange(20): | |
saliva = feed_and_bell() | |
print i, ('わんわん' if saliva > threshold else '...') | |
# ベルだけ鳴らす | |
print 'Only Bell' | |
for i in xrange(20): | |
saliva = bell() | |
print i, ('わんわん' if saliva > threshold else '...') | |
sys.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment