Skip to content

Instantly share code, notes, and snippets.

@taotao54321
Created November 10, 2016 07:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save taotao54321/d925e41ca3c362ba23e239b8276606f3 to your computer and use it in GitHub Desktop.
Save taotao54321/d925e41ca3c362ba23e239b8276606f3 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# 単純パーセプトロン(Single Layer Perceptron)
# 入力ユニット数2(bias除く), 出力1
# AND, OR, XOR でテスト
# AND, OR は正しく学習できるが、XOR は学習できないはず
# 勾配降下法により学習
# 参考: http://hokuts.com/2015/11/25/ml2_perceptron/
#
# 誤差関数は E = max(0, -y*u) を使用。不正解時のみ重み更新を行うので実
# 質的には E = -y*u と考えてよい(つまり微分可能)。
import numpy as np
class SLP:
def __init__(self):
self.w = np.zeros(3)
def learn(self, x, y, eta):
"""1件の教師データから1回学習"""
# 現在のモデルを使って計算
# 結果が正しければ return
u = self.w.dot(x)
if self._activate(u) == y: return
# 結果が誤っていたら勾配降下法により学習
# 誤差関数 E = -y * u
# 勾配 grad = -y * x
self.w -= eta * (-y * x)
def test(self, x):
u = self.w.dot(x)
return self._activate(u)
def _activate(self, u):
return 1 if u >= 0 else -1
def test(xs, ys):
nn = SLP()
print("### Learning ###")
for i in range(10000):
for x, y in zip(xs, ys):
nn.learn(x, y, 0.01)
#print("iteration {}: w = {}".format(i, nn.w))
print("w = {}".format(nn.w))
print()
print("### Test ###")
for x in xs:
o = nn.test(x)
print("{} -> {}".format(x, o))
print()
XS = tuple(map(np.array, (
(1, 0, 0),
(1, 0, 1),
(1, 1, 0),
(1, 1, 1),
)))
YS_AND = ( -1, -1, -1, 1 )
YS_OR = ( -1, 1, 1, 1 )
YS_XOR = ( -1, 1, 1, -1 )
def main():
print("#------------------------------------------------------------")
print("# [AND]")
print("#------------------------------------------------------------")
test(XS, YS_AND)
print("#------------------------------------------------------------")
print("# [OR]")
print("#------------------------------------------------------------")
test(XS, YS_OR)
print("#------------------------------------------------------------")
print("# [XOR]")
print("#------------------------------------------------------------")
test(XS, YS_XOR)
if __name__ == "__main__": main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment