Skip to content

Instantly share code, notes, and snippets.

@MilesCranmer
Created January 5, 2021 23:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MilesCranmer/585e39b76a10f409ce612be7562c33a5 to your computer and use it in GitHub Desktop.
Save MilesCranmer/585e39b76a10f409ce612be7562c33a5 to your computer and use it in GitHub Desktop.
Generic einops operation that performs a repeat, rearrange, or reduce based on indices
# Copy this into your code. Call with, e.g., einop(x, 'i j -> j', reduction='mean')
import functools
import einops as _einops
from einops.parsing import ParsedExpression
@functools.lru_cache(256)
def _match_einop(pattern: str, reduction=None, **axes_lengths: int):
"""Find the corresponding operation matching the pattern"""
left, rght = pattern.split('->')
left = ParsedExpression(left)
rght = ParsedExpression(rght)
default_op = 'rearrange'
op = default_op
for index in left.identifiers:
if index not in rght.identifiers:
op = 'reduce'
break
for index in rght.identifiers:
if index not in left.identifiers:
if op != default_op:
raise RuntimeError('[Einops] You must perform a reduce and repeat separately: {}'.format(pattern))
op = 'repeat'
break
return op
def einop(tensor, pattern: str, reduction=None, **axes_lengths: int):
"""Perform either reduce, rearrange, or repeat depending on pattern"""
op = _match_einop(pattern, reduction, **axes_lengths)
if op == 'rearrange':
if reduction is not None:
raise RuntimeError('[Einops] Do not pass reduction for rearrange pattern: {}'.format(pattern))
return _einops.rearrange(tensor, pattern, **axes_lengths)
elif op == 'reduce':
if reduction is None:
raise RuntimeError('[Einops] Missing reduction operation for reduce pattern: {}'.format(pattern))
return _einops.reduce(tensor, pattern, reduction, **axes_lengths)
elif op == 'repeat':
if reduction is not None:
raise RuntimeError('[Einops] Do not pass reduction for repeat pattern: {}'.format(pattern))
return _einops.repeat(tensor, pattern, **axes_lengths)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment