Skip to content

Instantly share code, notes, and snippets.

@jcmgray
Created May 30, 2020 07:15
Show Gist options
  • Save jcmgray/906ba067f0b4ab4a1a1adde06903b213 to your computer and use it in GitHub Desktop.
Save jcmgray/906ba067f0b4ab4a1a1adde06903b213 to your computer and use it in GitHub Desktop.
oinsum - an einsum implementation for numpy object arrays
def oinsum(eq, *arrays):
"""A ``einsum`` implementation for ``numpy`` object arrays.
"""
import numpy as np
import functools
import operator
lhs, output = eq.split('->')
inputs = lhs.split(',')
sizes = {}
for term, array in zip(inputs, arrays):
for k, d in zip(term, array.shape):
sizes[k] = d
out_size = tuple(sizes[k] for k in output)
out = np.empty(out_size, dtype=object)
inner = [k for k in sizes if k not in output]
inner_size = [sizes[k] for k in inner]
for coo_o in np.ndindex(*out_size):
coord = dict(zip(output, coo_o))
def gen_inner_sum():
for coo_i in np.ndindex(*inner_size):
coord.update(dict(zip(inner, coo_i)))
locs = [tuple(coord[k] for k in term) for term in inputs]
elements = (array[loc] for array, loc in zip(arrays, locs))
yield functools.reduce(operator.mul, elements)
out[coo_o] = functools.reduce(operator.add, gen_inner_sum())
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment