Skip to content

Instantly share code, notes, and snippets.

@lostella
Created September 7, 2020 11:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lostella/911dbe788ecc00a12bec7f4dea80dc2e to your computer and use it in GitHub Desktop.
Save lostella/911dbe788ecc00a12bec7f4dea80dc2e to your computer and use it in GitHub Desktop.
GluonTS data loader sanity check
import time
import mxnet as mx
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.loader import TrainDataLoader
from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.batchify import batchify as mx_batchify
dataset = get_dataset("electricity")
dataset_train = dataset.train
freq = dataset.metadata.freq
prediction_length = dataset.metadata.prediction_length
batch_size = 32
num_batches_per_epoch = 8
estimator = DeepAREstimator(freq=freq, prediction_length=prediction_length,)
transform = estimator.create_transformation()
print("creating data loader")
training_loader = TrainDataLoader(
dataset=dataset_train,
transform=transform,
batch_size=batch_size,
batchify_fn=lambda x: mx_batchify(x, mx.cpu()),
num_batches_per_epoch=num_batches_per_epoch,
num_workers=2,
shuffle_buffer_length=20,
)
print("sleeping")
time.sleep(1.0)
print("done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment