Skip to content

Instantly share code, notes, and snippets.

@spezold
Last active August 19, 2021 19:20
Show Gist options
  • Save spezold/c90b310de7f3245feb19a84f35ed3dc5 to your computer and use it in GitHub Desktop.
Save spezold/c90b310de7f3245feb19a84f35ed3dc5 to your computer and use it in GitHub Desktop.
Return both the values inside and outside of a rolling window over a 1D PyTorch tensor as a 2D tensor
from typing import Tuple
import torch
from torch import Tensor
# The straightforward solution
def rolling_window_inside_and_outside(t: Tensor, size: int, stride: int=1) -> Tuple[Tensor, Tensor]:
"""
Given a 1D tensor, provide both the values inside the rolling window and outside the rolling window for each window
position with the given window size and stride.
:param t: tensor to be rolled over
:param size: window size
:param stride: step size for rolling
:return: values inside the rolling window, values outside the rolling window (n-th window position in n-th row)
"""
assert t.ndim == 1
len_t = len(t)
assert size <= len_t
# Pad to make necessary values available both inside and outside
t = torch.cat((t, t[:len_t - size]))
# Unfold completely
unfolded = t.unfold(0, len_t, stride)
# Split into inside and outside part
inside, outside = unfolded.split((size, len_t - size), dim=1)
return inside, outside
# The tedious solution: by unfolding the inner and outer part separately, we might perhaps save memory
def rolling_window_inside_and_outside_2(t: Tensor, size: int, stride: int=1) -> Tuple[Tensor, Tensor]:
"""
Given a 1D tensor, provide both the values inside the rolling window and outside the rolling window for each window
position with the given window size and stride.
:param t: tensor to be rolled over
:param size: window size
:param stride: step size for rolling
:return: values inside the rolling window, values outside the rolling window (n-th window position in n-th row)
"""
assert t.ndim == 1
len_t = len(t)
assert size <= len_t
# Values inside window are straightforward: we just need to unfold
inside = t.unfold(0, size, stride)
# Values outside window need to be repeated or cropped, depending on window size
o = t.roll(-size) # Outer values start at value after window
size_o = len_t - size # Outer window size is length minus inner window size
len_o = 2 * size_o # Necessary length of values to unfold is now twice the window size
# Bring to required length (same as o = torch.cat((o, o[:len_o - len_t])) if (len_o > len_t) else o[:len_o])
o = torch.cat((o[:len_o], o[:max(len_o - len_t, 0)]))
outside = o.unfold(0, size_o, stride)
return inside, outside
if __name__ == "__main__":
for t in [torch.arange(7), torch.arange(8)]:
for size in [0, 3, 4, 5, len(t)]:
for stride in [1,2,3]:
i, o = rolling_window_inside_and_outside(t, size=size, stride=stride)
print()
print()
print(f"Length {len(t)}, window size {size}, stride {stride}:")
print()
print("Inside:")
print(i)
print()
print("Outside:")
print(o)
# Concatenating the inner and outer values for each position should again give all values of t
concatenated = torch.cat((i, o), dim=1)
for row in range(len(concatenated)):
assert set(concatenated[row].tolist()) == set(t.tolist())
# This prints for example:
#
# Length 8, window size 5, stride 2:
#
# Inside:
# tensor([[0, 1, 2, 3, 4],
# [2, 3, 4, 5, 6]])
#
# Outside:
# tensor([[5, 6, 7],
# [7, 0, 1]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment