Skip to content

Instantly share code, notes, and snippets.

@madisonmay
Last active December 12, 2021 16:36
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save madisonmay/d0f2c9b6ad2135b8f9bf6ac66b171990 to your computer and use it in GitHub Desktop.
Save madisonmay/d0f2c9b6ad2135b8f9bf6ac66b171990 to your computer and use it in GitHub Desktop.
Named Einsum
import numpy as np
from string import ascii_lowercase, ascii_uppercase
def _translate_axes_names(seq: str, translation_table: dict=None):
if translation_table is None:
translation_table = {}
if not seq:
return "", translation_table
operands = seq.split(',')
valid_axis_targets = ascii_lowercase + ascii_uppercase
variable_offset = len(translation_table)
translated_operands = []
for operand in operands:
translated_axes = []
for axis_name in operand.strip().split('.'):
if axis_name in translation_table:
translated_axis_name = translation_table[axis_name]
else:
translated_axis_name = valid_axis_targets[variable_offset]
translation_table[axis_name] = translated_axis_name
variable_offset += 1
translated_axes.append(translated_axis_name)
translated_operand = "".join(translated_axes)
translated_operands.append(translated_operand)
return ",".join(translated_operands), translation_table
def named_einsum(subscripts, *args, **kwargs):
input_seq, output_seq = subscripts.split('->')
translated_inputs, translation_table = _translate_axes_names(input_seq)
translated_outputs, _ = _translate_axes_names(output_seq, translation_table)
translated_seq = "->".join([translated_inputs, translated_outputs])
return tf.einsum(translated_seq, *args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment