Skip to content

Instantly share code, notes, and snippets.

@bkaankuguoglu
Created January 16, 2022 17:45
Show Gist options
  • Save bkaankuguoglu/889ae28238c4b4aa1b8f0158db283194 to your computer and use it in GitHub Desktop.
Save bkaankuguoglu/889ae28238c4b4aa1b8f0158db283194 to your computer and use it in GitHub Desktop.
Class Optimization
# ...
def forecast_with_lag_features(self, test_loader, batch_size=1, n_features=1, n_steps=100):
test_loader_iter = iter(test_loader)
predictions = []
*_, (X, y) = test_loader_iter
y = y.to(device).detach().numpy()
X = X.view([batch_size, -1, n_features]).to(device)
X = torch.roll(X, shifts=1, dims=2)
X[..., -1, 0] = y.item(0)
with torch.no_grad():
self.model.eval()
for _ in range(n_steps):
X = X.view([batch_size, -1, n_features]).to(device)
yhat = self.model(X)
yhat = yhat.to(device).detach().numpy()
X = torch.roll(X, shifts=1, dims=2)
X[..., -1, 0] = yhat.item(0)
predictions.append(yhat.item(0))
return predictions
# ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment