Created
February 14, 2015 16:32
-
-
Save Xevaquor/c963d640ff4e59a21cbb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
MIT license | |
Created on Wed Feb 11 20:03:52 2015 | |
@author: xevaquor | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from scipy.optimize import fmin_bfgs | |
TrainingData = np.array([ | |
[1.,2.,2.], | |
[1.,3.,2.], | |
[1.,2.,3.], | |
[1.,2.,1.], | |
[1.,1.,2.], | |
[1.,3.,3.], | |
[1.,1.,1.], | |
[1.,1.,3.], | |
[1.,3.,1.], | |
[1.,2.,4.], | |
[1.,2.,0.], | |
[1.,0.,2.], | |
[1.,4.,2.] | |
]) | |
m,n = TrainingData.shape | |
Y = np.array([0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1]) | |
def sigmoid(x): | |
return 1/(1 + np.e**(-x)) | |
def h(Theta, X): | |
z = Theta[0] * X[0] + Theta[1] * X[1] + Theta[2] * X[2] +\ | |
Theta[3] * X[1]*X[2] +\ | |
Theta[4] * X[1]**2 + Theta[5] * X[2]**2 +\ | |
Theta[6] * X[1]**2 * X[2]**2 +\ | |
Theta[7] * X[1]**2 * X[2] + Theta[8] * X[1]**2 +\ | |
Theta[9] * X[1]**3 + Theta[10] * X[2]**3 | |
return sigmoid(z) | |
#ostatecznie uznałem, że tak będzie czytelniej | |
def single_penalty(Theta, x, y): | |
return -np.log(h(Theta, x)) if y == 1 else -np.log(1-h(Theta, x)) | |
def J(Theta): | |
sum = 0. | |
for i in range(m): | |
sum += single_penalty(Theta, TrainingData[i,:], Y[i]) | |
return sum/m | |
initialTheta = [0,0,0,0,0,0,0,0,0,0,0] | |
theta = fmin_bfgs(J, initialTheta) | |
print(theta) | |
zeros = TrainingData[:5] | |
ones = TrainingData[5:] | |
plt.scatter(zeros[:,1], zeros[:,2],s=120, marker='x', color='purple') | |
plt.scatter(ones[:,1], ones[:,2],s=120, marker='*', color='orange') | |
xscale = np.linspace(-3,6,500) | |
yscale = np.linspace(-3,6,500) | |
xmesh, ymesh = np.meshgrid(xscale, yscale) | |
zmesh = xmesh**2 + ymesh**2 | |
for row in range(len(yscale)): | |
for col in range(len(xscale)): | |
xx = xmesh[row,col] | |
yy = ymesh[row,col] | |
zmesh[row,col] = h(theta, [1, xx,yy]) | |
plt.contour(xmesh, ymesh, zmesh, [0.01, 0.5, 0.99] ) | |
plt.xlabel('X1') | |
plt.ylabel('X2') | |
plt.xlim([-1,5]) | |
plt.ylim([-1,5]) | |
plt.gca().set_aspect('equal', adjustable='box') | |
#plt.savefig('6.png') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment