Created
April 19, 2016 21:09
-
-
Save mehdidc/4fae166cf492590adedd72c9bd5628b5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib as mpl | |
mpl.use('Agg') | |
import matplotlib.pyplot as plt | |
from pyearth import Earth | |
import time | |
from itertools import product | |
np.random.seed(2) | |
def train_model(m, n, p, k): | |
X = 80 * np.random.uniform(size=(m, n)) | |
y = np.dot(5 * X**2 + 6 * X**3, np.random.uniform(size=(n, p))) | |
model = Earth(max_terms=k, | |
check_every=1, | |
thresh=0, | |
minspan=1, | |
endspan=1) | |
model.fit(X, y) | |
def train(f, params): | |
values = [] | |
durations = [] | |
for p in zip(*params): | |
print(p) | |
values.append(p) | |
start = time.time() | |
train_model(*p) | |
duration = time.time() - start | |
durations.append(duration) | |
return values, durations | |
# nb_examples nb_features nb_outputs nb_terms | |
p = range(1, 100, 5) | |
nb = len(p) | |
v, durations = train(train_model, ([1000]*nb, [100]*nb, p, [10]*nb)) | |
plt.clf() | |
plt.plot(p, durations) | |
plt.xlabel("number of outputs") | |
plt.ylabel("duration in sec") | |
plt.savefig("nb_outputs.png") | |
v, durations = train(train_model, ([1000]*nb, [100]*nb, p, range(10, 30))) | |
plt.clf() | |
plt.plot(list(map(lambda k:k[2], v)), durations) | |
plt.xlabel("number of outputs") | |
plt.ylabel("duration in sec") | |
plt.savefig("nb_outputs_and_nb_terms.png") | |
v, durations = train(train_model, (range(10, 1000, 50), [100]*nb, p, [10]*nb)) | |
plt.clf() | |
plt.plot(list(map(lambda k:k[2], v)), durations) | |
plt.xlabel("number of outputs") | |
plt.ylabel("duration in sec") | |
plt.savefig("nb_outputs_and_nb_examples.png") | |
v, durations = train(train_model, ([1000]*nb, range(1, 100, 5), p, [10]*nb)) | |
plt.clf() | |
plt.plot(list(map(lambda k:k[2], v)), durations) | |
plt.xlabel("number of outputs") | |
plt.ylabel("duration in sec") | |
plt.savefig("nb_outputs_and_nb_features.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment