Skip to content

Instantly share code, notes, and snippets.

@polvalente
Created March 1, 2024 23:35
Show Gist options
  • Save polvalente/b3701372ab9c22ebc7dab4fdc1433ded to your computer and use it in GitHub Desktop.
Save polvalente/b3701372ab9c22ebc7dab4fdc1433ded to your computer and use it in GitHub Desktop.
Equivalência de tensordot para einsum
In [1]: import torch
In [2]: a = torch.arange(3, 4, 5)
In [3]: a = torch.arange(60).reshape(3, 4, 5)
In [4]: b = torch.arange(24).reshape(1, 4, 3, 2)
# Aqui, (junto com o output da linha Out[7]) a gente vê que
# que o tensordot usando contraction axes específicas tem
# esse resultado.
# Na prática, a gente pode pensar que o cálculo é equivalente
# a transpor esses eixos pro final dos respectivos tensores,
# e "flatten" eles em um grande vetor, aí rola soma dos produtos ponto a ponto normalmente.
In [5]: torch.tensordot(a, b, dims=((0, 1), (2, 1)))
Out[5]:
tensor([[[4400, 4730]],
[[4532, 4874]],
[[4664, 5018]],
[[4796, 5162]],
[[4928, 5306]]])
In [7]: torch.tensordot(a, b, dims=((0, 1), (2, 1))).shape
Out[7]: torch.Size([5, 1, 2])
# Essa expressão de einsum encoda a mesma conta:
# a dimensão 0 da esquerda corresponde à dim -2 da direita,
# a dim 1 da esquerda, à -3 da direita, e as outras são encaixadas,
# na ordem em que aparecem, como leading axes do resultado.
In [9]: torch.einsum('ijk,...jiw->k...w', a, b).shape
Out[9]: torch.Size([5, 1, 2])
In [10]: torch.einsum('ijk,...jiw->k...w', a, b)
Out[10]:
tensor([[[4400, 4730]],
[[4532, 4874]],
[[4664, 5018]],
[[4796, 5162]],
[[4928, 5306]]])
# abaixo, o desenvolvimento mais intuitivo de o que as duas operações estão fazendo por debaixo dos panos:
In [17]: a_flat = torch.permute(a, (2, 0, 1)).reshape((a.shape[2], -1))
In [18]: a_flat.shape
Out[18]: torch.Size([5, 12])
In [21]: b_flat = torch.permute(b, (0, 3, 2, 1)).reshape((b.shape[0], b.shape[3], -1))
In [22]: b_flat.shape
Out[22]: torch.Size([1, 2, 12])
In [23]: torch.tensordot(a_flat, b_flat, dims=((-1,), (-1,)))
Out[23]:
tensor([[[4400, 4730]],
[[4532, 4874]],
[[4664, 5018]],
[[4796, 5162]],
[[4928, 5306]]])
In [24]: torch.tensordot(a_flat, b_flat, dims=((-1,), (-1,))).shape
Out[24]: torch.Size([5, 1, 2])
# Repare que os resultados de Out[23] e de Out[24] são iguais aos respectivos das outras 2 implementações
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment