Skip to content

Instantly share code, notes, and snippets.

@0x00b1
Created January 31, 2024 18:17
Show Gist options
  • Save 0x00b1/32a747a031a789dce9fc4dfd4a02ae0e to your computer and use it in GitHub Desktop.
Save 0x00b1/32a747a031a789dce9fc4dfd4a02ae0e to your computer and use it in GitHub Desktop.
PyTorch `iota` operator
from torch import Tensor
import torch
def iota(shape: tuple[int, ...], dim: int = 0, **kwargs) -> Tensor:
dimensions = []
for index, _ in enumerate(shape):
if index != dim:
dimension = 1
else:
dimension = shape[index]
dimensions = [*dimensions, dimension]
return torch.arange(shape[dim], **kwargs).view(*dimensions).expand(*shape)
@0x00b1
Copy link
Author

0x00b1 commented Jan 31, 2024

iota([4, 8], 0, dtype=torch.int32)

tensor([[0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3, 3, 3, 3]], dtype=torch.int32)
iota([4, 8], 1, dtype=torch.int32)

tensor([[0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 5, 6, 7]], dtype=torch.int32)
iota([2, 4, 8], 0, dtype=torch.int32)

tensor([[[0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0]],

        [[1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1]]], dtype=torch.int32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment