Created
July 11, 2021 07:27
-
-
Save RikiyaOta/45aa2e674246c94c2fbecc264d9f3f23 to your computer and use it in GitHub Desktop.
Perceptron Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class Perceptron: | |
def __init__(self, eta=0.01, n_iter=50, random_state=1): | |
""" | |
eta: 学習率 | |
n_iter: 学習を繰り返す回数 | |
""" | |
self.eta = eta | |
self.n_iter = n_iter | |
self.random_state = random_state | |
def fit(self, X, y): | |
rgen = np.random.RandomState(self.random_state) | |
dim = X.shape[1] | |
# 重みの初期化 | |
self.w_ = rgen.normal(loc=0.0, scale=0.01, size=1+dim) | |
self.errors_ = [] | |
for _ in range(self.n_iter): | |
errors = 0 | |
for xi, target in zip(X, y): | |
update = self.eta * (target - self.predict(xi)) | |
# 重み w_0 の更新 | |
self.w_[0] += update * 1 | |
# それ以外の重みの更新 | |
self.w_[1:] += update * xi | |
# 誤分類の数をカウントする | |
errors += int(update != 0.0) | |
# 反復ごとの誤分類の数をカウントする | |
# これがだんだん少なくなると良さそうな雰囲気がする(雰囲気だけ) | |
self.errors_.append(errors) | |
return self | |
def net_input(self, X): | |
return np.dot(X, self.w_[1:]) + self.w_[0] | |
def predict(self, X): | |
return np.where(self.net_input(X) >= 0.0, 1, -1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment