Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active July 31, 2021 20:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/9387bb6c13963c38ce280dc260872196 to your computer and use it in GitHub Desktop.
Save xmodar/9387bb6c13963c38ce280dc260872196 to your computer and use it in GitHub Desktop.
Memics einops.rearrange for simple cases. Can be simplified with named_tensors. Can be optimized with tracing.
import math
def chunk_dim(tensor, chunks, dim=0):
"""Split a dimension of a tensor into two dimensions"""
shape = list(tensor.shape)
shape[dim] //= chunks
shape.insert(dim, chunks)
return tensor.view(shape)
def rearrange(tensor, input_shape, output_shape, **dims):
"""Rearrange the tensor dims using string patterns (einops.rearrange)"""
assert all('.' not in d for d in dims), 'only provide singular dim names'
# parse input_shape to flattened dims dictionary {dim_name: dim_size}
dims.update(zip(input_shape.split(' '), tensor.shape))
for dim in list(filter(lambda d: '.' in d, dims)):
size = dims.pop(dim)
remaining = None
for d in dim.split('.'):
if d in dims:
size, remainder = divmod(size, dims[d])
assert remainder == 0, f'{dim} must divide {d}={dims[d]}'
else:
assert remaining is None, f'specify either {remaining} or {d}'
remaining = d
dims[remaining] = size
# view the tensor as the flattend input shape
in_dims = input_shape.replace('.', ' ').split(' ')
tensor = tensor.view([dims[d] for d in in_dims])
# permute the tensor to the flattened output shape
out_dims = output_shape.replace('.', ' ').split(' ')
tensor = tensor.permute([in_dims.index(d) for d in out_dims])
# view the tensor as the suggested output shape
shape = [
math.prod(dims[d] for d in dim.split('.'))
for dim in output_shape.split(' ')
]
try:
tensor = tensor.view(shape)
except RuntimeError:
tensor = tensor.reshape(shape)
return tensor
if __name__ == '__main__':
import torch
images = torch.randn(2, 64, 32, 3)
print(rearrange(images, 'B I.H W C', 'B.I C H W', I=2).shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment