Skip to content

Instantly share code, notes, and snippets.

@h1dia
Created March 21, 2017 14:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save h1dia/6efa9472c00f6ee1af561353c3e238a9 to your computer and use it in GitHub Desktop.
Save h1dia/6efa9472c00f6ee1af561353c3e238a9 to your computer and use it in GitHub Desktop.
# 1クラス認識しかできないよ
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def d_sigmoid(x):
return sigmoid(x) * (1.0 - sigmoid(x))
# 各層の素子数(項数)定義 + 1はバイアス項の分
input_symbol = 3 + 1
hidden_symbol = 15 + 1
output_symbol = 1
# 学習回数の指定(データセットの学習を何週分するか)
max_epoch = 10
# 学習係数
p = 0.05
error = []
#196人分
#年齢, 身長, 体重, B, W, H
idol = [
[23, 161, 43, 82, 56, 85],
[18, 163, 47, 84, 58, 85],
[22, 160, 51, 92, 58, 90],
[18, 158, 42, 81, 57, 80],
[11, 140, 36, 75, 55, 78],
[16, 154, 55, 92, 59, 88],
[19, 156, 43, 85, 57, 85],
[16, 160, 48, 88, 59, 84],
[14, 151, 41, 78, 55, 77],
[15, 165, 43, 80, 54, 80],
[17, 158, 46, 83, 56, 82],
[17, 161, 46, 85, 57, 84],
[20, 157, 43, 83, 57, 82],
[19, 155, 44, 78, 57, 80],
[16, 156, 41, 78, 55, 77],
[15, 154, 43, 81, 58, 80],
[14, 148, 39, 75, 53, 74],
[21, 160, 44, 86, 56, 81],
[18, 161, 43, 83, 57, 82],
[9, 128, 29, 61, 57, 67],
[19, 165, 44, 81, 56, 80],
[16, 153, 41, 81, 56, 79],
[17, 163, 48, 85, 60, 88],
[14, 156, 42, 76, 55, 79],
[13, 152, 42, 78, 57, 80],
[19, 172, 49, 86, 58, 85],
[19, 161, 46, 80, 57, 80],
[22, 160, 45, 84, 56, 80],
[17, 162, 58, 92, 65, 93],
[16, 170, 56, 105,64, 92],
[15, 157, 41, 83, 55, 82],
[21, 159, 45, 89, 57, 87],
[17, 155, 42, 84, 56, 83],
[20, 156, 44, 81, 58, 83],
[15, 153, 40, 78, 55, 80],
[16, 153, 43, 79, 55, 80],
[16, 153, 42, 79, 57, 80],
[19, 156, 47, 83, 57, 81],
[13, 164, 40, 70, 53, 74],
[28, 152, 47, 92, 58, 84],
[18, 156, 42, 79, 56, 80],
[17, 154, 44, 83, 58, 81],
[28, 159, 44, 87, 57, 85],
[14, 156, 41, 81, 57, 80],
[16, 152, 41, 83, 56, 80],
[17, 159, 44, 75, 57, 78],
[16, 162, 41, 72, 55, 78],
[19, 162, 46, 89, 59, 85],
[17, 158, 43, 75, 57, 79],
[15, 151, 38, 78, 56, 78],
[15, 156, 43, 82, 57, 82],
[25, 172, 50, 88, 60, 89],
[18, 159, 41, 82, 57, 83],
[15, 168, 49, 83, 56, 85],
[19, 160, 43, 86, 56, 86],
[18, 164, 45, 83, 55, 82],
[20, 148, 40, 77, 54, 78],
[16, 154, 41, 78, 54, 81],
[20, 166, 45, 80, 55, 82],
[15, 161, 44, 77, 54, 78],
[20, 163, 45, 86, 57, 86],
[20, 157, 45, 83, 59, 85],
[12, 140, 35, 72, 54, 77],
[14, 142, 37, 74, 52, 75],
[13, 148, 41, 75, 50, 77],
[15, 148, 42, 78, 56, 80],
[17, 155, 42, 82, 59, 86],
[19, 165, 48, 85, 59, 88],
[19, 164, 45, 84, 56, 82],
[17, 156, 46, 87, 57, 85],
[20, 157, 46, 85, 57, 82],
[15, 153, 43, 76, 58, 78],
[17, 162, 46, 91, 56, 86],
[19, 162, 45, 84, 54, 81],
[16, 153, 40, 78, 54, 80],
[12, 145, 39, 72, 53, 75],
[11, 139, 33, 73, 49, 73],
[10, 137, 30, 63, 47, 65],
[26, 166, 47, 87, 57, 87],
[21, 168, 46, 83, 55, 85],
[13, 147, 38, 76, 55, 79],
[18, 163, 45, 82, 56, 81],
[18, 169, 49, 90, 62, 92],
[27, 171, 49, 93, 58, 88],
[15, 165, 44, 80, 56, 81],
[17, 159, 45, 83, 59, 87],
[13, 145, 39, 73, 53, 75],
[13, 156, 42, 77, 53, 76],
[13, 142, 34, 65, 50, 70],
[17, 162, 43, 80, 56, 82],
[12, 149, 36, 72, 54, 75],
[18, 162, 45, 88, 58, 86],
[19, 158, 45, 82, 57, 83],
[14, 155, 43, 78, 55, 80],
[19, 165, 48, 85, 58, 83],
[21, 156, 45, 81, 55, 81],
[25, 160, 46, 83, 60, 89],
[25, 171, 49, 81, 57, 83],
[14, 145, 37, 74, 54, 78],
[31, 167, 51, 91, 62, 90],
[20, 160, 43, 88, 57, 88],
[24, 168, 48, 87, 55, 86],
[16, 155, 42, 74, 60, 79],
[17, 152, 41, 80, 55, 81],
[12, 141, 34, 68, 52, 67],
[18, 153, 44, 86, 56, 82],
[15, 156, 42, 85, 54, 83],
[23, 167, 45, 82, 57, 83],
[18, 161, 46, 86, 58, 88],
[23, 160, 45, 84, 56, 85],
[17, 155, 43, 80, 57, 83],
[18, 149, 40, 77, 57, 81],
[16, 161, 45, 83, 56, 85],
[14, 155, 43, 84, 55, 86],
[22, 160, 44, 80, 57, 82],
[13, 150, 40, 72, 51, 73],
[14, 140, 41, 79, 58, 80],
[17, 158, 45, 82, 56, 80],
[16, 155, 55, 88, 60, 86],
[19, 172, 49, 86, 59, 83],
[19, 165, 45, 82, 55, 85],
[14, 154, 42, 75, 55, 78],
[18, 157, 43, 81, 55, 79],
[15, 157, 46, 84, 57, 85],
[17, 155, 42, 81, 56, 81],
[25, 169, 48, 78, 57, 80],
[22, 168, 50, 92, 58, 85],
[15, 154, 42, 78, 55, 80],
[14, 147, 39, 75, 54, 77],
[17, 162, 43, 86, 55, 84],
[20, 163, 46, 86, 59, 85],
[31, 167, 43, 84, 54, 83],
[17, 148, 40, 80, 60, 82],
[20, 161, 44, 80, 57, 80],
[27, 167, 48, 92, 56, 84],
[10, 132, 28, 64, 56, 70],
[19, 163, 45, 78, 57, 83],
[18, 154, 41, 77, 55, 80],
[16, 161, 43, 80, 55, 84],
[13, 158, 42, 78, 55, 77],
[13, 158, 42, 78, 55, 77],
[17, 165, 45, 80, 59, 83],
[24, 158, 46, 90, 58, 81],
[26, 163, 45, 83, 56, 84],
[16, 155, 42, 83, 55, 81],
[15, 161, 45, 86, 55, 83],
[15, 142, 35, 73, 53, 75],
[16, 157, 44, 81, 58, 80],
[15, 161, 46, 84, 58, 87],
[15, 152, 45, 85, 55, 81],
[19, 162, 46, 86, 57, 91],
[28, 166, 45, 82, 56, 83],
[15, 161, 45, 78, 54, 81],
[18, 160, 47, 90, 56, 86],
[18, 156, 43, 83, 58, 86],
[22, 165, 48, 92, 58, 85],
[21, 161, 44, 81, 56, 81],
[12, 143, 38, 71, 58, 73],
[20, 160, 46, 84, 57, 85],
[22, 165, 46, 85, 57, 83],
[21, 168, 48, 91, 59, 86],
[23, 155, 43, 82, 55, 80],
[18, 164, 47, 80, 54, 81],
[15, 155, 42, 81, 56, 82],
[15, 153, 40, 77, 54, 79],
[26, 165, 46, 85, 60, 85],
[17, 153, 52, 90, 65, 89],
[19, 164, 46, 83, 57, 85],
[14, 149, 39, 75, 56, 80],
[18, 163, 53, 95, 60, 87],
[14, 151, 41, 73, 56, 75],
[13, 146, 37, 74, 53, 76],
[15, 145, 38, 75, 55, 77],
[11, 150, 40, 70, 58, 72],
[21, 156, 45, 77, 54, 76],
[13, 150, 37, 82, 56, 86],
[15, 145, 40, 80, 55, 78],
[14, 149, 38, 73, 55, 76],
[17, 182, 60, 91, 64, 86],
[15, 152, 41, 82, 58, 84],
[18, 160, 45, 85, 56, 83],
[14, 150, 41, 81, 56, 80],
[23, 158, 47, 85, 58, 86],
[14, 144, 33, 75, 54, 77],
[21, 165, 51, 92, 60, 85],
[12, 140, 37, 74, 55, 78],
[11, 130, 28, 62, 50, 65],
[9 , 127, 31, 60, 55, 65],
[17, 166, 43, 86, 60, 85],
[16, 151, 40, 73, 53, 73],
[16, 150, 40, 75, 54, 78],
[9, 132, 32, 65, 51, 70],
[19, 157, 45, 77, 56, 82],
[17, 156, 45, 82, 57, 83],
[16, 145, 38, 72, 53, 75],
[26, 168, 49, 81, 60, 86],
]
# 係数行列の初期化, 0-1の範囲
# 引数は 列,行のサイズ
W1 = np.random.random_sample((input_symbol, hidden_symbol - 1))
W2 = np.random.random_sample((hidden_symbol, output_symbol))
# ネットワークの学習
for epoch in range(0, max_epoch):
err = 0
for i in range(0, 195):
# 学習データの入力
X1 = np.array([1,
idol[i][0] / 31,
idol[i][1] /182,
idol[i][2] / 60])
# 教師信号(出力してほしい値)
S = np.array(idol[i][3])
# 中間層の計算
# 入力と係数のドット積
X2 = np.dot(X1, W1)
# 結果に活性化関数をかける
O2 = sigmoid(X2)
# 出力層の計算
# 行列が逆になっているので転置する
O2 = np.r_[1, O2]
# 中間層の出力と係数のドット積
X3 = np.dot(O2, W2)
# back proapgation
# 誤差表示用
err += abs(X3 - S)
# δoutの計算
D2 = X3 - S
# δhiddenの計算
D1 = d_sigmoid(np.r_[1, X2]) * D2 * W2.T
# 重みの修正
# X1の転置
X1 = X1[np.newaxis, :]
# δoutのバイアス項はδhiddenに伝播しないので除外する
D1 = D1[0][1:]
W1 = W1 - p * (X1.T * D1)
# D2, O2の転置
D2 = D2[:, np.newaxis]
O2 = O2[np.newaxis, :]
W2 = W2 - p * (O2.T * D2)
print(epoch, ", err = ", err)
# 訓練済みネットワークへの入力
X1 = np.array([1,
17 / 31,
139 / 182,
30 / 60])
# 入力と係数のドット積
X2 = np.dot(X1, W1)
# 結果に活性化関数をかける
O2 = sigmoid(X2)
# 転置
O2 = np.r_[1, O2]
X3 = np.dot(O2, W2)
print("Anzu Futaba ", ": ", X3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment