Skip to content

Instantly share code, notes, and snippets.

@lostella
Created May 4, 2021 13:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lostella/ee40ffde4431c5f6f38404b5fb557c47 to your computer and use it in GitHub Desktop.
Save lostella/ee40ffde4431c5f6f38404b5fb557c47 to your computer and use it in GitHub Desktop.
Sanity check on GluonTS provided datasets
from gluonts.dataset.repository.datasets import get_dataset, dataset_names
def check_train_test_split(dataset):
prediction_length = dataset.metadata.prediction_length
train_end = {}
for entry in dataset.train:
assert entry["item_id"] not in train_end, f"item {k} is duplicate"
train_end[entry["item_id"]] = entry["start"] + len(entry["target"]) * entry["start"].freq
test_end = {}
for entry in dataset.test:
test_end[entry["item_id"]] = entry["start"] + len(entry["target"]) * entry["start"].freq
for k in test_end:
if k not in train_end:
continue
expected_end = train_end[k] + prediction_length * train_end[k].freq
assert test_end[k] >= expected_end, f"test entry for item {k} ends at {test_end[k]} < {expected_end}"
for name in dataset_names:
try:
dataset = get_dataset(name)
except RuntimeError:
print(f"WARN dataset '{name}' could not be obtained")
continue
try:
check_train_test_split(dataset)
except AssertionError as err:
print(f"ERR dataset '{name}' has issues: {err}")
else:
print(f"✓ dataset '{name}' looks good")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment