Last active
May 19, 2023 14:00
-
-
Save agoose77/595cc3e3d0c44059e2c6b1f1ad183c81 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from collections.abc import Sequence | |
import awkward as ak | |
import numpy as np | |
def _compute_starts(parts: Sequence[ak.contents.Content]) -> tuple[int, ...]: | |
# The following won't work for typetracer | |
lengths = np.array([p.length for p in parts], dtype=np.int64) | |
stops = np.cumsum(lengths) | |
return tuple(stops - lengths) | |
def _reduce_parts_positional( | |
parts: Sequence[ak.contents.Content], | |
reducer, | |
keepdims: bool, | |
mask_identity: bool, | |
starts: tuple[int, ...], | |
): | |
partials = [reducer(p, keepdims=True, mask_identity=True, axis=0) for p in parts] | |
partial_values = [p[j] for p, j in zip(parts, partials)] | |
partial = ak.concatenate(partials, axis=0) | |
partial_value = ak.concatenate(partial_values, axis=0) | |
final_value = reducer(partial_value, keepdims=True, mask_identity=True) | |
final = partial[final_value] + ak.Array(starts)[final_value] | |
if not mask_identity: | |
final = ak.fill_none(final, -1, axis=0) | |
if not keepdims: | |
final = final[0] | |
return final | |
def _reduce_parts_non_positional( | |
parts: Sequence[ak.contents.Content], | |
reducer, | |
keepdims: bool, | |
mask_identity: bool, | |
starts: tuple[int, ...], | |
): | |
partials = [reducer(p, keepdims=True, mask_identity=True, axis=0) for p in parts] | |
partial = ak.concatenate(partials, axis=0) | |
if not mask_identity: | |
partial = ak.drop_none(partial, axis=0) | |
return reducer(partial, axis=0, keepdims=keepdims, mask_identity=mask_identity) | |
_positional_reducers = {ak.argmin, ak.argmax} | |
def reduce_parts_axis_0( | |
parts: Sequence[ak.contents.Content], | |
reducer, | |
keepdims: bool, | |
mask_identity: bool, | |
): | |
assert all(p.form == parts[0].form for p in parts[1:]) | |
partial_reducer = ( | |
_reduce_parts_positional | |
if reducer in _positional_reducers | |
else _reduce_parts_non_positional | |
) | |
starts = _compute_starts(parts) | |
return partial_reducer(parts, reducer, keepdims, mask_identity, starts) | |
parts = [ | |
ak.to_layout(b) | |
for b in ( | |
np.array([1, 2, 3], dtype=np.int64), | |
np.array([], dtype=np.int64), | |
np.array([0, -1], dtype=np.int64), | |
np.array([4], dtype=np.int64), | |
) | |
] | |
# Sum | |
assert reduce_parts_axis_0(parts, ak.sum, False, False) == 9 | |
assert reduce_parts_axis_0(parts, ak.sum, True, False).to_list() == [9] | |
# Min | |
assert reduce_parts_axis_0(parts, ak.min, False, False) == -1 | |
assert reduce_parts_axis_0(parts, ak.min, True, False).to_list() == [-1] | |
# Max | |
assert reduce_parts_axis_0(parts, ak.max, False, False) == 4 | |
assert reduce_parts_axis_0(parts, ak.max, True, False).to_list() == [4] | |
# ArgMax | |
assert reduce_parts_axis_0(parts, ak.argmax, False, False) == 5 | |
assert reduce_parts_axis_0(parts, ak.argmax, True, False).to_list() == [5] | |
assert reduce_parts_axis_0(parts, ak.argmax, False, True) == 5 | |
assert reduce_parts_axis_0(parts, ak.argmax, True, True).to_list() == [5] | |
# ArgMin | |
assert reduce_parts_axis_0(parts, ak.argmin, False, False) == 4 | |
assert reduce_parts_axis_0(parts, ak.argmin, True, False).to_list() == [4] | |
assert reduce_parts_axis_0(parts, ak.argmin, False, True) == 4 | |
assert reduce_parts_axis_0(parts, ak.argmin, True, True).to_list() == [4] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Already on it ;) dask-contrib/dask-awkward#267