Skip to content

Instantly share code, notes, and snippets.

Created February 3, 2019 23:36
Show Gist options
  • Save eigenfoo/673063880decd9f41009b6054bd77e7f to your computer and use it in GitHub Desktop.
Save eigenfoo/673063880decd9f41009b6054bd77e7f to your computer and use it in GitHub Desktop.
Example of how Einstein notation simplifies tensor manipulations.
from time import time
import torch
batch_size = 128
image_width = 64
image_height = 64
num_channels = 3 # RGB, for instance.
# Suppose we wanted to scale each channel of each image by a certain factor, and
# add them together and the end.
image_channels = torch.randn(batch_size, image_width, image_height, num_channels)
scale_factors = torch.randn(batch_size, num_channels)
# With einsum, this is straightforward, once you get the hang of thinking in
# terms of subscripts.
start = time()
images = torch.einsum("ijkl,il->ijk", image_channels, scale_factors)
print(time() - start) # ~ 0.002s
# Alternatively, we could pointwise multiply (Hadamard product) with
# broadcasting. Notice the unsqueezing, and how we have to sum separately.
start = time()
unsqueezed_scale_factors = scale_factors.unsqueeze(1).unsqueeze(2)
images_2 = torch.sum(image_channels * unsqueezed_scale_factors, dim=3)
print(time() - start) # ~ 0.010s
assert torch.all(torch.eq(images, images_2))
# Finally, we could do a batched matmul. The matmul sums for us, but we still
# need some squeezing/unsqueezing.
start = time()
unsqueezed_scale_factors = scale_factors.unsqueeze(1).unsqueeze(3)
images_3 = torch.squeeze(torch.matmul(image_channels, unsqueezed_scale_factors))
print(time() - start) # ~ 0.005s
assert torch.all(torch.eq(images, images_3))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment