Skip to content

Instantly share code, notes, and snippets.

@pbruneau
Created January 26, 2023 07:14
Show Gist options
  • Save pbruneau/bef11c78662a2a2dbfcf6381be4f2c04 to your computer and use it in GitHub Desktop.
Save pbruneau/bef11c78662a2a2dbfcf6381be4f2c04 to your computer and use it in GitHub Desktop.
Minimal example using DeepAREstimator
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.transform import ExpectedNumInstanceSampler
from gluonts.torch.model.deepar import DeepAREstimator
import torch
dataset = get_dataset("electricity")
context_length = 2 * 7 * 24
prediction_length = dataset.metadata.prediction_length
model = DeepAREstimator(
prediction_length=prediction_length,
context_length=context_length,
freq='1H',
hidden_size=100,
train_sampler = ExpectedNumInstanceSampler(
num_instances=1,
min_future=prediction_length,
min_past=context_length,
),
batch_size=32,
num_batches_per_epoch=50,
trainer_kwargs={
'max_epochs': 10,
'gpus': -1 if torch.cuda.is_available() else None,
},
)
predictor = model.train(dataset.train, dataset.test, shuffle_buffer_length=10, num_workers=4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment