Skip to content

Instantly share code, notes, and snippets.

@jcmgray
Last active May 12, 2024 07:58
Show Gist options
  • Save jcmgray/5c6e4e586a5f962908f65304811e29b9 to your computer and use it in GitHub Desktop.
Save jcmgray/5c6e4e586a5f962908f65304811e29b9 to your computer and use it in GitHub Desktop.
Batched tensor contraction - essentially batched tensordot, but using einsum notation.
from numpy import tensordot, stack
def einsum_to_axes(lhix, rhix):
"""Take two einsum inputs (strs) and produce
the tuples of ints required for tensordot.
"""
ix_rm = set(lhix) & set(rhix)
lh_ax, rh_ax = [], []
for ix in ix_rm:
lh_ax.append(lhix.find(ix))
rh_ax.append(rhix.find(ix))
return lh_ax, rh_ax
def batched_contract(eq, x, y):
"""Performed a batched (broadcasted) contraction of two tensors.
Two arguments only, uses einsum format, but tensordot backend.
``eq`` should be an einsum string which is a 'standard'
contraction apart a single index appearing on all arguments.
"""
lhs, zix = eq.split('->')
xix, yix = lhs.split(',')
xix_set, yix_set, zix_set = set(xix), set(yix), set(zix)
# index that will be batched
bi, = set.intersection(xix_set, yix_set, zix_set)
# indices to be contracted
ix_rm = set.intersection(xix_set, yix_set)
# axis location of batched index for each op
xbax, ybax, zbax = (ix.find(bi) for ix in (xix, yix, zix))
# size of the batched dimension
b_size = x.shape[xbax]
# each ops indices without the broadcast index
xix_nob, yix_nob, zix_nob = (ix.replace(bi, "") for ix in (xix, yix, zix))
# list to slice tensors with
xsel, ysel = [slice(None)] * len(xix), [slice(None)] * len(yix)
# turn it into a tensordot tuple of axes
tdot_axes = einsum_to_axes(xix_nob, yix_nob)
def gen_tdot_slices():
for i in range(b_size):
# update the slicers
xsel[xbax] = ysel[ybax] = i
yield tensordot(x[xsel], y[ysel], tdot_axes)
out = stack(gen_tdot_slices(), axis=0)
current_ix = [bi] + [i for i in xix + yix if i not in ix_rm]
perm = [current_ix.index(i) for i in zix]
return out.transpose(*perm)
@sarda23
Copy link

sarda23 commented Apr 27, 2024

does this work for multiple batching dims which can be non-continuos?

@jcmgray
Copy link
Author

jcmgray commented May 7, 2024

hi @sarda23, not this snippet. But it is superseded by https://github.com/jcmgray/einsum_bmm, which itself is incorporated into https://cotengra.readthedocs.io/en/latest/.

@sarda23
Copy link

sarda23 commented May 12, 2024

Thanks will look into this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment