Skip to content

Instantly share code, notes, and snippets.

@h3ik0th
Created October 29, 2021 19:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save h3ik0th/0cea3f41d879e877985398c0d3f9ea09 to your computer and use it in GitHub Desktop.
Save h3ik0th/0cea3f41d879e877985398c0d3f9ea09 to your computer and use it in GitHub Desktop.
# set up, fit, run, plot, and evaluate the RNN model
def run_RNN(flavor, ts, train, val):
# set the model up
model_RNN = RNNModel(
model=flavor,
model_name=flavor + str(" RNN"),
input_chunk_length=periodicity,
training_length=20,
hidden_dim=20,
batch_size=16,
n_epochs=EPOCH,
dropout=0,
optimizer_kwargs={'lr': 1e-3},
log_tensorboard=True,
random_state=42,
force_reset=True)
if flavor == "RNN": flavor = "Vanilla"
# fit the model
fit_it(model_RNN, train, val, flavor)
# compute N predictions
pred = model_RNN.predict(n=FC_N, future_covariates=covariates)
# plot predictions vs actual
plot_fitted(pred, ts, flavor)
# print accuracy metrics
res_acc = accuracy_metrics(pred, ts)
print(flavor + " : ")
_ = [print(k,":",f'{v:.4f}') for k,v in res_acc.items()]
return [pred, res_acc]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment