Skip to content

Instantly share code, notes, and snippets.

@AbhinavMadahar
Last active February 5, 2017 03:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AbhinavMadahar/529de23fdce9a3c164b18794d740621c to your computer and use it in GitHub Desktop.
Save AbhinavMadahar/529de23fdce9a3c164b18794d740621c to your computer and use it in GitHub Desktop.
A simple Regression AI
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"]) # Point(1, 2).x == 1 and Point(1, 2).y == 2
mean = lambda values: sum(values) / len(values)
class Regression(object):
def __init__(self, points):
self.slope = 0 # we'll use gradient descent to find this
self.point = Point(0, 0) # we'll use that stat fact to recalculate this
self.learn_from_data_set(points)
def __str__(self):
return "Regression with slope " + str(self.slope) + " and point " + str(self.point)
# make it possible to write Regression[x] to calculate the value at x
def __getitem__(self, x):
return self.slope * (x - self.point.x) + self.point.y
# total inaccuracy
def cost(self, data_points):
return sum((self[point.x] - point.y) ** 2 for point in data_points)
def derivative_of_cost_with_respect_to_slope(self, data_points):
infinitesimal = 10 ** -5
original_cost = self.cost(data_points)
self.slope += infinitesimal
new_cost = self.cost(data_points)
self.slope -= infinitesimal
return (new_cost - original_cost) / infinitesimal
def learn_from_data_set(self, data_points):
# (average x, average y) MUST be on the regression
mean_x = mean([point.x for point in data_points])
mean_y = mean([point.y for point in data_points])
self.point = Point(mean_x, mean_y)
# use gradient descent to find slope
eta = 0.0001
while abs(self.derivative_of_cost_with_respect_to_slope(data_points)) > 0.000001:
self.slope -= eta * self.derivative_of_cost_with_respect_to_slope(data_points)
# demo
print Regression([
Point(30., 66.),
Point(34., 79.),
Point(27., 70.),
Point(25., 60.),
Point(17., 48.),
Point(23., 55.),
Point(20., 60.)
])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment