Created April 30, 2021 18:22
from matt
def split(arr: torch.Tensor, splits, dim=0):
axis_len = arr.shape[dim]
splits = min(axis_len, max(splits, 1))
chunk_size = axis_len // splits
remainder = axis_len - chunk_size * splits
s = 0
for i in range(splits):
adjust, remainder = 1 if remainder > 0 else 0, remainder - 1
yield torch.narrow(arr, dim, s, chunk_size + adjust)
s += chunk_size + adjust
