Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save GenevieveBuckley/8a2c5068d065ce81284fe606ac72c77d to your computer and use it in GitHub Desktop.
Save GenevieveBuckley/8a2c5068d065ce81284fe606ac72c77d to your computer and use it in GitHub Desktop.
Tensordot auto-rechunking experiments
import dask
import dask.array as da
import numpy as np
def _inner_axes(a_ndim, b_ndim, axes):
"""Given tensordot axes argument, return list of axes to sum over."""
if isinstance(axes, (int, float)):
if axes == 0:
inner_axes_a = []
inner_axes_b = []
elif axes > 0:
inner_axes_a = list(range(a_ndim))[-axes:]
inner_axes_b = list(range(b_ndim))[:axes]
else:
axes_a, axes_b = axes
if isinstance(axes_a, (int, float)):
axes_a = [axes_a]
if isinstance(axes_b, (int, float)):
axes_b = [axes_b]
inner_axes_a = [i for i in range(a.ndim) if i in axes_a]
inner_axes_b = [i for i in range(b.ndim) if i in axes_b]
return inner_axes_a, inner_axes_b
def find_optimal_chunks(array, inner_axes, limit=None):
if limit is None:
limit = dask.utils.parse_bytes(dask.config.get('array.chunk-size'))
inner_chunks = [array.shape[ax] for ax in inner_axes]
while limit < (np.prod(inner_chunks) * array.dtype.itemsize):
inner_chunks[np.argmax(inner_chunks)] = np.max(inner_chunks) // 2
optimal_chunks = []
for ax in range(array.ndim):
if ax in inner_axes:
idx = inner_axes.index(ax)
optimal_chunks.append(inner_chunks[idx])
else:
optimal_chunks.append('auto')
return optimal_chunks
if __name__=="__main__":
print("Example 1")
a = da.ones((100, 10_000_000), chunks=(100, 10_000_000))
b = da.ones((10_000_000, 100), chunks=(10_000_000, 100))
# a = da.ones((50, 20_000_000), chunks=(50, 20_000_000)) # optimal chunks (1, 10_000_000)
# b = da.ones((20_000_000, 50), chunks=(20_000_000, 50)) # optimal_chunks (10_000_000, 1)
axes = [1, 0]
inner_axes_a, inner_axes_b = _inner_axes(a.ndim, b.ndim, axes)
optimal_chunks_a = find_optimal_chunks(a, inner_axes_a)
optimal_chunks_b = find_optimal_chunks(b, inner_axes_b)
print("optimal_chunks_a:", optimal_chunks_a)
print("optimal_chunks_b:", optimal_chunks_b)
aa = a.rechunk(optimal_chunks_a)
bb = b.rechunk(optimal_chunks_b)
z = da.tensordot(aa, bb, axes=axes)
z = z.rechunk(['auto' for _ in range(z.ndim)])
# z.compute() # time consuming, but stays nicely within memeory
import dask
import dask.array as da
import numpy as np
def max_dtype(items):
"""Returns dtype with largest number of bytes.
Parameters
----------
items : Iterable of numpy arrays.
"""
nbytes = [i.dtype.itemsize for i in items]
max_pos = np.argmax(np.array(nbytes), axis=0)
dtype = items[max_pos].dtype
return dtype
def _outer_axes(a_ndim, b_ndim, axes):
"""Given tensordot axes argument, return outer axes lists."""
if isinstance(axes, (int, float)):
if axes == 0:
outer_axes_a = list(range(a_ndim))
outer_axes_b = list(range(b_ndim))
elif axes > 0:
outer_axes_a = list(range(a_ndim))[:axes]
outer_axes_b = list(range(b_ndim))[-axes:]
else:
axes_a, axes_b = axes
if isinstance(axes_a, (int, float)):
axes_a = [axes_a]
if isinstance(axes_b, (int, float)):
axes_b = [axes_b]
outer_axes_a = [i for i in range(a.ndim) if i not in axes_a]
outer_axes_b = [i for i in range(b.ndim) if i not in axes_b]
return outer_axes_a, outer_axes_b
def _ideal_chunks(ndim, chunks, output_chunks, outer_axes):
"""Determine ideal chunk size for tensordot input array."""
inner_axes = [i for i in range(ndim) if i not in outer_axes]
ideal_chunks = {}
for i in range(ndim):
if i in inner_axes:
ideal_chunks[i] = chunks[i] # FIXME! Should we rechunk to be smaller here too?
else:
chunk_values = output_chunks[outer_axes.index(i)] // 2 # FIXME! check this value
ideal_chunks[i] = chunk_values
return ideal_chunks
def _tensordot_rechunk_size(a, b, axes):
"""Get rechunk dictionaries for tensordot input arrays."""
outer_axes_a, outer_axes_b = _outer_axes(a.ndim, b.ndim, axes)
shape = [a.shape[i] for i in outer_axes_a] + [b.shape[i] for i in outer_axes_b]
previous_chunks = [a.chunks[i] for i in outer_axes_a] + [b.chunks[i] for i in outer_axes_b]
chunks = ['auto' for _ in shape]
dtype = max_dtype([a, b])
limit = dask.config.get('array.chunk-size')
desired_output_chunks = da.core.auto_chunks(chunks, shape, limit, dtype, previous_chunks=previous_chunks)
rechunk_a = _ideal_chunks(a.ndim, a.chunks, desired_output_chunks, outer_axes_a)
rechunk_b = _ideal_chunks(b.ndim, b.chunks, desired_output_chunks, outer_axes_b)
return rechunk_a, rechunk_b
if __name__=="__main__":
print("Example 1")
a = da.ones((1000, 1000, 1000), chunks=(100, 200, 100))
b = da.ones((1000, 1000), chunks=(100, 200))
axes = (0, 0)
rechunk_a, rechunk_b = _tensordot_rechunk_size(a, b, axes)
aa = a.rechunk(rechunk_a)
print("Original chunksize, array A:", a.chunksize)
print("Rechunked chunksize, array A:", aa.chunksize)
bb = b.rechunk(rechunk_b)
print("Original chunksize, array B:", b.chunksize)
print("Rechunked chunksize, array B:", bb.chunksize, "\n")
print("Example 2")
a = da.ones((1000000, 1))
b = da.ones((1, 1000000))
axes = 2
rechunk_a, rechunk_b = _tensordot_rechunk_size(a, b, axes)
aa = a.rechunk(rechunk_a)
print("Original chunksize, array A:", a.chunksize)
print("Rechunked chunksize, array A:", aa.chunksize)
bb = b.rechunk(rechunk_b)
print("Original chunksize, array B:", b.chunksize)
print("Rechunked chunksize, array B:", bb.chunksize, "\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment