Last active
May 9, 2023 13:17
-
-
Save agoose77/9fc1938bf15d1dedb0bd72a4cd501ed1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from functools import wraps | |
from inspect import signature | |
import awkward as ak | |
# Awkward functions that are not public | |
def behavior_of(args) -> dict | None: | |
it = iter(args) | |
behavior = None | |
for array in it: | |
if isinstance(array, (ak.Array, ak.Record)): | |
behavior = array.behavior | |
break | |
for array in it: | |
if isinstance(array, (ak.Array, ak.Record)) and array.behavior is not None: | |
behavior = {**behavior, **array.behavior} | |
return behavior | |
def maybe_wrap_highlevel(obj, highlevel: bool, behavior: dict | None): | |
if highlevel: | |
if isinstance(obj, ak.contents.Content): | |
return ak.Array(obj, behavior=behavior) | |
elif isinstance(obj, ak.record.Record): | |
return ak.Record(obj, behavior=behavior) | |
return obj | |
# "Transformer" implementation | |
def do_transform(func, /, *args, highlevel=True, **kwargs): | |
# Bind the complete set of arguments to obtain a `BoundArguments` object | |
bound_args = signature(func).bind(*args, **kwargs) | |
# Pull out the names and initial values of the array-like objects | |
array_items = { | |
k: v for k, v in bound_args.arguments.items() if isinstance(v, ak.Array) | |
} | |
behavior = behavior_of(array_items.values()) | |
def wrapper(inputs, **other): | |
# Update array-arguments with new values | |
for name, new_value in zip(array_items, inputs): | |
bound_args.arguments[name] = maybe_wrap_highlevel( | |
new_value, highlevel, behavior | |
) | |
# Apply and return the result | |
return func(*bound_args.args, **bound_args.kwargs, **other) | |
return ak.transform(wrapper, *array_items.keys()) | |
# "Transformer" decorator | |
def transformer(highlevel: bool = True): | |
""" | |
Implement `do_transform` as a decorator | |
Args: | |
highlevel: whether arguments to decorated function should be made highlevel | |
Returns: transform result | |
""" | |
# Capture `func` | |
def decorator(func): | |
@wraps(func) | |
def wrapper(*args, **kwargs): | |
return do_transform(func, *args, highlevel=highlevel, **kwargs) | |
return wrapper | |
return decorator | |
# Example function | |
def func(arr_1: ak.Array, scalar_1: int, arr_2: ak.Array, scalar_2: int, **other): | |
print("arrays", arr_1, arr_2) | |
print("scalars", scalar_1, scalar_2) | |
print("Other args", other) | |
return ak.to_layout((arr_1 + arr_2) * scalar_2 + scalar_1) | |
result = do_transform( | |
func, ak.Array([1, 2, 3]), 1, ak.Array([[1, 2, 3], [4, 5], [6, 7, 8]]), 2 | |
) | |
# Example function using decorator | |
@transformer() | |
def func_2(arr_1: ak.Array, scalar_1: int, arr_2: ak.Array, scalar_2: int, **other): | |
print("arrays", arr_1, arr_2) | |
print("scalars", scalar_1, scalar_2) | |
print("Other args", other) | |
return ak.to_layout((arr_1 + arr_2) * scalar_2 + scalar_1) | |
result_2 = func_2(ak.Array([1, 2, 3]), 1, ak.Array([[1, 2, 3], [4, 5], [6, 7, 8]]), 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment