Skip to content

Instantly share code, notes, and snippets.

@pbruneau
Last active March 25, 2023 09:16
Show Gist options
  • Save pbruneau/5b58bca5f36e4114d38f1e91f1f0f5b8 to your computer and use it in GitHub Desktop.
Save pbruneau/5b58bca5f36e4114d38f1e91f1f0f5b8 to your computer and use it in GitHub Desktop.
Minimal GluonTS / DeepAR example
import pandas as pd
import numpy as np
import mxnet as mx
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.mx import DeepAREstimator, Trainer
from gluonts.model.predictor import Predictor
from pathlib import Path
def train_model():
model = DeepAREstimator(
prediction_length=12, freq="M", trainer=Trainer(epochs=5)
).train(training_data)
model.serialize(Path("temp"))
df = pd.read_csv(
"https://raw.githubusercontent.com/AileenNielsen/"
"TimeSeriesAnalysisWithPython/master/data/AirPassengers.csv",
index_col=0,
parse_dates=True,
)
dataset = PandasDataset(df, target="#Passengers")
training_data, test_gen = split(dataset, offset=-36)
# commenting if reusing model
#train_model()
model = Predictor.deserialize(Path("temp"), ctx=mx.gpu(0))
test_data = test_gen.generate_instances(prediction_length=12, windows=3)
np.random.seed(42)
mx.random.seed(42)
forecasts = list(model.predict(test_data.input))
print(forecasts[0].median)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment