Created
November 10, 2016 07:55
-
-
Save taotao54321/d925e41ca3c362ba23e239b8276606f3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# 単純パーセプトロン(Single Layer Perceptron) | |
# 入力ユニット数2(bias除く), 出力1 | |
# AND, OR, XOR でテスト | |
# AND, OR は正しく学習できるが、XOR は学習できないはず | |
# 勾配降下法により学習 | |
# 参考: http://hokuts.com/2015/11/25/ml2_perceptron/ | |
# | |
# 誤差関数は E = max(0, -y*u) を使用。不正解時のみ重み更新を行うので実 | |
# 質的には E = -y*u と考えてよい(つまり微分可能)。 | |
import numpy as np | |
class SLP: | |
def __init__(self): | |
self.w = np.zeros(3) | |
def learn(self, x, y, eta): | |
"""1件の教師データから1回学習""" | |
# 現在のモデルを使って計算 | |
# 結果が正しければ return | |
u = self.w.dot(x) | |
if self._activate(u) == y: return | |
# 結果が誤っていたら勾配降下法により学習 | |
# 誤差関数 E = -y * u | |
# 勾配 grad = -y * x | |
self.w -= eta * (-y * x) | |
def test(self, x): | |
u = self.w.dot(x) | |
return self._activate(u) | |
def _activate(self, u): | |
return 1 if u >= 0 else -1 | |
def test(xs, ys): | |
nn = SLP() | |
print("### Learning ###") | |
for i in range(10000): | |
for x, y in zip(xs, ys): | |
nn.learn(x, y, 0.01) | |
#print("iteration {}: w = {}".format(i, nn.w)) | |
print("w = {}".format(nn.w)) | |
print() | |
print("### Test ###") | |
for x in xs: | |
o = nn.test(x) | |
print("{} -> {}".format(x, o)) | |
print() | |
XS = tuple(map(np.array, ( | |
(1, 0, 0), | |
(1, 0, 1), | |
(1, 1, 0), | |
(1, 1, 1), | |
))) | |
YS_AND = ( -1, -1, -1, 1 ) | |
YS_OR = ( -1, 1, 1, 1 ) | |
YS_XOR = ( -1, 1, 1, -1 ) | |
def main(): | |
print("#------------------------------------------------------------") | |
print("# [AND]") | |
print("#------------------------------------------------------------") | |
test(XS, YS_AND) | |
print("#------------------------------------------------------------") | |
print("# [OR]") | |
print("#------------------------------------------------------------") | |
test(XS, YS_OR) | |
print("#------------------------------------------------------------") | |
print("# [XOR]") | |
print("#------------------------------------------------------------") | |
test(XS, YS_XOR) | |
if __name__ == "__main__": main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment