This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
dataset = get_dataset(args.datasource, regenerate=False) | |
prediction_length = dataset.metadata.prediction_length | |
freq = dataset.metadata.freq | |
cardinality = ast.literal_eval(dataset.metadata.feat_static_cat[0].cardinality) | |
train_ds = dataset.train | |
test_ds = dataset.test | |
trainer = Trainer(ctx=mx.context.gpu() if is_gpu&args.use_cuda else mx.context.cpu(), | |
batch_size=args.batch_size, | |
learning_rate=args.learning_rate, | |
epochs=20, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
usage: deeprenewal [-h] [--use-cuda USE_CUDA] | |
[--datasource {retail_dataset}] | |
[--regenerate-datasource REGENERATE_DATASOURCE] | |
[--model-save-dir MODEL_SAVE_DIR] | |
[--point-forecast {median,mean}] | |
[--calculate-spec CALCULATE_SPEC] | |
[--batch_size BATCH_SIZE] | |
[--learning-rate LEARNING_RATE] | |
[--max-epochs MAX_EPOCHS] | |
[--number-of-batches-per-epoch NUMBER_OF_BATCHES_PER_EPOCH] |