Created
January 5, 2021 23:18
-
-
Save MilesCranmer/585e39b76a10f409ce612be7562c33a5 to your computer and use it in GitHub Desktop.
Generic einops operation that performs a repeat, rearrange, or reduce based on indices
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copy this into your code. Call with, e.g., einop(x, 'i j -> j', reduction='mean') | |
import functools | |
import einops as _einops | |
from einops.parsing import ParsedExpression | |
@functools.lru_cache(256) | |
def _match_einop(pattern: str, reduction=None, **axes_lengths: int): | |
"""Find the corresponding operation matching the pattern""" | |
left, rght = pattern.split('->') | |
left = ParsedExpression(left) | |
rght = ParsedExpression(rght) | |
default_op = 'rearrange' | |
op = default_op | |
for index in left.identifiers: | |
if index not in rght.identifiers: | |
op = 'reduce' | |
break | |
for index in rght.identifiers: | |
if index not in left.identifiers: | |
if op != default_op: | |
raise RuntimeError('[Einops] You must perform a reduce and repeat separately: {}'.format(pattern)) | |
op = 'repeat' | |
break | |
return op | |
def einop(tensor, pattern: str, reduction=None, **axes_lengths: int): | |
"""Perform either reduce, rearrange, or repeat depending on pattern""" | |
op = _match_einop(pattern, reduction, **axes_lengths) | |
if op == 'rearrange': | |
if reduction is not None: | |
raise RuntimeError('[Einops] Do not pass reduction for rearrange pattern: {}'.format(pattern)) | |
return _einops.rearrange(tensor, pattern, **axes_lengths) | |
elif op == 'reduce': | |
if reduction is None: | |
raise RuntimeError('[Einops] Missing reduction operation for reduce pattern: {}'.format(pattern)) | |
return _einops.reduce(tensor, pattern, reduction, **axes_lengths) | |
elif op == 'repeat': | |
if reduction is not None: | |
raise RuntimeError('[Einops] Do not pass reduction for repeat pattern: {}'.format(pattern)) | |
return _einops.repeat(tensor, pattern, **axes_lengths) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment