Skip to content

Instantly share code, notes, and snippets.

@Guitaricet
Created December 8, 2022 02:39
Show Gist options
  • Save Guitaricet/1aca93323b0d3f94a35a9001aa736467 to your computer and use it in GitHub Desktop.
Save Guitaricet/1aca93323b0d3f94a35a9001aa736467 to your computer and use it in GitHub Desktop.
add NamedShape property to torch.Tensor
import torch
class NamedShape:
"""A convenience class to make beautifully named shapes."""
def __init__(self, tensor) -> None:
self.names = tensor.names
self.shape = tensor.shape
def __repr__(self) -> str:
_named_shape = {name: size for name, size in zip(self.names, self.shape)}
_named_shape = "".join(f"{n1}={n2}, " for n1, n2 in _named_shape.items())
_named_shape = "NamedShape[" + _named_shape[:-2] + "]"
return _named_shape
torch.Tensor.named_shape = property(NamedShape)
x = torch.rand(5, 3, names=("batch", "features"))
W = torch.randn(3, 7, names=("features", "neurons"))
y = x @ W
print(y.named_shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment