Named Einsum
import numpy as np | |
from string import ascii_lowercase, ascii_uppercase | |
def _translate_axes_names(seq: str, translation_table: dict=None): | |
if translation_table is None: | |
translation_table = {} | |
if not seq: | |
return "", translation_table | |
operands = seq.split(',') | |
valid_axis_targets = ascii_lowercase + ascii_uppercase | |
variable_offset = len(translation_table) | |
translated_operands = [] | |
for operand in operands: | |
translated_axes = [] | |
for axis_name in operand.strip().split('.'): | |
if axis_name in translation_table: | |
translated_axis_name = translation_table[axis_name] | |
else: | |
translated_axis_name = valid_axis_targets[variable_offset] | |
translation_table[axis_name] = translated_axis_name | |
variable_offset += 1 | |
translated_axes.append(translated_axis_name) | |
translated_operand = "".join(translated_axes) | |
translated_operands.append(translated_operand) | |
return ",".join(translated_operands), translation_table | |
def named_einsum(subscripts, *args, **kwargs): | |
input_seq, output_seq = subscripts.split('->') | |
translated_inputs, translation_table = _translate_axes_names(input_seq) | |
translated_outputs, _ = _translate_axes_names(output_seq, translation_table) | |
translated_seq = "->".join([translated_inputs, translated_outputs]) | |
return tf.einsum(translated_seq, *args, **kwargs) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment