Last active
August 29, 2015 14:10
-
-
Save sinhrks/b5ee53af5c02224ab05b to your computer and use it in GitHub Desktop.
Logistic Regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import matplotlib.animation as animation | |
# rpy2 経由で R から iris をロード | |
# import pandas.rpy.common as com | |
# iris = com.load_data('iris') | |
# http://aima.cs.berkeley.edu/data/iris.csv | |
names = ['Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width', 'Species'] | |
iris = pd.read_csv('iris.csv', header=None, names=names) | |
np.random.seed(1) | |
# 描画領域のサイズ | |
FIGSIZE = (5, 3.5) | |
# 2 クラスにするため、setosa, versicolor のデータのみ抽出 | |
data = iris[:100] | |
# 説明変数は 2つ | |
columns = ['Petal.Width', 'Petal.Length'] | |
x = data[columns] # データ | |
y = data['Species'] # ラベル | |
def plot_x_by_y(x, y, colors, ax=None): | |
if ax is None: | |
# 描画領域を作成 | |
fig = plt.figure(figsize=FIGSIZE) | |
# 描画領域に Axes を追加、マージン調整 | |
ax = fig.add_subplot(1, 1, 1) | |
fig.subplots_adjust(bottom=0.15) | |
x1 = x.columns[0] | |
x2 = x.columns[1] | |
for (species, group), c in zip(x.groupby(y), colors): | |
ax = group.plot(kind='scatter', x=x1, y=x2, | |
color=c, ax=ax, figsize=FIGSIZE) | |
return ax | |
plot_x_by_y(x, y, colors=['red', 'blue']) | |
plt.show() | |
# ラベルを0, 1の列に変換 | |
y = (y == 'setosa').astype(int) | |
# index と同じ長さの配列を作成し、ランダムにシャッフル | |
indexer = np.arange(x.shape[0]) | |
np.random.shuffle(indexer) | |
# x, y を シャッフルされた index の順序に並び替え | |
x = x.iloc[indexer, ] | |
y = y.iloc[indexer, ] | |
def p_y_given_x(x, w, b): | |
# x, w, b から y の推測値を計算 | |
def sigmoid(a): | |
return 1.0 / (1.0 + np.exp(-a)) | |
return sigmoid(np.dot(x, w) + b) | |
def grad(x, y, w, b): | |
# 現予測値から勾配を計算 | |
error = y - p_y_given_x(x, w, b) | |
w_grad = -np.mean(x.T * error, axis=1) | |
b_grad = -np.mean(error) | |
return w_grad, b_grad | |
def gd(x, y, w, b, eta=0.1, num=100): | |
for i in range(1, num): | |
# 入力をまとめて処理 | |
w_grad, b_grad = grad(x, y, w, b) | |
w -= eta * w_grad | |
b -= eta * b_grad | |
e = np.mean(np.abs(y - p_y_given_x(x, w, b))) | |
yield i, w, b, e | |
def sgd(x, y, w, b, eta=0.1, num=4): | |
for i in range(1, num): | |
for index in range(x.shape[0]): | |
# 一行ずつ処理 | |
_x = x.iloc[[index], ] | |
_y = y.iloc[[index], ] | |
w_grad, b_grad = grad(_x, _y, w, b) | |
w -= eta * w_grad | |
b -= eta * b_grad | |
e = np.mean(np.abs(y - p_y_given_x(x, w, b))) | |
yield i, w, b, e | |
def msgd(x, y, w, b, eta=0.1, num=25, batch_size=10): | |
for i in range(1, num): | |
for index in range(0, x.shape[0], batch_size): | |
# index 番目のバッチを切り出して処理 | |
_x = x[index:index + batch_size] | |
_y = y[index:index + batch_size] | |
w_grad, b_grad = grad(_x, _y, w, b) | |
w -= eta * w_grad | |
b -= eta * b_grad | |
e = np.mean(np.abs(y - p_y_given_x(x, w, b))) | |
yield i, w, b, e | |
def plot_logreg(x, y, fitter, title): | |
# 描画領域を作成 | |
fig = plt.figure(figsize=FIGSIZE) | |
# 描画領域に Axes を追加、マージン調整 | |
ax = fig.add_subplot(1, 1, 1) | |
fig.subplots_adjust(bottom=0.15) | |
# 描画オブジェクト保存用 | |
objs = [] | |
# 回帰直線描画用の x 座標 | |
bx = np.arange(x.iloc[:, 0].min(), x.iloc[:, 0].max(), 0.1) | |
# w, b の初期値を作成 | |
w, b = np.zeros(2), 0 | |
# 勾配法の関数からジェネレータを生成 | |
gen = fitter(x, y, w, b) | |
# ジェネレータを実行し、勾配法 1ステップごとの結果を得る | |
for i, w, b, e in gen: | |
# 回帰直線の y 座標を計算 | |
by = -b/w[1] - w[0]/w[1]*bx | |
# 回帰直線を生成 | |
l = ax.plot(bx, by, color='gray', linestyle='dashed') | |
# 描画するテキストを生成 | |
wt = """Iteration = {0} times | |
w = [{1[0]:.2f}, {1[1]:.2f}] | |
b = {2:.2f} | |
error = {3:.3}""".format(i, w, b, e) | |
# axes 上の相対座標 (0.1, 0.9) に text の上部を合わせて描画 | |
t = ax.text(0.1, 0.9, wt, va='top', transform=ax.transAxes) | |
# 描画した line, text をアニメーション用の配列に入れる | |
objs.append(tuple(l) + (t, )) | |
# データ, 表題を描画 | |
ax = plot_x_by_y(x, y, colors=['red', 'blue'], ax=ax) | |
ax.set_title(title) | |
# アニメーション開始 | |
ani = animation.ArtistAnimation(fig, objs, interval=1, repeat=False) | |
plt.show() | |
# ファイル保存する場合は plt.show をコメントアウトして以下を使う | |
# ani.save(title + '.gif', fps=15) | |
plot_logreg(x, y, gd, title='Gradient descent') | |
plot_logreg(x, y, sgd, title='Stochastic gradient descent') | |
plot_logreg(x, y, msgd, title='Minibatch SGD') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment