Skip to content

Instantly share code, notes, and snippets.

@vittorio-nardone
Created November 18, 2020 09:15
Show Gist options
  • Save vittorio-nardone/716832a9abbafb7526ff863909c42119 to your computer and use it in GitHub Desktop.
Save vittorio-nardone/716832a9abbafb7526ff863909c42119 to your computer and use it in GitHub Desktop.
Metaflow steps to perform hyperparameters tuning in Prophet
@step
def hyper_tuning(self):
"""
Hyperparameters tuning
"""
# Tune hyperparameters of the model
param_grid = {
'changepoint_prior_scale': [0.001, 0.01, 0.1, 0.5],
'seasonality_prior_scale': [0.01, 0.1, 1.0, 10.0],
}
# Generate all combinations of parameters
self.all_params = [dict(zip(param_grid.keys(), v)) for v in itertools.product(*param_grid.values())]
# Use cross validation to evaluate all parameters
self.next(self.cross_validation, foreach='all_params')
@step
def cross_validation(self):
"""
Perform cross-validation on given hyperparameters
"""
# Fit model with given params
m = Prophet(**self.input).fit(self.df)
# Perform cross-validation
df_cv = cross_validation(m, initial='730 days', period='180 days', horizon = '365 days', parallel="processes")
df_p = performance_metrics(df_cv, rolling_window=1)
# Store the RMSE
self.rmses = df_p['rmse'].values[0]
self.next(self.train)
@step
def train(self, inputs):
"""
Check cross-validation results and find best parameters.
A new Prophet model is fitted.
"""
# Merge artifacts
self.merge_artifacts(inputs, exclude=['rmses'])
# Get RMSEs from previous steps
rmses = [input.rmses for input in inputs]
# Find the best parameters
self.hyperparameters = self.all_params[np.argmin(rmses)]
# Fit a new model using best params
self.m = Prophet(**self.hyperparameters)
self.m.fit(self.df)
self.next(self.end)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment