Created
July 17, 2021 09:54
-
-
Save RikiyaOta/211b7391ca1d80cb8efa77f363573cf0 to your computer and use it in GitHub Desktop.
Logistic回帰(勾配降下法)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class LogisticRegressionGradientDescent: | |
def __init__(self, eta=0.05, n_iter=100, random_state=1): | |
self.eta = eta | |
self.n_iter = n_iter | |
self.random_state = random_state | |
def fit(self, X, y): | |
rgen = np.random.RandomState(self.random_state) | |
_, n_features = X.shape | |
self.w_ = rgen.normal(loc=0.0, scale=0.01, size = 1 + n_features) | |
for i in range(self.n_iter): | |
net_input = self.net_input(X) | |
output = self.activation(net_input) | |
errors = y - output | |
self.w_[1:] += self.eta * np.dot(X.T, errors) | |
self.w_[0] += self.eta * errors.sum() | |
return self | |
def net_input(self, X): | |
return np.dot(X, self.w_[1:]) + self.w_[0] | |
def activation(self, z): | |
""" | |
ロジスティックシグモイド関数による活性化関数 | |
""" | |
return 1.0 / (1.0 + np.exp(-np.clip(z, -250, 250))) | |
def predict(self, X): | |
return np.where(self.activation(self.net_input(X)) >= 0.5, 1, 0) | |
# 以下でも同じ結果 | |
# return np.where(self.net_input(X) >= 0.0, 1, 0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment