Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Created April 30, 2021 18:22
Show Gist options
  • Save lucidrains/4fdc64aa7893656f6470c6d944169b01 to your computer and use it in GitHub Desktop.
Save lucidrains/4fdc64aa7893656f6470c6d944169b01 to your computer and use it in GitHub Desktop.
from matt
def split(arr: torch.Tensor, splits, dim=0):
axis_len = arr.shape[dim]
splits = min(axis_len, max(splits, 1))
chunk_size = axis_len // splits
remainder = axis_len - chunk_size * splits
s = 0
for i in range(splits):
adjust, remainder = 1 if remainder > 0 else 0, remainder - 1
yield torch.narrow(arr, dim, s, chunk_size + adjust)
s += chunk_size + adjust
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment