Skip to content

Instantly share code, notes, and snippets.

@rmurphy2718
Created September 8, 2021 23:40
Show Gist options
  • Save rmurphy2718/43b0b4e326207b3b6558c1d56c04189a to your computer and use it in GitHub Desktop.
Save rmurphy2718/43b0b4e326207b3b6558c1d56c04189a to your computer and use it in GitHub Desktop.
A very simple strategy for mitigating a possible problem with `GridSearchCV` in `sklearn`: when the grid search crashes prematurely.

This post documents, for quick reference, a very simple strategy for mitigating a possible problem with GridSearchCV in sklearn. What happens if your process crashes in the middle of your grid search? You can lose everything, even hours of tuning. This post documents a very simple solution that can be set up in under a minute.

The idea, suggested on this StackOverflow post is to

  1. Use a verbose option, so that sklearn prints the performance of each model after a fold finishes
  2. Direct that output to a persistent file on disk.

This solution does not save all the information that is returned after fit() completes. Moreover, it does not provide a way to jump right back into the grid search where the program crashed, and relies in setting up a new search to complete the grid. However, I think we can agree it is much better than losing hours of compute, and has tremendous value for the minimal amount of work needed to set it up.

"""Demonstration of logging GridSearchCV results to disk.

As we run fit() on a sklearn GridSearchCV object, results will be written to a
persistent file.  This is nice as it can save at least some information if the
process crashes.

Inspired by:
stackoverflow.com/questions/48058044/gridsearchcv-save-result-each-iteration

Functions:
    * grid_search_with_logging - Runs a grid and saves basic results to disk
                                 as they finish.
"""

import os
import sys
from time import time

import numpy as np
from sklearn.datasets import make_regression
from sklearn.linear_model import Lasso
from sklearn.model_selection import GridSearchCV

LOG_FILE_NAME = 'results.log'
def grid_search_with_logging(grid_search_obj, log_file_path):
    """Run an sklearn grid search and write results after each fold finishes.

    Each time a fold finishes, the performance will be written to the log file.
    This does not contain nearly as much information as the final returned
    object, but it's better than nothing if the process crashes somewhere.

    Arguments
    -----------
    grid_search_obj : GridSearchCV
        Object that contains a model, a hyperparameter grid, and other
        cross-validation information.
    log_file_path : str
        Filepath where basic results are written.

    """
    # Redirect stdout to a file.
    old_stdout = sys.stdout
    log_file_path = open(log_file_path, 'w')
    sys.stdout = log_file_path

    grid_search_obj.fit(X, y)

    sys.stdout = old_stdout
    log_file_path.close()

    return grid_searcher
# Set up a very simple regression problem and tune a small grid.
X, y = make_regression()

param_dict = [{'alpha': 2**np.arange(3, 9)}]
grid_searcher = GridSearchCV(Lasso(), param_dict, cv=3,
                             verbose=9,
                             scoring='neg_mean_squared_error')


grid_search_with_logging(grid_searcher, LOG_FILE_NAME)
GridSearchCV(cv=3, estimator=Lasso(),
             param_grid=[{'alpha': array([  8,  16,  32,  64, 128, 256])}],
             scoring='neg_mean_squared_error', verbose=9)
# Print the age of the log file, to show it was just created
file_age = time() - os.stat(LOG_FILE_NAME).st_mtime
print(f'File {LOG_FILE_NAME} is {file_age} seconds old')
File results.log is 0.10371923446655273 seconds old
# Print the contents.
with open(LOG_FILE_NAME) as log_file:
    for line in log_file.readlines():
        print(line)
Fitting 3 folds for each of 6 candidates, totalling 18 fits

[CV 1/3] END .......................alpha=8;, score=-1806.850 total time=   0.0s

[CV 2/3] END ........................alpha=8;, score=-503.415 total time=   0.0s

[CV 3/3] END ........................alpha=8;, score=-568.153 total time=   0.0s

[CV 1/3] END ......................alpha=16;, score=-4585.145 total time=   0.0s

[CV 2/3] END ......................alpha=16;, score=-1753.650 total time=   0.0s

[CV 3/3] END ......................alpha=16;, score=-2088.400 total time=   0.0s

[CV 1/3] END .....................alpha=32;, score=-10742.334 total time=   0.0s

[CV 2/3] END ......................alpha=32;, score=-6209.026 total time=   0.0s

[CV 3/3] END ......................alpha=32;, score=-6968.894 total time=   0.0s

[CV 1/3] END .....................alpha=64;, score=-23641.038 total time=   0.0s

[CV 2/3] END .....................alpha=64;, score=-20046.968 total time=   0.0s

[CV 3/3] END .....................alpha=64;, score=-13932.981 total time=   0.0s

[CV 1/3] END ....................alpha=128;, score=-36446.537 total time=   0.0s

[CV 2/3] END ....................alpha=128;, score=-32023.400 total time=   0.0s

[CV 3/3] END ....................alpha=128;, score=-27136.854 total time=   0.0s

[CV 1/3] END ....................alpha=256;, score=-36446.537 total time=   0.0s

[CV 2/3] END ....................alpha=256;, score=-32023.400 total time=   0.0s

[CV 3/3] END ....................alpha=256;, score=-27136.854 total time=   0.0s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment