Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jpmallette/8cdfda4f1734ad8c7174b3ffe6c0d416 to your computer and use it in GitHub Desktop.
Save jpmallette/8cdfda4f1734ad8c7174b3ffe6c0d416 to your computer and use it in GitHub Desktop.
Execute Cross Validation and Performance Loop
def execute_cross_validation_and_performance_loop(cross_valid_params, metric = 'mse'):
""" Execute Cross Validation and Performance Loop
Parameters
----------
cross_valid_params: List of dict
dict value same as cross_validation function argument
model, horizon, period, initial
metric: string
sort the dataframe in ascending order base on the
performance metric of your choice either mse, rmse, mae or mape
Returns
-------
A pd.DataFrame with cross_validation result. One row
per different configuration sorted ascending base on
the metric inputed by the user.
Example
--------
>>> m = Prophet()
>>> df = pd.read_csv('/examples/example_wp_log_peyton_manning.csv')
>>> m.fit(df)
>>> cross_valid_params = [{'model': m,
'initial': '730 days',
'period': '180 days',
'horizon': '365 days'},
{'model': m,
'initial': '500 days',
'period': '180 days',
'horizon': '365 days'}]
index initial horizon period mse rmse mae mape coverage
4332 500 days 365 days 180 days 0.663628 0.814634 0.627102 0.075824 0.572352
3987 730 days 365 days 180 days 0.670460 0.818816 0.628407 0.075577 0.589017
"""
assert metric in ['mse','rmse','mae','mape'], \
'metric must be either mse, rmse, mae or mape'
df_ps = pd.DataFrame()
for cross_valid_param in cross_valid_params:
df_cv = cross_validation(**cross_valid_param)
df_p = performance_metrics(df_cv, rolling_window=1)
df_p['initial'] = cross_valid_param['initial']
df_p['period'] = cross_valid_param['period']
df_ps = df_ps.append(df_p)
df_ps = df_ps[['initial','horizon','period','mse'
,'rmse','mae','mape','coverage']]
return df_ps.sort_values(metric)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment