Skip to content

Instantly share code, notes, and snippets.

@GenevieveBuckley
Created July 19, 2021 08:06
Show Gist options
  • Save GenevieveBuckley/06c0d5adf3273e56d0e23ac7bd817ed2 to your computer and use it in GitHub Desktop.
Save GenevieveBuckley/06c0d5adf3273e56d0e23ac7bd817ed2 to your computer and use it in GitHub Desktop.
Tensordot shape output
def _tensordot_shape_output(a, b, axes):
if isinstance(axes, (int, float)):
if axes == 0:
shape_out = a.shape + b.shape
chunks_out = a.chunks + b.chunks
elif axes > 0:
shape_out = a.shape[:axes-1] + b.shape[axes-1:]
chunks_out = a.chunks[:axes-1] + b.chunks[axes-1:]
else:
axes_a, axes_b = axes
if isinstance(axes_a, (int, float)):
axes_a = [axes_a]
if isinstance(axes_b, (int, float)):
axes_b = [axes_b]
outer_axes_a = [i for i in range(a.ndim) if i not in axes_a]
outer_axes_b = [i for i in range(b.ndim) if i not in axes_b]
shape_out_a = [a.shape[i] for i in outer_axes_a]
shape_out_b = [b.shape[i] for i in outer_axes_b]
shape_out = tuple(shape_out_a + shape_out_b)
chunks_out_a = [a.chunks[i] for i in outer_axes_a]
chunks_out_b = [b.chunks[i] for i in outer_axes_b]
chunks_out = tuple(chunks_out_a + chunks_out_b)
return shape_out, chunks_out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment