Created
August 6, 2019 19:03
-
-
Save marcopeix/a4809a1207c114f93a6af8d051db0e52 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Drop irrelevant columns | |
cols_to_drop = ['PT08.S1(CO)', 'C6H6(GT)', 'PT08.S2(NMHC)', 'PT08.S4(NO2)', 'PT08.S5(O3)', 'T', 'RH', 'AH'] | |
weekly_data = weekly_data.drop(cols_to_drop, axis=1) | |
# Import Prophet | |
from fbprophet import Prophet | |
import logging | |
logging.getLogger().setLevel(logging.ERROR) | |
# Change the column names according to Prophet's guidelines | |
df = weekly_data.reset_index() | |
df.columns = ['ds', 'y'] | |
df.head() | |
# Split into a train/test set | |
prediction_size = 30 | |
train_df = df[:-prediction_size] | |
# Initialize and train a model | |
m = Prophet() | |
m.fit(train_df) | |
# Make predictions | |
future = m.make_future_dataframe(periods=prediction_size) | |
forecast = m.predict(future) | |
forecast.head() | |
# Plot forecast | |
m.plot(forecast) | |
# Plot forecast's components | |
m.plot_components(forecast) | |
# Evaluate the model | |
def make_comparison_dataframe(historical, forecast): | |
return forecast.set_index('ds')[['yhat', 'yhat_lower', 'yhat_upper']].join(historical.set_index('ds')) | |
cmp_df = make_comparison_dataframe(df, forecast) | |
cmp_df.head() | |
def calculate_forecast_errors(df, prediction_size): | |
df = df.copy() | |
df['e'] = df['y'] - df['yhat'] | |
df['p'] = 100 * df['e'] / df['y'] | |
predicted_part = df[-prediction_size:] | |
error_mean = lambda error_name: np.mean(np.abs(predicted_part[error_name])) | |
return {'MAPE': error_mean('p'), 'MAE': error_mean('e')} | |
for err_name, err_value in calculate_forecast_errors(cmp_df, prediction_size).items(): | |
print(err_name, err_value) | |
# Plot forecast with upper and lower bounds | |
plt.figure(figsize=(17, 8)) | |
plt.plot(cmp_df['yhat']) | |
plt.plot(cmp_df['yhat_lower']) | |
plt.plot(cmp_df['yhat_upper']) | |
plt.plot(cmp_df['y']) | |
plt.xlabel('Time') | |
plt.ylabel('Average Weekly NOx Concentration') | |
plt.grid(False) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment