Skip to content

Instantly share code, notes, and snippets.

@pbruneau
Created January 11, 2023 07:49
Show Gist options
  • Save pbruneau/164dbe40b994185ea722aa80d27fae6c to your computer and use it in GitHub Desktop.
Save pbruneau/164dbe40b994185ea722aa80d27fae6c to your computer and use it in GitHub Desktop.
Minimal example using the Torch DeepAR implementation
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.torch.model.deepar import DeepARModel, DeepARLightningModule
from gluonts.transform import (
AddObservedValuesIndicator,
InstanceSplitter,
ExpectedNumInstanceSampler,
)
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import TrainDataLoader
from gluonts.itertools import Cached
from gluonts.torch.batchify import batchify
from gluonts.torch.distributions import StudentTOutput
import pytorch_lightning as pl
import torch
dataset = get_dataset("electricity")
context_length = 2 * 7 * 24
prediction_length = dataset.metadata.prediction_length
# the initialization demands values for num_feat*
# I try to provide them in the most sensible way
model = DeepARModel(
prediction_length=prediction_length,
context_length=context_length,
distr_output=StudentTOutput(),
freq='1H',
num_feat_dynamic_real=0,
num_feat_static_real=0,
num_feat_static_cat=0,
cardinality=[],
)
module = DeepARLightningModule(model)
mask_unobserved = AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
)
training_splitter = InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=ExpectedNumInstanceSampler(
num_instances=1,
min_future=prediction_length,
),
past_length=context_length,
future_length=prediction_length,
time_series_fields=[FieldName.OBSERVED_VALUES],
)
batch_size = 32
num_batches_per_epoch = 50
data_loader = TrainDataLoader(
# We cache the dataset, to make training faster
Cached(dataset.train),
batch_size=batch_size,
stack_fn=batchify,
transform=mask_unobserved + training_splitter,
num_batches_per_epoch=num_batches_per_epoch,
)
trainer = pl.Trainer(max_epochs=10, gpus=-1 if torch.cuda.is_available() else None)
trainer.fit(module, data_loader)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment