Skip to content

Instantly share code, notes, and snippets.

@Saurabh7
Created June 8, 2016 13:53
Show Gist options
  • Save Saurabh7/ab36961c51cd5e69213096527b4d2810 to your computer and use it in GitHub Desktop.
Save Saurabh7/ab36961c51cd5e69213096527b4d2810 to your computer and use it in GitHub Desktop.
from modshogun import RealFeatures, RegressionLabels, LeastAngleRegression, PruneVarSubMean, PNorm
import numpy as np
from time import time
from sklearn.linear_model import LassoLars, Lars
from sklearn.metrics import mean_squared_error as mse
from sklearn import preprocessing
ran1=[3000]
ran2=[5000]#,100000]
num_feat=10
num_vec=100
def gen_data(num_feat, num_vec):
X=np.random.rand(num_vec, num_feat)
xtest=np.random.rand(num_vec/10, num_feat)
w=np.zeros((1, num_feat))
var=np.array(range(0, num_feat))
var=np.random.permutation(var)
for i in range(num_feat/3):
w[0][var[i]]=np.random.randint(100)
y=np.dot(w,X.T)
ytest=np.dot(w, xtest.T)
y=np.array(y)
y=np.reshape(y,(num_vec,))
ytest=np.reshape(ytest,(num_vec/10,))
return X, y, xtest, ytest
def run():
for f in ran1:
for v in ran2:
X,y, xtest, ytest =gen_data(f,v)
y_mean = y.mean(axis=0)
y-=y_mean
ytest-=y_mean
feat = RealFeatures(X.T)
lab = RegressionLabels(y)
ftest = RealFeatures(xtest.T)
ltest = RegressionLabels(ytest)
p1=PruneVarSubMean()
p2=PNorm(2)
p1.init(feat)
p1.apply_to_feature_matrix(feat)
p1.apply_to_feature_matrix(ftest)
p2.init(feat)
p2.apply_to_feature_matrix(feat)
p2.apply_to_feature_matrix(ftest)
t1=time()
lambda1=0.01
modelsg = LeastAngleRegression(False)
#modelsg.()
modelsg.set_max_non_zero(700)
modelsg.set_labels(lab)
modelsg.parallel.set_num_threads(3)
modelsg.train(feat)
t2=time()
timesg=t2-t1
t3=time()
modelsk = Lars(n_nonzero_coefs=700,precompute=False)
modelsk.fit(X, y)
#out = model.coef
t4=time()
timesk=t4-t3
print modelsg.get_path_size()
#print modelsg.get_w()
print modelsk.n_iter_
#print modelsk.coef_
sgout=modelsg.apply(ftest)
skout=modelsk.predict(xtest)
sgout1=sgout.get_labels()
#print sgout1, skout, ytest
outsg=mse(ytest, sgout1)
outsk=mse(ytest, skout)
print "------------\n"
print 'N:', v, 'D:', f, 'shogun: %.6f , mse: %.6f | scikit: %.6f , mse: %.6f ' % (timesg, outsg\
,timesk, outsk)
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment