Last active
May 12, 2024 07:58
-
-
Save jcmgray/5c6e4e586a5f962908f65304811e29b9 to your computer and use it in GitHub Desktop.
Batched tensor contraction - essentially batched tensordot, but using einsum notation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy import tensordot, stack | |
def einsum_to_axes(lhix, rhix): | |
"""Take two einsum inputs (strs) and produce | |
the tuples of ints required for tensordot. | |
""" | |
ix_rm = set(lhix) & set(rhix) | |
lh_ax, rh_ax = [], [] | |
for ix in ix_rm: | |
lh_ax.append(lhix.find(ix)) | |
rh_ax.append(rhix.find(ix)) | |
return lh_ax, rh_ax | |
def batched_contract(eq, x, y): | |
"""Performed a batched (broadcasted) contraction of two tensors. | |
Two arguments only, uses einsum format, but tensordot backend. | |
``eq`` should be an einsum string which is a 'standard' | |
contraction apart a single index appearing on all arguments. | |
""" | |
lhs, zix = eq.split('->') | |
xix, yix = lhs.split(',') | |
xix_set, yix_set, zix_set = set(xix), set(yix), set(zix) | |
# index that will be batched | |
bi, = set.intersection(xix_set, yix_set, zix_set) | |
# indices to be contracted | |
ix_rm = set.intersection(xix_set, yix_set) | |
# axis location of batched index for each op | |
xbax, ybax, zbax = (ix.find(bi) for ix in (xix, yix, zix)) | |
# size of the batched dimension | |
b_size = x.shape[xbax] | |
# each ops indices without the broadcast index | |
xix_nob, yix_nob, zix_nob = (ix.replace(bi, "") for ix in (xix, yix, zix)) | |
# list to slice tensors with | |
xsel, ysel = [slice(None)] * len(xix), [slice(None)] * len(yix) | |
# turn it into a tensordot tuple of axes | |
tdot_axes = einsum_to_axes(xix_nob, yix_nob) | |
def gen_tdot_slices(): | |
for i in range(b_size): | |
# update the slicers | |
xsel[xbax] = ysel[ybax] = i | |
yield tensordot(x[xsel], y[ysel], tdot_axes) | |
out = stack(gen_tdot_slices(), axis=0) | |
current_ix = [bi] + [i for i in xix + yix if i not in ix_rm] | |
perm = [current_ix.index(i) for i in zix] | |
return out.transpose(*perm) |
hi @sarda23, not this snippet. But it is superseded by https://github.com/jcmgray/einsum_bmm, which itself is incorporated into https://cotengra.readthedocs.io/en/latest/.
Thanks will look into this.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
does this work for multiple batching dims which can be non-continuos?