Skip to content

Instantly share code, notes, and snippets.

@mehdidc
Created April 19, 2016 21:09
Show Gist options
  • Save mehdidc/4fae166cf492590adedd72c9bd5628b5 to your computer and use it in GitHub Desktop.
Save mehdidc/4fae166cf492590adedd72c9bd5628b5 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from pyearth import Earth
import time
from itertools import product
np.random.seed(2)
def train_model(m, n, p, k):
X = 80 * np.random.uniform(size=(m, n))
y = np.dot(5 * X**2 + 6 * X**3, np.random.uniform(size=(n, p)))
model = Earth(max_terms=k,
check_every=1,
thresh=0,
minspan=1,
endspan=1)
model.fit(X, y)
def train(f, params):
values = []
durations = []
for p in zip(*params):
print(p)
values.append(p)
start = time.time()
train_model(*p)
duration = time.time() - start
durations.append(duration)
return values, durations
# nb_examples nb_features nb_outputs nb_terms
p = range(1, 100, 5)
nb = len(p)
v, durations = train(train_model, ([1000]*nb, [100]*nb, p, [10]*nb))
plt.clf()
plt.plot(p, durations)
plt.xlabel("number of outputs")
plt.ylabel("duration in sec")
plt.savefig("nb_outputs.png")
v, durations = train(train_model, ([1000]*nb, [100]*nb, p, range(10, 30)))
plt.clf()
plt.plot(list(map(lambda k:k[2], v)), durations)
plt.xlabel("number of outputs")
plt.ylabel("duration in sec")
plt.savefig("nb_outputs_and_nb_terms.png")
v, durations = train(train_model, (range(10, 1000, 50), [100]*nb, p, [10]*nb))
plt.clf()
plt.plot(list(map(lambda k:k[2], v)), durations)
plt.xlabel("number of outputs")
plt.ylabel("duration in sec")
plt.savefig("nb_outputs_and_nb_examples.png")
v, durations = train(train_model, ([1000]*nb, range(1, 100, 5), p, [10]*nb))
plt.clf()
plt.plot(list(map(lambda k:k[2], v)), durations)
plt.xlabel("number of outputs")
plt.ylabel("duration in sec")
plt.savefig("nb_outputs_and_nb_features.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment