Created
May 1, 2017 05:57
-
-
Save hnakagawa/2994c5639c20580b1f3878bb88ff3e4d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
''' | |
φ(x) = [x^0, x^1,...x^n-1] の縦ベクトル | |
''' | |
def phi(x, n): | |
return np.array([x**i for i in range(n)]).reshape(n, -1) | |
''' | |
式1.70 m(x) | |
''' | |
def mean(x, xl, tl, beta, s, n): | |
ss = np.matrix(np.zeros((n, 1))) | |
for i in range(len(xl)): | |
ss += phi(xl[i], n) * tl[i] | |
return beta * phi(x, n).T * s * ss | |
''' | |
式1.71 s^2(x) | |
''' | |
def variance(x, xl, beta, s, n): | |
p = phi(x, n) | |
return beta**-1 + p.T * s * p | |
''' | |
式1.72 S^-1 | |
''' | |
def S_inverse(xl, alpha, beta, n): | |
ss = np.matrix(np.zeros((n, n))) | |
for i in range(len(xl)): | |
p = phi(xl[i], n) | |
ss += p * p.T | |
identity = np.matrix(np.identity(n)) | |
return alpha * identity + beta * ss | |
def test(x, xl, tl, beta, s, n): | |
m = mean(x, xl, tl, beta, s, n)[0,0] | |
sd = np.sqrt(variance(x, xl, beta, s, n)[0,0]) | |
return m, m + sd, m - sd # 平均、平均+標準偏差, 平均-標準偏差 | |
def train(xn, xl, tl, alpha, beta, n): | |
s = S_inverse(xl, alpha, beta, n).getI() | |
return [test(x, xl, tl, beta, s, n) for x in xn] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment