Skip to content

Instantly share code, notes, and snippets.

@gngdb
Created April 10, 2024 22:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gngdb/6ff17112942f4f12d1af18e282de3470 to your computer and use it in GitHub Desktop.
Save gngdb/6ff17112942f4f12d1af18e282de3470 to your computer and use it in GitHub Desktop.
einsum implemented with `itertools.product`
import torch
import itertools
from collections import OrderedDict
def einsum_itertools(equation, *operands, verbose=False):
# Parse the equation
input_labels, output_labels = equation.split('->')
input_labels = input_labels.split(',')
if verbose:
print(f"{input_labels=} {output_labels=}")
# Get the dimensions of each operand
input_dims = [list(op.shape) for op in operands]
if verbose:
print(f"{input_dims=}")
# Create a dictionary mapping labels to dimensions
label_dims = OrderedDict({})
for labels, dims in zip(input_labels, input_dims):
for label, dim in zip(labels, dims):
label_dims[label] = dim
if verbose:
print(f"{label_dims=}")
# Compute the output shape
output_shape = [label_dims[label] for label in output_labels]
if verbose:
print(f"{output_shape=}")
# Create the output tensor
output = torch.zeros(output_shape)
# Generate the indices for iteration
indices = [range(dim) for dim in label_dims.values()]
if verbose:
print(f"{[len(i) for i in indices]=}")
# Perform the einsum operation using nested iteration
for idx in itertools.product(*indices):
if verbose:
print(f" {idx=}")
# Create a dictionary mapping labels to indices
label_idx = OrderedDict(zip(label_dims.keys(), idx))
if verbose:
print(f" {label_idx=}")
# Compute the product of the operands at the current indices
product = 1
for op, labels in zip(operands, input_labels):
op_idx = tuple(label_idx[label] for label in labels)
product *= op[op_idx]
if verbose:
print(f" {product=}")
# Update the output tensor at the corresponding indices
output_idx = tuple(label_idx[label] for label in output_labels)
if verbose:
print(f" {output_idx=}")
output[output_idx] += product
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment