Last active
July 27, 2021 09:47
-
-
Save GenevieveBuckley/8a2c5068d065ce81284fe606ac72c77d to your computer and use it in GitHub Desktop.
Tensordot auto-rechunking experiments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dask | |
import dask.array as da | |
import numpy as np | |
def _inner_axes(a_ndim, b_ndim, axes): | |
"""Given tensordot axes argument, return list of axes to sum over.""" | |
if isinstance(axes, (int, float)): | |
if axes == 0: | |
inner_axes_a = [] | |
inner_axes_b = [] | |
elif axes > 0: | |
inner_axes_a = list(range(a_ndim))[-axes:] | |
inner_axes_b = list(range(b_ndim))[:axes] | |
else: | |
axes_a, axes_b = axes | |
if isinstance(axes_a, (int, float)): | |
axes_a = [axes_a] | |
if isinstance(axes_b, (int, float)): | |
axes_b = [axes_b] | |
inner_axes_a = [i for i in range(a.ndim) if i in axes_a] | |
inner_axes_b = [i for i in range(b.ndim) if i in axes_b] | |
return inner_axes_a, inner_axes_b | |
def find_optimal_chunks(array, inner_axes, limit=None): | |
if limit is None: | |
limit = dask.utils.parse_bytes(dask.config.get('array.chunk-size')) | |
inner_chunks = [array.shape[ax] for ax in inner_axes] | |
while limit < (np.prod(inner_chunks) * array.dtype.itemsize): | |
inner_chunks[np.argmax(inner_chunks)] = np.max(inner_chunks) // 2 | |
optimal_chunks = [] | |
for ax in range(array.ndim): | |
if ax in inner_axes: | |
idx = inner_axes.index(ax) | |
optimal_chunks.append(inner_chunks[idx]) | |
else: | |
optimal_chunks.append('auto') | |
return optimal_chunks | |
if __name__=="__main__": | |
print("Example 1") | |
a = da.ones((100, 10_000_000), chunks=(100, 10_000_000)) | |
b = da.ones((10_000_000, 100), chunks=(10_000_000, 100)) | |
# a = da.ones((50, 20_000_000), chunks=(50, 20_000_000)) # optimal chunks (1, 10_000_000) | |
# b = da.ones((20_000_000, 50), chunks=(20_000_000, 50)) # optimal_chunks (10_000_000, 1) | |
axes = [1, 0] | |
inner_axes_a, inner_axes_b = _inner_axes(a.ndim, b.ndim, axes) | |
optimal_chunks_a = find_optimal_chunks(a, inner_axes_a) | |
optimal_chunks_b = find_optimal_chunks(b, inner_axes_b) | |
print("optimal_chunks_a:", optimal_chunks_a) | |
print("optimal_chunks_b:", optimal_chunks_b) | |
aa = a.rechunk(optimal_chunks_a) | |
bb = b.rechunk(optimal_chunks_b) | |
z = da.tensordot(aa, bb, axes=axes) | |
z = z.rechunk(['auto' for _ in range(z.ndim)]) | |
# z.compute() # time consuming, but stays nicely within memeory | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dask | |
import dask.array as da | |
import numpy as np | |
def max_dtype(items): | |
"""Returns dtype with largest number of bytes. | |
Parameters | |
---------- | |
items : Iterable of numpy arrays. | |
""" | |
nbytes = [i.dtype.itemsize for i in items] | |
max_pos = np.argmax(np.array(nbytes), axis=0) | |
dtype = items[max_pos].dtype | |
return dtype | |
def _outer_axes(a_ndim, b_ndim, axes): | |
"""Given tensordot axes argument, return outer axes lists.""" | |
if isinstance(axes, (int, float)): | |
if axes == 0: | |
outer_axes_a = list(range(a_ndim)) | |
outer_axes_b = list(range(b_ndim)) | |
elif axes > 0: | |
outer_axes_a = list(range(a_ndim))[:axes] | |
outer_axes_b = list(range(b_ndim))[-axes:] | |
else: | |
axes_a, axes_b = axes | |
if isinstance(axes_a, (int, float)): | |
axes_a = [axes_a] | |
if isinstance(axes_b, (int, float)): | |
axes_b = [axes_b] | |
outer_axes_a = [i for i in range(a.ndim) if i not in axes_a] | |
outer_axes_b = [i for i in range(b.ndim) if i not in axes_b] | |
return outer_axes_a, outer_axes_b | |
def _ideal_chunks(ndim, chunks, output_chunks, outer_axes): | |
"""Determine ideal chunk size for tensordot input array.""" | |
inner_axes = [i for i in range(ndim) if i not in outer_axes] | |
ideal_chunks = {} | |
for i in range(ndim): | |
if i in inner_axes: | |
ideal_chunks[i] = chunks[i] # FIXME! Should we rechunk to be smaller here too? | |
else: | |
chunk_values = output_chunks[outer_axes.index(i)] // 2 # FIXME! check this value | |
ideal_chunks[i] = chunk_values | |
return ideal_chunks | |
def _tensordot_rechunk_size(a, b, axes): | |
"""Get rechunk dictionaries for tensordot input arrays.""" | |
outer_axes_a, outer_axes_b = _outer_axes(a.ndim, b.ndim, axes) | |
shape = [a.shape[i] for i in outer_axes_a] + [b.shape[i] for i in outer_axes_b] | |
previous_chunks = [a.chunks[i] for i in outer_axes_a] + [b.chunks[i] for i in outer_axes_b] | |
chunks = ['auto' for _ in shape] | |
dtype = max_dtype([a, b]) | |
limit = dask.config.get('array.chunk-size') | |
desired_output_chunks = da.core.auto_chunks(chunks, shape, limit, dtype, previous_chunks=previous_chunks) | |
rechunk_a = _ideal_chunks(a.ndim, a.chunks, desired_output_chunks, outer_axes_a) | |
rechunk_b = _ideal_chunks(b.ndim, b.chunks, desired_output_chunks, outer_axes_b) | |
return rechunk_a, rechunk_b | |
if __name__=="__main__": | |
print("Example 1") | |
a = da.ones((1000, 1000, 1000), chunks=(100, 200, 100)) | |
b = da.ones((1000, 1000), chunks=(100, 200)) | |
axes = (0, 0) | |
rechunk_a, rechunk_b = _tensordot_rechunk_size(a, b, axes) | |
aa = a.rechunk(rechunk_a) | |
print("Original chunksize, array A:", a.chunksize) | |
print("Rechunked chunksize, array A:", aa.chunksize) | |
bb = b.rechunk(rechunk_b) | |
print("Original chunksize, array B:", b.chunksize) | |
print("Rechunked chunksize, array B:", bb.chunksize, "\n") | |
print("Example 2") | |
a = da.ones((1000000, 1)) | |
b = da.ones((1, 1000000)) | |
axes = 2 | |
rechunk_a, rechunk_b = _tensordot_rechunk_size(a, b, axes) | |
aa = a.rechunk(rechunk_a) | |
print("Original chunksize, array A:", a.chunksize) | |
print("Rechunked chunksize, array A:", aa.chunksize) | |
bb = b.rechunk(rechunk_b) | |
print("Original chunksize, array B:", b.chunksize) | |
print("Rechunked chunksize, array B:", bb.chunksize, "\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment