Skip to content

Instantly share code, notes, and snippets.

@aflansburg
Last active August 1, 2021 18:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aflansburg/12e64fc9bfb7cbf744b1595c37703e28 to your computer and use it in GitHub Desktop.
Save aflansburg/12e64fc9bfb7cbf744b1595c37703e28 to your computer and use it in GitHub Desktop.
Calculate GridSearchCV runtime
# runtime info based on solution below and fit_time results of the gridsearchcv return object
# based on a response on StackExchange Data Science - Naveen Vuppula
# https://datascience.stackexchange.com/a/93524/41883
# from time import time
def gridsearch_runtime(grid_obj, X_train, y_train):
'''
Parameters:
grid_obj: GridSearchCV return object that has not yet been fit to training data
X_train: split training data independent variables
y_train: split training data containing dependent variable
'''
start = time()
grid_obj.fit(X_train, y_train)
end = time()
mean_fit_time= grid_obj.cv_results_['mean_fit_time']
mean_score_time= grid_obj.cv_results_['mean_score_time']
n_splits = grid_obj.n_splits_ #number of splits of training data
n_iter = pd.DataFrame(grid_obj.cv_results_).shape[0] #Iterations per split
time_from_cv_result = np.mean(mean_fit_time + mean_score_time) * n_splits * n_iter
time_from_sys_time = end - start
return time_from_cv_result, time_from_sys_time
# obviously you would use this (and write the method above) in a different manner
# but this is more for use as a very basic compute performance testing tool
# assuming some estimator (tuned_estimator), parameters, and scorer
grid_obj = GridSearchCV(tuned_estimator, parameters, scoring=acc_scorer,cv=5)
res_time, res_sys_time = gridsearch_runtime(grid_obj, X_train, y_train)
print('calculated runtime from GridSearchCV cv_results_ attributes')
print(res_time)
print('calculated runtime from using time()')
print(res_sys_time)
# Output (time units are in seconds):
# calculated runtime from GridSearchCV cv_results_ attributes
# 293.37535548210144
# calculated runtime from using time()
# 295.12124705314636
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment