Skip to content

Instantly share code, notes, and snippets.

@pbruneau
Created April 19, 2021 06:35
Show Gist options
  • Save pbruneau/ae733ac6e5923943d7d0bc4bb0f8c103 to your computer and use it in GitHub Desktop.
Save pbruneau/ae733ac6e5923943d7d0bc4bb0f8c103 to your computer and use it in GitHub Desktop.
benchmark_m4 simplified sample
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
This example shows how to fit a model and evaluate its predictions.
"""
import pprint
from functools import partial
import pandas as pd
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.distribution.piecewise_linear import PiecewiseLinearOutput
from gluonts.evaluation import Evaluator
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.model.deepar import DeepAREstimator
from gluonts.model.seq2seq import MQCNNEstimator
from gluonts.model.simple_feedforward import SimpleFeedForwardEstimator
from gluonts.mx.trainer import Trainer
epochs = 100
num_batches_per_epoch = 50
estimators = {
'DeepAREstimator': partial(DeepAREstimator,
trainer=Trainer(
hybridize=True, ctx='gpu',
epochs=epochs, num_batches_per_epoch=num_batches_per_epoch
),
),
'MQCNNEstimator': partial(MQCNNEstimator,
trainer=Trainer(
hybridize=True, ctx='gpu',
epochs=epochs, num_batches_per_epoch=num_batches_per_epoch
),
),
}
def evaluate(dataset_name, estimator_name):
dataset = get_dataset(dataset_name)
arguments = {
'prediction_length': dataset.metadata.prediction_length,
'freq': dataset.metadata.freq,
'use_feat_static_cat': True,
'cardinality': [
feat_static_cat.cardinality
for feat_static_cat in dataset.metadata.feat_static_cat
]
}
estimator = estimators[estimator_name]
estimator = estimator(**arguments)
print(f"evaluating {estimator_name} on {dataset_name}")
predictor = estimator.train(dataset.train)
forecast_it, ts_it = make_evaluation_predictions(
dataset.test, predictor=predictor, num_samples=100
)
agg_metrics, item_metrics = Evaluator()(
ts_it, forecast_it, num_series=len(dataset.test)
)
pprint.pprint(agg_metrics)
eval_dict = agg_metrics
eval_dict["dataset"] = dataset_name
eval_dict["estimator"] = estimator_name
return eval_dict
if __name__ == "__main__":
evaluate("m4_weekly", "MQCNNEstimator")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment