Skip to content

Instantly share code, notes, and snippets.

@agramfort
Created April 26, 2020 12:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save agramfort/ca54cc1bc12a37d7a426a7799cc236ce to your computer and use it in GitHub Desktop.
Save agramfort/ca54cc1bc12a37d7a426a7799cc236ce to your computer and use it in GitHub Desktop.
"""
Benchmark of MultiTaskLasso
"""
import gc
from itertools import product
from time import time
import numpy as np
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.linear_model import MultiTaskLasso
def compute_bench(alpha, n_samples, n_features, n_tasks):
results = []
n_bench = len(n_samples) * len(n_features) * len(n_tasks)
for it, (ns, nf, nt) in enumerate(product(n_samples, n_features, n_tasks)):
print('==================')
print('Iteration %s of %s' % (it, n_bench))
print('==================')
n_informative = nf // 10
X, Y, coef_ = make_regression(n_samples=ns, n_features=nf,
n_informative=n_informative,
n_targets=nt,
noise=0.1, coef=True)
X /= np.sqrt(np.sum(X ** 2, axis=0)) # Normalize data
gc.collect()
clf = MultiTaskLasso(alpha=alpha, fit_intercept=False)
tstart = time()
clf.fit(X, Y)
results.append(
dict(n_samples=ns, n_features=nf, n_tasks=nt, time=time() - tstart)
)
return pd.DataFrame(results)
def compare_results():
results_new = pd.read_csv('mlt_new.csv').set_index(['n_samples', 'n_features', 'n_tasks'])
results_old = pd.read_csv('mlt_old.csv').set_index(['n_samples', 'n_features', 'n_tasks'])
results_ratio = (results_old / results_new)
results_ratio.columns = ['time (old) / time (new)']
print(results_new)
print(results_old)
print(results_ratio)
if __name__ == '__main__':
import matplotlib.pyplot as plt
alpha = 0.01 # regularization parameter
list_n_features = [300, 1000, 4000]
list_n_samples = [100, 500]
list_n_tasks = [2, 10, 20, 50]
results = compute_bench(alpha, list_n_samples,
list_n_features, list_n_tasks)
# results.to_csv('mlt_old.csv', index=False)
results.to_csv('mlt_new.csv', index=False)
compare_results()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment