Created
May 30, 2020 07:15
-
-
Save jcmgray/906ba067f0b4ab4a1a1adde06903b213 to your computer and use it in GitHub Desktop.
oinsum - an einsum implementation for numpy object arrays
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def oinsum(eq, *arrays): | |
"""A ``einsum`` implementation for ``numpy`` object arrays. | |
""" | |
import numpy as np | |
import functools | |
import operator | |
lhs, output = eq.split('->') | |
inputs = lhs.split(',') | |
sizes = {} | |
for term, array in zip(inputs, arrays): | |
for k, d in zip(term, array.shape): | |
sizes[k] = d | |
out_size = tuple(sizes[k] for k in output) | |
out = np.empty(out_size, dtype=object) | |
inner = [k for k in sizes if k not in output] | |
inner_size = [sizes[k] for k in inner] | |
for coo_o in np.ndindex(*out_size): | |
coord = dict(zip(output, coo_o)) | |
def gen_inner_sum(): | |
for coo_i in np.ndindex(*inner_size): | |
coord.update(dict(zip(inner, coo_i))) | |
locs = [tuple(coord[k] for k in term) for term in inputs] | |
elements = (array[loc] for array, loc in zip(arrays, locs)) | |
yield functools.reduce(operator.mul, elements) | |
out[coo_o] = functools.reduce(operator.add, gen_inner_sum()) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment