Skip to content

Instantly share code, notes, and snippets.

@rockt
Created February 12, 2021 09:22
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rockt/a3191f517728ea9a136a204f578d27c8 to your computer and use it in GitHub Desktop.
Save rockt/a3191f517728ea9a136a204f578d27c8 to your computer and use it in GitHub Desktop.
PyTorch einsum with named tensors
import torch
import re
def einsumfy_exp(exp):
names = set(re.split("[, \(\)]|->", exp))
names.remove("")
invalid_names = set(filter(lambda x: len(x) > 1, names))
if "..." in invalid_names:
invalid_names.remove("...")
free_chr = ord('a')
for name in invalid_names:
while chr(free_chr) in names:
free_chr += 1
exp = exp.replace(name, chr(free_chr))
free_chr += 1
return exp
def zweisum(exp, *args, **kwargs):
return torch.einsum(einsumfy_exp(exp), args)
x = torch.randn(2, 3, 5)
W = torch.randn(5, 7)
print(zweisum("batch time emb, emb hidden -> batch time hidden", x, W))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment