Skip to content

Instantly share code, notes, and snippets.

@jetnew
Last active June 1, 2019 17:41
Show Gist options
  • Save jetnew/cbf48edb1e611f2af64f3a3731d7ac0b to your computer and use it in GitHub Desktop.
Save jetnew/cbf48edb1e611f2af64f3a3731d7ac0b to your computer and use it in GitHub Desktop.
ARIMA using statsmodels
from statsmodels.tsa.arima_model import ARIMA
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
p = 5 # lag
d = 1 # difference order
q = 0 # size of moving average window
train, test = train_test_split(X, test_size=0.20, shuffle=False)
history = train.tolist()
predictions = []
for t in range(len(test)):
model = ARIMA(history, order=(p,d,q))
fit = model.fit(disp=False)
pred = fit.forecast()[0]
predictions.append(pred)
history.append(test[t])
print('MSE: %.3f' % mean_squared_error(test, predictions))
plt.plot(test)
plt.plot(predictions, color='red')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment