Skip to content

Instantly share code, notes, and snippets.

Last active January 17, 2016 09:37
Show Gist options
  • Save palloc/e658d61a1e5ab8455c39 to your computer and use it in GitHub Desktop.
Save palloc/e658d61a1e5ab8455c39 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas import Series, DataFrame
from numpy.random import normal
N=0 #学習用のデータ数
M=[0,0,0,0] #多項式の次数
def h(x):
y = 10 * (x - 0.5) ** 2
return y
def create_data(num):
dataset = DataFrame(columns=['x', 'y'])
for i in range(num):
x = float(i) / float(num-1) #0~1の間をデータ数で区切る
y = h(x) + normal(scale=0.7) #正規分布の標準偏差0.7の乱数に従った雑音を付加
dataset = dataset.append(Series([x,y], index=['x', 'y']), ignore_index=True)
return dataset
def M_l_estimation(dataset, m):
t = dataset.y
phi = DataFrame()
for i in range(0, m+1):
p = dataset.x ** i = "x ** %d" % i
phi = pd.concat([phi, p], axis = 1)
ws =, phi)), phi.T), t) #w=(Φ^T*Φ)^(−1)*Φ^T*tを計算
def f(x):
y = 0.0
for i, w in enumerate(ws):
y += w * (x ** i)
return y
s = 0.0
for index, line in dataset.iterrows():
s += (f(line.x) - line.y) ** 2
s /= len(dataset)
return(f, ws, np.sqrt(s))
def main():
N = int(raw_input("学習用データの個数N:"))
M = [int(i) for i in (raw_input("多項式の次数M(スペース区切りで4つ入力):")).split(" ")]
training_set = create_data(N)
df_ws = DataFrame()
fig = plt.figure()
for c, m in enumerate(M):
f, ws, s = M_l_estimation(training_set, m)
df_ws = df_ws.append(Series(ws, name="M=%d" % m))
subplot = fig.add_subplot(2, 2, c+1)
subplot.set_xlim(-0.1, 1.1)
subplot.set_ylim(-7, 7)
subplot.set_title("M:%d" % m)
line_x = np.linspace(0, 1, 101)
line_y = h(line_x)
subplot.plot(line_x, line_y, color = "green", linestyle = "--")
subplot.scatter(training_set.x, training_set.y, marker = "x", color = "blue")
line_x = np.linspace(0, 1, 101)
line_y = f(line_x)
label = "simga=%.2f" % s
subplot.plot(line_x, line_y, color = "red", label = label)
subplot.plot(line_x, line_y + s, color = "red", linestyle = "--")
subplot.plot(line_x, line_y - s, color = "red", linestyle = "--")
subplot.legend(loc = 1)
if __name__ == '__main__':
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment