Last active
May 19, 2023 14:00
-
-
Save agoose77/595cc3e3d0c44059e2c6b1f1ad183c81 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from collections.abc import Sequence | |
import awkward as ak | |
import numpy as np | |
def _compute_starts(parts: Sequence[ak.contents.Content]) -> tuple[int, ...]: | |
# The following won't work for typetracer | |
lengths = np.array([p.length for p in parts], dtype=np.int64) | |
stops = np.cumsum(lengths) | |
return tuple(stops - lengths) | |
def _reduce_parts_positional( | |
parts: Sequence[ak.contents.Content], | |
reducer, | |
keepdims: bool, | |
mask_identity: bool, | |
starts: tuple[int, ...], | |
): | |
partials = [reducer(p, keepdims=True, mask_identity=True, axis=0) for p in parts] | |
partial_values = [p[j] for p, j in zip(parts, partials)] | |
partial = ak.concatenate(partials, axis=0) | |
partial_value = ak.concatenate(partial_values, axis=0) | |
final_value = reducer(partial_value, keepdims=True, mask_identity=True) | |
final = partial[final_value] + ak.Array(starts)[final_value] | |
if not mask_identity: | |
final = ak.fill_none(final, -1, axis=0) | |
if not keepdims: | |
final = final[0] | |
return final | |
def _reduce_parts_non_positional( | |
parts: Sequence[ak.contents.Content], | |
reducer, | |
keepdims: bool, | |
mask_identity: bool, | |
starts: tuple[int, ...], | |
): | |
partials = [reducer(p, keepdims=True, mask_identity=True, axis=0) for p in parts] | |
partial = ak.concatenate(partials, axis=0) | |
if not mask_identity: | |
partial = ak.drop_none(partial, axis=0) | |
return reducer(partial, axis=0, keepdims=keepdims, mask_identity=mask_identity) | |
_positional_reducers = {ak.argmin, ak.argmax} | |
def reduce_parts_axis_0( | |
parts: Sequence[ak.contents.Content], | |
reducer, | |
keepdims: bool, | |
mask_identity: bool, | |
): | |
assert all(p.form == parts[0].form for p in parts[1:]) | |
partial_reducer = ( | |
_reduce_parts_positional | |
if reducer in _positional_reducers | |
else _reduce_parts_non_positional | |
) | |
starts = _compute_starts(parts) | |
return partial_reducer(parts, reducer, keepdims, mask_identity, starts) | |
parts = [ | |
ak.to_layout(b) | |
for b in ( | |
np.array([1, 2, 3], dtype=np.int64), | |
np.array([], dtype=np.int64), | |
np.array([0, -1], dtype=np.int64), | |
np.array([4], dtype=np.int64), | |
) | |
] | |
# Sum | |
assert reduce_parts_axis_0(parts, ak.sum, False, False) == 9 | |
assert reduce_parts_axis_0(parts, ak.sum, True, False).to_list() == [9] | |
# Min | |
assert reduce_parts_axis_0(parts, ak.min, False, False) == -1 | |
assert reduce_parts_axis_0(parts, ak.min, True, False).to_list() == [-1] | |
# Max | |
assert reduce_parts_axis_0(parts, ak.max, False, False) == 4 | |
assert reduce_parts_axis_0(parts, ak.max, True, False).to_list() == [4] | |
# ArgMax | |
assert reduce_parts_axis_0(parts, ak.argmax, False, False) == 5 | |
assert reduce_parts_axis_0(parts, ak.argmax, True, False).to_list() == [5] | |
assert reduce_parts_axis_0(parts, ak.argmax, False, True) == 5 | |
assert reduce_parts_axis_0(parts, ak.argmax, True, True).to_list() == [5] | |
# ArgMin | |
assert reduce_parts_axis_0(parts, ak.argmin, False, False) == 4 | |
assert reduce_parts_axis_0(parts, ak.argmin, True, False).to_list() == [4] | |
assert reduce_parts_axis_0(parts, ak.argmin, False, True) == 4 | |
assert reduce_parts_axis_0(parts, ak.argmin, True, True).to_list() == [4] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'm sure we can work with this, although thinking about a tree form may be worthwhile, for the very-many-partitions case. That would need splitting of
_reduce_parts_[non_]positional
into per-batch and of-batches parts. I think that means passing along the starts of each batch, the first element of the starts for the parts within a batch.