Skip to content

Instantly share code, notes, and snippets.

@Xevaquor
Created February 10, 2015 21:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Xevaquor/e128c3c31c48dafd5766 to your computer and use it in GitHub Desktop.
Save Xevaquor/e128c3c31c48dafd5766 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
""" MIT License """
import matplotlib.pyplot as plt
import numpy as np
X = [10,9,8,7,6,5,4,3,2,1,0]
Y = [1,2,4,3,5,4,6,7,4,8,9]
m = len(X)
alpha = 0.01
def h(x, theta):
return theta[0] + theta[1] * x
def J(theta):
sum = 0
for i in range(m):
sum += (h(X[i], theta) - Y[i])**2
return sum / m;
def J_partial0(theta):
sum = 0
for i in range(m):
sum += h(X[i], theta) - Y[i]
return 2 * sum / m;
def J_partial1(theta):
sum = 0
for i in range(m):
sum += (h(X[i], theta) - Y[i]) * X[i]
return 2 * sum / m;
def gradient_descent(startTheta):
theta = startTheta
for i in range(5000):
d0 = J_partial0(theta)
d1 = J_partial1(theta)
theta[0] -= alpha * d0
theta[1] -= alpha * d1
return theta
t = gradient_descent([0,0])
XX = np.linspace(0,10,1000)
YY = h(XX, t)
plt.plot(XX,YY)
plt.scatter(X,Y, color='r')
plt.xlabel('Opuszczone wykłady')
plt.ylabel('Punkty')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment