Created
December 15, 2013 01:37
-
-
Save qmaurmann/7967607 to your computer and use it in GitHub Desktop.
Code to produce first set of pictures (PAC learning for rectangles) for blog post Evolvability and Learning at qmaurmann.wordpress.com.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import random | |
import operator as op | |
import matplotlib.pyplot as plt | |
# globals | |
fig, ax = plt.subplots() | |
ax.set_aspect('equal') | |
ax.set_xlim([0,1]) | |
ax.set_ylim([0,1]) | |
def make_oracle(dist, target): # could have been a class | |
"""Constructs a PAC oracle from a distribution and target function.""" | |
def oracle(): | |
x = dist() | |
return (x, target(x)) | |
return oracle | |
class Simulation(object): | |
"""A class to keep a learning simulation organized. | |
Assumes an online learning scenario, updating the current hypothesis after | |
every point seen. | |
Much of the data carried around (like record and errors) is only there | |
for making pretty pictures. | |
""" | |
def __init__(self, dist, target, init_hyp): | |
self.oracle = make_oracle(dist, target) | |
self.hyp = init_hyp | |
self.target = target | |
self.record = [] # keeps list of all (input,label) pairs seen | |
self.errors = [] # optional, except that I want to make an error plot | |
self.m = 0 # num samples seen so far | |
def get_samples(self, m): | |
"""Draw m samples from the oracle, and update the hypothesis, record, | |
errors, and m (sample size) fields.""" | |
for i in xrange(self.m, self.m + m): | |
x, y = self.oracle() | |
self.hyp.update(x, y) | |
self.record.append((x,y)) | |
self.errors.append(self.unif_area_diff()) | |
self.m += 1 | |
def plot_record(self): | |
"""Assuming the input space is a subset of the plane, plot all points | |
seen by the learner, in black for positive examples or white for | |
negative.""" | |
ax.set_title("After {0} samples".format(self.m)) | |
for (x1,x2), y in self.record: | |
marker = 'ko' if y else 'wo' | |
ax.plot(x1, x2, marker) | |
def plot_hyp(self): | |
"""Assuming the input space is a subset of the plane, plot the current | |
hypothesis in blue.""" | |
(xmin, xmax), (ymin, ymax) = self.hyp.components | |
plt.axhspan(ymin, ymax, xmin=xmin, xmax=xmax, facecolor='blue', alpha=0.3) | |
def plot_target(self): | |
"""Assuming the input space is a subset of the plane, plot the target | |
function in green.""" | |
(xmin, xmax), (ymin, ymax) = self.target.components | |
plt.axhspan(ymin, ymax, xmin=xmin, xmax=xmax, facecolor='green', alpha=0.3) | |
def unif_area_diff(self): | |
"""Measures error, assuming uniform distribution and hypothesis subset | |
of target.""" | |
return self.target.unif_area() - self.hyp.unif_area() | |
def plot_errors(self): | |
ax.set_xlim([0, len(self.errors)]) | |
ax.set_ylim([0., .05]) | |
plt.plot(self.errors) | |
### Represent the a.a. rectangles prod_i [a_i,b_i] as length-n lists of pairs | |
### of floats, with (a_i, b_i) = rect[i] | |
infty = float('inf') | |
class Rect(object): | |
"""Callable n-dimensional rectangle class.""" | |
def __init__(self, n): | |
self.components = [(infty, -infty)] * n | |
def __call__(self, x): | |
return all(ai <= xi <= bi for (ai,bi), xi in zip(self.components, x)) | |
def update(self, x, y): | |
if y: | |
for i, xi in enumerate(x): | |
ai, bi = self.components[i] | |
if not ai <= xi <= bi: | |
self.components[i] = (min((ai,xi)), max((bi,xi))) | |
def unif_area(self): | |
return reduce(op.mul, (b-a if b > a else 0 for a,b in self.components)) | |
rect_dist = lambda: (random.random(), random.random()) # uniform | |
rect_targ = Rect(2) | |
rect_targ.components = [(0.3115810690880324, 0.9659225174906901), | |
(0.15314502453517027, 0.7996318770633059)] | |
rect_sim = rs = Simulation(rect_dist, rect_targ, Rect(2)) | |
def make_pretty_pics(): | |
global fig, ax | |
rs.get_samples(10) | |
rs.plot_record() | |
plt.savefig("1-just10") | |
rs.plot_hyp() | |
plt.savefig("2-hyp10") | |
rs.plot_target() | |
ax.set_title("After {0} samples, error = {1:.4f}".format(rs.m, rs.errors[-1])) | |
plt.savefig("3-targ10") | |
for i, delta in enumerate([10, 80, 900], 4): | |
fig, ax = plt.subplots() # clear! | |
ax.set_aspect('equal') | |
ax.set_xlim([0,1]) | |
ax.set_ylim([0,1]) | |
rs.get_samples(delta) | |
rs.plot_record() | |
rs.plot_hyp() | |
rs.plot_target() | |
ax.set_title("After {0} samples, error = {1:.4f}".format(rs.m, rs.errors[-1])) | |
plt.savefig("{0}-targ{1}".format(i, rs.m)) | |
# now clear again and make the error plots | |
make_error_pic() | |
make_error_pic(99) | |
def make_error_pic(reruns=None): | |
global fig, ax | |
fig, ax = plt.subplots() # clear! | |
plt.plot(rs.errors) | |
ax.set_ylim([0., .2]) | |
title = "Error vs sample size" | |
filename = "errors" | |
if reruns: | |
title += " ({0} runs)".format(reruns+1) | |
filename += str(reruns+1) | |
for _ in xrange(reruns): # independent simulations | |
sim = Simulation(rect_dist, rect_targ, Rect(2)) | |
sim.get_samples(1000) | |
ax.plot(sim.errors, '0.75') # gray for reruns | |
ax.plot(rs.errors, 'b') # blue for main plot | |
ax.set_title(title) | |
ax.set_xlabel("m") | |
ax.set_ylabel("error") | |
plt.savefig(filename) | |
if __name__ == "__main__": | |
make_pretty_pics() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment