Skip to content

Instantly share code, notes, and snippets.

@agoose77
Last active May 19, 2023 14:00
Show Gist options
  • Save agoose77/595cc3e3d0c44059e2c6b1f1ad183c81 to your computer and use it in GitHub Desktop.
Save agoose77/595cc3e3d0c44059e2c6b1f1ad183c81 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from collections.abc import Sequence
import awkward as ak
import numpy as np
def _compute_starts(parts: Sequence[ak.contents.Content]) -> tuple[int, ...]:
# The following won't work for typetracer
lengths = np.array([p.length for p in parts], dtype=np.int64)
stops = np.cumsum(lengths)
return tuple(stops - lengths)
def _reduce_parts_positional(
parts: Sequence[ak.contents.Content],
reducer,
keepdims: bool,
mask_identity: bool,
starts: tuple[int, ...],
):
partials = [reducer(p, keepdims=True, mask_identity=True, axis=0) for p in parts]
partial_values = [p[j] for p, j in zip(parts, partials)]
partial = ak.concatenate(partials, axis=0)
partial_value = ak.concatenate(partial_values, axis=0)
final_value = reducer(partial_value, keepdims=True, mask_identity=True)
final = partial[final_value] + ak.Array(starts)[final_value]
if not mask_identity:
final = ak.fill_none(final, -1, axis=0)
if not keepdims:
final = final[0]
return final
def _reduce_parts_non_positional(
parts: Sequence[ak.contents.Content],
reducer,
keepdims: bool,
mask_identity: bool,
starts: tuple[int, ...],
):
partials = [reducer(p, keepdims=True, mask_identity=True, axis=0) for p in parts]
partial = ak.concatenate(partials, axis=0)
if not mask_identity:
partial = ak.drop_none(partial, axis=0)
return reducer(partial, axis=0, keepdims=keepdims, mask_identity=mask_identity)
_positional_reducers = {ak.argmin, ak.argmax}
def reduce_parts_axis_0(
parts: Sequence[ak.contents.Content],
reducer,
keepdims: bool,
mask_identity: bool,
):
assert all(p.form == parts[0].form for p in parts[1:])
partial_reducer = (
_reduce_parts_positional
if reducer in _positional_reducers
else _reduce_parts_non_positional
)
starts = _compute_starts(parts)
return partial_reducer(parts, reducer, keepdims, mask_identity, starts)
parts = [
ak.to_layout(b)
for b in (
np.array([1, 2, 3], dtype=np.int64),
np.array([], dtype=np.int64),
np.array([0, -1], dtype=np.int64),
np.array([4], dtype=np.int64),
)
]
# Sum
assert reduce_parts_axis_0(parts, ak.sum, False, False) == 9
assert reduce_parts_axis_0(parts, ak.sum, True, False).to_list() == [9]
# Min
assert reduce_parts_axis_0(parts, ak.min, False, False) == -1
assert reduce_parts_axis_0(parts, ak.min, True, False).to_list() == [-1]
# Max
assert reduce_parts_axis_0(parts, ak.max, False, False) == 4
assert reduce_parts_axis_0(parts, ak.max, True, False).to_list() == [4]
# ArgMax
assert reduce_parts_axis_0(parts, ak.argmax, False, False) == 5
assert reduce_parts_axis_0(parts, ak.argmax, True, False).to_list() == [5]
assert reduce_parts_axis_0(parts, ak.argmax, False, True) == 5
assert reduce_parts_axis_0(parts, ak.argmax, True, True).to_list() == [5]
# ArgMin
assert reduce_parts_axis_0(parts, ak.argmin, False, False) == 4
assert reduce_parts_axis_0(parts, ak.argmin, True, False).to_list() == [4]
assert reduce_parts_axis_0(parts, ak.argmin, False, True) == 4
assert reduce_parts_axis_0(parts, ak.argmin, True, True).to_list() == [4]
@martindurant
Copy link

I'm sure we can work with this, although thinking about a tree form may be worthwhile, for the very-many-partitions case. That would need splitting of _reduce_parts_[non_]positional into per-batch and of-batches parts. I think that means passing along the starts of each batch, the first element of the starts for the parts within a batch.

@agoose77
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment