Skip to content

Instantly share code, notes, and snippets.

@keltecc
Created September 3, 2019 20:08
Show Gist options
  • Save keltecc/16b6a8a6827d92d8013e52cfd78abe4f to your computer and use it in GitHub Desktop.
Save keltecc/16b6a8a6827d92d8013e52cfd78abe4f to your computer and use it in GitHub Desktop.
decision stump solver
#!/usr/bin/env python3
from collections import namedtuple
Point = namedtuple('Point', ['x', 'y'])
class Stump(object):
def __init__(self, c1, c2, m):
self.c1 = c1
self.c2 = c2
self.m = m
def answer(self, question):
if question <= self.m:
return self.c1
return self.c2
def find_c1_c2(points, m):
left, right = [], []
for point in points:
if point.x <= m:
left.append(point)
else:
right.append(point)
c1 = sum(point.y for point in left) / len(left)
c2 = sum(point.y for point in right) / len(right)
return c1, c2
def calculate_error(points, stump):
error = 0
for point in points:
error += (stump.answer(point.x) - point.y) ** 2
return error
def find_stump(points):
points = sorted(points, key=lambda point: point.x)
ms = [(left.x + right.x) / 2 for left, right in zip(points, points[1:])]
best_stump, best_error = None, None
for m in ms:
c1, c2 = find_c1_c2(points, m)
stump = Stump(c1, c2, m)
error = calculate_error(points, stump)
if best_stump is None or error < best_error:
best_stump = stump
best_error = error
return stump
def main():
coordinates = [(1, 1), (2, 1), (3, 0), (4, 0)]
points = [Point(*coords) for coords in coordinates]
stump = find_stump(points)
print(stump.__dict__)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment