Skip to content

Instantly share code, notes, and snippets.

@brandondube
Created December 29, 2021 15:34
Show Gist options
  • Save brandondube/43ab9e9f173252f5a97e0c0d5c3ca54f to your computer and use it in GitHub Desktop.
Save brandondube/43ab9e9f173252f5a97e0c0d5c3ca54f to your computer and use it in GitHub Desktop.
dot product batch with einsum
from concurrent.futures import ThreadPoolExecutor, wait
def slices_for_split(N, m):
"""Split an array of length N into m chunks, with the final chunk being a bit smaller if need be.
Does not actually split, returns a list of slice objects for reuse.
"""
# this code is largely adapted from the numpy source code
Neach_section, extras = divmod(N, m)
Neach_section, extras
slices = []
start = 0
end = Neach_section
# slices.append(first)
for _ in range(m-1):
slc = slice(start, end)
start += Neach_section
end += Neach_section
slices.append(slc)
slices.append(slice(start, N))
return slices
def dot_batch_via_einsum(a, b, pool):
nworkers = pool._max_workers
N = a.shape[0]
out = np.zeros(N, dtype=a.dtype)
slices = slices_for_split(N, nworkers)
def work_func(j):
# j = chunk index
slc = slices[j]
np.einsum('ij,ij->i', a[slc], b[slc], out=out[slc])
return
futures = [pool.submit(work_func, j) for j in range(nworkers)]
wait(futures)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment