Created
July 13, 2019 19:41
-
-
Save ikanez/941209c70bfc0cb7351479242fbf8289 to your computer and use it in GitHub Desktop.
Evaluate performance of best sarima model over multiple time window and log into mlflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import timedelta | |
num_window = 10 | |
def last_day_of_month(any_day): | |
next_month = any_day.replace(day=28) + timedelta(days=4) # this will never fail | |
return next_month - timedelta(days=next_month.day) | |
with mlflow.start_run(run_name='sarima_backtest'): | |
t1 = pd.to_datetime('2017-01-31') | |
t2 = pd.to_datetime('2017-04-30') | |
# log param | |
mlflow.log_param('num_window',num_window) | |
mlflow.log_param('init_start_dt', t1) | |
mlflow.log_param('init_end_dt',t2) | |
for j in range(0,num_window): | |
# Get the dynamic forecast between dates t1 and t2 | |
t1 = last_day_of_month(t1 + timedelta(add_num_days)) | |
t2 = last_day_of_month(t2 + timedelta(add_num_days)) | |
btc_month_dynamic = best_model.get_prediction(start=t1, end=t2, dynamic=True, full_results=True) | |
btc_month_sarima['dynamic_forecast'] = invboxcox(btc_month_dynamic.predicted_mean, lmbda) | |
# Taking 80% confidence interval because the 95% blows out too high to visualise | |
pred_dynamic_ci = btc_month_dynamic.conf_int(alpha=0.2) | |
pred_dynamic_ci['lower close_box'] = invboxcox(pred_dynamic_ci['lower close_box'], lmbda) | |
pred_dynamic_ci['upper close_box'] = invboxcox(pred_dynamic_ci['upper close_box'], lmbda) | |
# Plot | |
plt.ylim((0,20000)) | |
btc_month_sarima.Close['2016':'2018-01'].plot(label='close', linewidth=3) | |
btc_month_sarima.dynamic_forecast.plot(color='r', ls='--', label='dynamic forecast', linewidth=3) | |
plt.fill_between(pred_dynamic_ci.index, | |
pred_dynamic_ci.iloc[:, 0], | |
pred_dynamic_ci.iloc[:, 1], color='k', alpha=.25) | |
plt.fill_betweenx(plt.ylim(), t1, t2, alpha=.1, zorder=-1) | |
plt.legend() | |
plt.title('Bitcoin Dynamic Monthly Forecast (backtesting)') | |
plt.grid() | |
plt.ylabel('USD') | |
display() | |
fig_fn = 'plot_{}_{}.png'.format(t1,t2) | |
plt.savefig(fig_fn) | |
mlflow.log_artifact(fig_fn) # logging to mlflow | |
plt.close() | |
# calculate metrics | |
df_pred = pd.DataFrame(btc_month_sarima.dynamic_forecast) | |
df_ori = pd.DataFrame(btc_month_sarima.Close['2016':'2018-01']) | |
df_ori_res = [df_ori.Close.loc[i] for i in df_pred.dropna().index] | |
df_pred_res = df_pred.dropna().dynamic_forecast.values | |
rmse = sqrt(mean_squared_error(df_ori_res, df_pred_res)) | |
# log metrics | |
mlflow.log_metric('rmse',rmse,step=j) # log metric to mlflow | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment