Last active
February 5, 2017 03:33
-
-
Save AbhinavMadahar/529de23fdce9a3c164b18794d740621c to your computer and use it in GitHub Desktop.
A simple Regression AI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import namedtuple | |
Point = namedtuple("Point", ["x", "y"]) # Point(1, 2).x == 1 and Point(1, 2).y == 2 | |
mean = lambda values: sum(values) / len(values) | |
class Regression(object): | |
def __init__(self, points): | |
self.slope = 0 # we'll use gradient descent to find this | |
self.point = Point(0, 0) # we'll use that stat fact to recalculate this | |
self.learn_from_data_set(points) | |
def __str__(self): | |
return "Regression with slope " + str(self.slope) + " and point " + str(self.point) | |
# make it possible to write Regression[x] to calculate the value at x | |
def __getitem__(self, x): | |
return self.slope * (x - self.point.x) + self.point.y | |
# total inaccuracy | |
def cost(self, data_points): | |
return sum((self[point.x] - point.y) ** 2 for point in data_points) | |
def derivative_of_cost_with_respect_to_slope(self, data_points): | |
infinitesimal = 10 ** -5 | |
original_cost = self.cost(data_points) | |
self.slope += infinitesimal | |
new_cost = self.cost(data_points) | |
self.slope -= infinitesimal | |
return (new_cost - original_cost) / infinitesimal | |
def learn_from_data_set(self, data_points): | |
# (average x, average y) MUST be on the regression | |
mean_x = mean([point.x for point in data_points]) | |
mean_y = mean([point.y for point in data_points]) | |
self.point = Point(mean_x, mean_y) | |
# use gradient descent to find slope | |
eta = 0.0001 | |
while abs(self.derivative_of_cost_with_respect_to_slope(data_points)) > 0.000001: | |
self.slope -= eta * self.derivative_of_cost_with_respect_to_slope(data_points) | |
# demo | |
print Regression([ | |
Point(30., 66.), | |
Point(34., 79.), | |
Point(27., 70.), | |
Point(25., 60.), | |
Point(17., 48.), | |
Point(23., 55.), | |
Point(20., 60.) | |
]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment