Created
November 3, 2021 21:43
-
-
Save rian-dolphin/e8df952f1895f471943f21c604ba3d49 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_x_y_pairs(train_scaled, train_periods, prediction_periods): | |
""" | |
train_scaled - training sequence | |
train_periods - How many data points to use as inputs | |
prediction_periods - How many periods to ouput as predictions | |
""" | |
x_train = [train_scaled[i:i+train_periods] for i in range(len(train_scaled)-train_periods-prediction_periods)] | |
y_train = [train_scaled[i+train_periods:i+train_periods+prediction_periods] for i in range(len(train_scaled)-train_periods-prediction_periods)] | |
#-- use the stack function to convert the list of 1D tensors | |
# into a 2D tensor where each element of the list is now a row | |
x_train = torch.stack(x_train) | |
y_train = torch.stack(y_train) | |
return x_train, y_train | |
train_periods = 16 #-- number of quarters for input | |
prediction_periods = test_periods | |
x_train, y_train = get_x_y_pairs(train_scaled, train_periods, prediction_periods) | |
print(x_train.shape) | |
print(y_train.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment