Skip to content

Instantly share code, notes, and snippets.

@eigenein
Created August 26, 2018 12:29
Show Gist options
  • Save eigenein/7743271fd2125919f03cceb8db6a1692 to your computer and use it in GitHub Desktop.
Save eigenein/7743271fd2125919f03cceb8db6a1692 to your computer and use it in GitHub Desktop.
TTestSearchCV
class TTestSearchCV:
def __init__(self, estimator, param_grid, *, cv, scoring, alpha=0.95):
self.estimator = estimator
self.param_grid: Dict[str, Any] = param_grid
self.cv = cv
self.scoring = scoring
self.alpha = alpha
self.p = 1.0 - alpha
self.best_params_: Optional[Dict[str, Any]] = None
self.best_score_: Optional[float] = None
self.best_scores_: Optional[numpy.ndarray] = None
self.best_confidence_interval_: Optional[numpy.ndarray] = None
def fit(self, x, y):
for values in product(*self.param_grid.values()):
params = dict(zip(self.param_grid.keys(), values))
self.estimator.set_params(**params)
logger.log(SPAM, f'🤖 CV started: {params}')
scores: numpy.ndarray = cross_val_score(self.estimator, x, y, scoring=self.scoring, cv=self.cv)
score: float = scores.mean()
logger.debug(f'🤖 Score: {score:.4f} with {params}.')
if not self.is_better_score(score, scores):
continue
logger.debug(f'🤖 Found significantly better score: {score:.4f}.')
self.best_params_ = params
self.best_score_ = score
self.best_scores_ = scores
self.best_confidence_interval_ = stats.t.interval(self.alpha, len(scores) - 1, loc=scores.mean(), scale=stats.sem(scores))
def is_better_score(self, score: float, scores: numpy.ndarray) -> bool:
if self.best_params_ is None:
return True
if score < self.best_score_:
return False
_, p_value = stats.ttest_ind(self.best_scores_, scores)
logger.log(SPAM, f'🤖 P-value: {p_value:.4f}.')
return p_value < self.p
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment