Created
November 11, 2013 15:26
-
-
Save mmtootmm/7414904 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
''' | |
パーセプトロンの学習規則に従い誤り訂正法を用いて | |
パーセプトロンの収束定理を検証するサンプルプログラム。 | |
線形分離可能な2つのクラスについて固定増分法で線形識別関数の重みを求める。 | |
''' | |
import numpy as np | |
def learn(weight, cluster_a, cluster_b, rho): | |
result_weight = weight | |
error = True | |
# 全てのパターンを正確に識別できるまで重みを調整する。 | |
# 識別関数{weight(dot)pattern}はクラスaに対して正の値、クラスbに対して負の値を返すものとする。 | |
while error: | |
print result_weight | |
error = False | |
for pattern in cluster_a: | |
ex_pattern = np.array([1., pattern]) | |
if result_weight.dot(ex_pattern) <= 0: | |
result_weight += ex_pattern * rho | |
error = True | |
for pattern in cluster_b: | |
ex_pattern = np.array([1., pattern]) | |
if result_weight.dot(ex_pattern) >= 0: | |
result_weight -= ex_pattern * rho | |
error = True | |
return result_weight | |
def run(): | |
# 重みに適当な初期値を入れる | |
weight = np.array([2., -7.]) | |
# 一次元の特徴空間で表されるパターン群 | |
cluster_a = np.array([ 1.2, 0.2, -0.2]) | |
cluster_b = np.array([-0.5, -1.0, -1.5]) | |
weight = learn(weight, cluster_a, cluster_b, 1.2) | |
print weight | |
if __name__ == '__main__': | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment