Skip to content

Instantly share code, notes, and snippets.

@charlieoneill11
Created January 11, 2022 00:29
Show Gist options
  • Save charlieoneill11/c9783e419704191c55e514ca63281b00 to your computer and use it in GitHub Desktop.
Save charlieoneill11/c9783e419704191c55e514ca63281b00 to your computer and use it in GitHub Desktop.
def training_loop(n_epochs, model, optimiser, loss_fn,
train_input, train_target, test_input, test_target):
for i in range(n_epochs):
def closure():
optimiser.zero_grad()
out = model(train_input)
loss = loss_fn(out, train_target)
loss.backward()
return loss
optimiser.step(closure)
with torch.no_grad():
future = 1000
pred = model(test_input, future=future)
# use all pred samples, but only go to 999
loss = loss_fn(pred[:, :-future], test_target)
y = pred.detach().numpy()
# draw figures
plt.figure(figsize=(12,6))
plt.title(f"Step {i+1}")
plt.xlabel("x")
plt.ylabel("y")
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
n = train_input.shape[1] # 999
def draw(yi, colour):
plt.plot(np.arange(n), yi[:n], colour, linewidth=2.0)
plt.plot(np.arange(n, n+future), yi[n:], colour+":", linewidth=2.0)
draw(y[0], 'r')
draw(y[1], 'b')
draw(y[2], 'g')
plt.savefig("predict%d.png"%i, dpi=200)
plt.close()
# print the loss
out = model(train_input)
loss_print = loss_fn(out, train_target)
print("Step: {}, Loss: {}".format(i, loss_print))
@tamirpumba
Copy link

tamirpumba commented Mar 17, 2023

I think
pred = model(test_input, future=future)
should be
pred = model(test_input, future_preds=future)
to make it fit to the forward function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment