Skip to content

Instantly share code, notes, and snippets.

@qmaurmann
Created December 15, 2013 01:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save qmaurmann/7967607 to your computer and use it in GitHub Desktop.
Save qmaurmann/7967607 to your computer and use it in GitHub Desktop.
Code to produce first set of pictures (PAC learning for rectangles) for blog post Evolvability and Learning at qmaurmann.wordpress.com.
from __future__ import division
import random
import operator as op
import matplotlib.pyplot as plt
# globals
fig, ax = plt.subplots()
ax.set_aspect('equal')
ax.set_xlim([0,1])
ax.set_ylim([0,1])
def make_oracle(dist, target): # could have been a class
"""Constructs a PAC oracle from a distribution and target function."""
def oracle():
x = dist()
return (x, target(x))
return oracle
class Simulation(object):
"""A class to keep a learning simulation organized.
Assumes an online learning scenario, updating the current hypothesis after
every point seen.
Much of the data carried around (like record and errors) is only there
for making pretty pictures.
"""
def __init__(self, dist, target, init_hyp):
self.oracle = make_oracle(dist, target)
self.hyp = init_hyp
self.target = target
self.record = [] # keeps list of all (input,label) pairs seen
self.errors = [] # optional, except that I want to make an error plot
self.m = 0 # num samples seen so far
def get_samples(self, m):
"""Draw m samples from the oracle, and update the hypothesis, record,
errors, and m (sample size) fields."""
for i in xrange(self.m, self.m + m):
x, y = self.oracle()
self.hyp.update(x, y)
self.record.append((x,y))
self.errors.append(self.unif_area_diff())
self.m += 1
def plot_record(self):
"""Assuming the input space is a subset of the plane, plot all points
seen by the learner, in black for positive examples or white for
negative."""
ax.set_title("After {0} samples".format(self.m))
for (x1,x2), y in self.record:
marker = 'ko' if y else 'wo'
ax.plot(x1, x2, marker)
def plot_hyp(self):
"""Assuming the input space is a subset of the plane, plot the current
hypothesis in blue."""
(xmin, xmax), (ymin, ymax) = self.hyp.components
plt.axhspan(ymin, ymax, xmin=xmin, xmax=xmax, facecolor='blue', alpha=0.3)
def plot_target(self):
"""Assuming the input space is a subset of the plane, plot the target
function in green."""
(xmin, xmax), (ymin, ymax) = self.target.components
plt.axhspan(ymin, ymax, xmin=xmin, xmax=xmax, facecolor='green', alpha=0.3)
def unif_area_diff(self):
"""Measures error, assuming uniform distribution and hypothesis subset
of target."""
return self.target.unif_area() - self.hyp.unif_area()
def plot_errors(self):
ax.set_xlim([0, len(self.errors)])
ax.set_ylim([0., .05])
plt.plot(self.errors)
### Represent the a.a. rectangles prod_i [a_i,b_i] as length-n lists of pairs
### of floats, with (a_i, b_i) = rect[i]
infty = float('inf')
class Rect(object):
"""Callable n-dimensional rectangle class."""
def __init__(self, n):
self.components = [(infty, -infty)] * n
def __call__(self, x):
return all(ai <= xi <= bi for (ai,bi), xi in zip(self.components, x))
def update(self, x, y):
if y:
for i, xi in enumerate(x):
ai, bi = self.components[i]
if not ai <= xi <= bi:
self.components[i] = (min((ai,xi)), max((bi,xi)))
def unif_area(self):
return reduce(op.mul, (b-a if b > a else 0 for a,b in self.components))
rect_dist = lambda: (random.random(), random.random()) # uniform
rect_targ = Rect(2)
rect_targ.components = [(0.3115810690880324, 0.9659225174906901),
(0.15314502453517027, 0.7996318770633059)]
rect_sim = rs = Simulation(rect_dist, rect_targ, Rect(2))
def make_pretty_pics():
global fig, ax
rs.get_samples(10)
rs.plot_record()
plt.savefig("1-just10")
rs.plot_hyp()
plt.savefig("2-hyp10")
rs.plot_target()
ax.set_title("After {0} samples, error = {1:.4f}".format(rs.m, rs.errors[-1]))
plt.savefig("3-targ10")
for i, delta in enumerate([10, 80, 900], 4):
fig, ax = plt.subplots() # clear!
ax.set_aspect('equal')
ax.set_xlim([0,1])
ax.set_ylim([0,1])
rs.get_samples(delta)
rs.plot_record()
rs.plot_hyp()
rs.plot_target()
ax.set_title("After {0} samples, error = {1:.4f}".format(rs.m, rs.errors[-1]))
plt.savefig("{0}-targ{1}".format(i, rs.m))
# now clear again and make the error plots
make_error_pic()
make_error_pic(99)
def make_error_pic(reruns=None):
global fig, ax
fig, ax = plt.subplots() # clear!
plt.plot(rs.errors)
ax.set_ylim([0., .2])
title = "Error vs sample size"
filename = "errors"
if reruns:
title += " ({0} runs)".format(reruns+1)
filename += str(reruns+1)
for _ in xrange(reruns): # independent simulations
sim = Simulation(rect_dist, rect_targ, Rect(2))
sim.get_samples(1000)
ax.plot(sim.errors, '0.75') # gray for reruns
ax.plot(rs.errors, 'b') # blue for main plot
ax.set_title(title)
ax.set_xlabel("m")
ax.set_ylabel("error")
plt.savefig(filename)
if __name__ == "__main__":
make_pretty_pics()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment