Last active
April 26, 2016 07:10
-
-
Save masaponto/b2eab035acc8198a1f4b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import numpy as np | |
import random | |
class MLP: | |
def __init__(self, mid_num, out_num, epochs, r = 0.5, a = 1): | |
"""mlp using sigmoid | |
mid_num: 中間層ノード数 | |
out_num: 出力層ノード数 | |
epochs: 学習回数 | |
r: 学習率 | |
a: シグモイド関数の定数 | |
""" | |
self.mid_num = mid_num | |
self.out_num = out_num | |
self.epochs = epochs | |
self.r = r | |
self.a = a | |
def sigmoid(self, x): | |
return 1 / (1 + np.exp(-self.a * x)) | |
def sigmoid_(self, x): | |
"""シグモイド関数微分 | |
""" | |
return self.a * x * (1.0 - x) | |
def calc_out(self, w_vs, x_v): | |
return self.sigmoid(np.dot(w_vs, x_v)) | |
def out_error(self, d_v, out_v): | |
"""出力層の誤差 | |
""" | |
return (out_v - d_v) * self.sigmoid_(out_v) | |
def mid_error(self, mid_v, eo_v): | |
"""中間層の誤差 | |
""" | |
return np.dot(self.wo_vs.T, eo_v) * self.sigmoid_(mid_v) | |
def w_update(self, w_vs, e_v, i_v): | |
"""重み更新 | |
""" | |
e_v = np.atleast_2d(e_v) | |
i_v = np.atleast_2d(i_v) | |
return w_vs - self.r * np.dot(e_v.T, i_v) | |
def add_bias(self, x_v): | |
"""バイアス項追加 | |
""" | |
return np.append(x_v, 1) | |
def fit(self, X, y): | |
"""学習 | |
""" | |
x_vs = X | |
d_vs = y | |
x_vs = [self.add_bias(x) for x in x_vs] | |
x_vd = len(x_vs[0]) | |
# 重み | |
self.wm_vs = np.random.uniform(-1, 1., (self.mid_num, x_vd)) | |
self.wo_vs = np.random.uniform(-1., 1., (self.out_num, self.mid_num)) | |
for n in range(self.epochs): | |
for d_v, x_v in zip(d_vs, x_vs): | |
# forward phase | |
# 中間層の結果 | |
mid_v = self.calc_out(self.wm_vs, x_v) | |
mid_v[-1] = -1 | |
# 出力層の結果 | |
out_v = self.calc_out(self.wo_vs, mid_v) | |
# backward phase | |
# 出力層の誤差 | |
eo_v = self.out_error(d_v, out_v) | |
# 中間層 | |
em_v = self.mid_error(mid_v, eo_v) | |
# weight update | |
# 中間層 | |
self.wm_vs = self.w_update(self.wm_vs, em_v, x_v) | |
# 出力層 | |
self.wo_vs = self.w_update(self.wo_vs, eo_v, mid_v) | |
def predict(self, x_v): | |
x_v = self.add_bias(x_v) | |
mid_v = self.calc_out(self.wm_vs, x_v) | |
out_v = self.calc_out(self.wo_vs, mid_v) | |
return out_v | |
if __name__ == "__main__": | |
# data | |
X_train = [[0, 0], [1, 0], [0, 1], [1, 1]] | |
y_train = [0, 1, 1, 0] | |
mid_num = 5 | |
out_num = 1 | |
epochs = 10000 | |
mlp = MLP(mid_num, out_num, epochs) | |
mlp.fit(X_train, y_train) | |
result = [mlp.predict(x) for x in X_train] | |
[print(r, ":", y) for r, y in zip(result, y_train)] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment