Created
February 7, 2018 04:32
-
-
Save wassname/f40e18d90e1b67880649bc934e8e977b to your computer and use it in GitHub Desktop.
pytorch stack widow in timeseries
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def window_stack(x, window=4, pad=True): | |
""" | |
Stack along a moving window of a pytorch timeseries | |
Inputs: | |
tensor of dims (batches/time, channels) | |
pad: if true the left side will be padded to let the output match | |
Outputs: | |
if pad=True: a tensor of size (batches, channels, window) | |
else: tensor of size (batches-window, channels, window) | |
""" | |
x = x.transpose(0,1) | |
if pad: | |
x = F.pad(x, (window, 0)+(0,0)*(len(x.size())-1)) | |
print(x.size(1)) | |
x = torch.stack([x[:, i:i+window] for i in range(x.size(1)-window)]) | |
return x | |
input = torch.range(0,19).view((2,10)).transpose(1,0) # batch/time, channels | |
print(input) | |
x = window_stack(input, window=4, pad=False) | |
x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment