Skip to content

Instantly share code, notes, and snippets.

@raamana
Created September 13, 2017 12:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save raamana/24a1c8ed0cc4d66944742ff96ec4c510 to your computer and use it in GitHub Desktop.
Save raamana/24a1c8ed0cc4d66944742ff96ec4c510 to your computer and use it in GitHub Desktop.
Code to reproduce grid search freeze
import sys
import timeit
from os.path import join as pjoin
import logging
import traceback
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import mutual_info_classif, SelectKBest
from sklearn.model_selection import GridSearchCV, ShuffleSplit
from sklearn.pipeline import Pipeline
rf = RandomForestClassifier(max_features=8, n_estimators=100, oob_score=True)
feat_selector = SelectKBest(score_func=mutual_info_classif, k=10)
fs_name = 'MI_top_K'
clf_name = 'random_forest_clf'
steps = [(fs_name , feat_selector),
(clf_name, rf)]
pipeline = Pipeline(steps)
param_name = lambda string: '{}__{}'.format(clf_name, string)
param_grid = {param_name('min_samples_leaf'): range(1, 5, 2),
param_name('max_features'): range(1, 6, 2),
param_name('n_estimators'): range(50, 250, 50)}
inner_cv = ShuffleSplit(n_splits=25, train_size=0.8)
gs = GridSearchCV(estimator=pipeline, param_grid=param_grid, cv=inner_cv,
verbose=2)
print(gs)
cur_dir = '.'
# train_data_mat = np.genfromtxt(pjoin(cur_dir, 'test_data20.csv'), delimiter=',', dtype='float')
train_data_full = np.genfromtxt(pjoin(cur_dir, 'JS_sklearn_test.txt'))
train_labels = np.genfromtxt(pjoin(cur_dir, 'labels_sklearn_test.txt'), dtype='int')
log_file = pjoin(cur_dir,'logfile_cv.txt')
logging.basicConfig(filename=log_file,level=logging.INFO)
def get_stop_time(start):
return timeit.default_timer() - start
for data_dim in range(100, 5000, 62000):
train_data_subset = train_data_full[:,:data_dim]
print('\ndata size: {} \n'.format(train_data_subset.shape))
start = timeit.default_timer()
try:
gs.fit(train_data_subset, train_labels)
except:
print('fit failed ')
traceback.print_exc()
log_msg = 'gridsearch at dimensionality {} just done after {} msecs.\n' \
' Best score: {}\nBest params: {}'.format(train_data_subset.shape,
get_stop_time(start), gs.best_score_, gs.best_params_)
print(log_msg)
logging.info(log_msg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment