Skip to content

Instantly share code, notes, and snippets.

@agoose77
Last active May 9, 2023 13:17
Show Gist options
  • Save agoose77/9fc1938bf15d1dedb0bd72a4cd501ed1 to your computer and use it in GitHub Desktop.
Save agoose77/9fc1938bf15d1dedb0bd72a4cd501ed1 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from functools import wraps
from inspect import signature
import awkward as ak
# Awkward functions that are not public
def behavior_of(args) -> dict | None:
it = iter(args)
behavior = None
for array in it:
if isinstance(array, (ak.Array, ak.Record)):
behavior = array.behavior
break
for array in it:
if isinstance(array, (ak.Array, ak.Record)) and array.behavior is not None:
behavior = {**behavior, **array.behavior}
return behavior
def maybe_wrap_highlevel(obj, highlevel: bool, behavior: dict | None):
if highlevel:
if isinstance(obj, ak.contents.Content):
return ak.Array(obj, behavior=behavior)
elif isinstance(obj, ak.record.Record):
return ak.Record(obj, behavior=behavior)
return obj
# "Transformer" implementation
def do_transform(func, /, *args, highlevel=True, **kwargs):
# Bind the complete set of arguments to obtain a `BoundArguments` object
bound_args = signature(func).bind(*args, **kwargs)
# Pull out the names and initial values of the array-like objects
array_items = {
k: v for k, v in bound_args.arguments.items() if isinstance(v, ak.Array)
}
behavior = behavior_of(array_items.values())
def wrapper(inputs, **other):
# Update array-arguments with new values
for name, new_value in zip(array_items, inputs):
bound_args.arguments[name] = maybe_wrap_highlevel(
new_value, highlevel, behavior
)
# Apply and return the result
return func(*bound_args.args, **bound_args.kwargs, **other)
return ak.transform(wrapper, *array_items.keys())
# "Transformer" decorator
def transformer(highlevel: bool = True):
"""
Implement `do_transform` as a decorator
Args:
highlevel: whether arguments to decorated function should be made highlevel
Returns: transform result
"""
# Capture `func`
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return do_transform(func, *args, highlevel=highlevel, **kwargs)
return wrapper
return decorator
# Example function
def func(arr_1: ak.Array, scalar_1: int, arr_2: ak.Array, scalar_2: int, **other):
print("arrays", arr_1, arr_2)
print("scalars", scalar_1, scalar_2)
print("Other args", other)
return ak.to_layout((arr_1 + arr_2) * scalar_2 + scalar_1)
result = do_transform(
func, ak.Array([1, 2, 3]), 1, ak.Array([[1, 2, 3], [4, 5], [6, 7, 8]]), 2
)
# Example function using decorator
@transformer()
def func_2(arr_1: ak.Array, scalar_1: int, arr_2: ak.Array, scalar_2: int, **other):
print("arrays", arr_1, arr_2)
print("scalars", scalar_1, scalar_2)
print("Other args", other)
return ak.to_layout((arr_1 + arr_2) * scalar_2 + scalar_1)
result_2 = func_2(ak.Array([1, 2, 3]), 1, ak.Array([[1, 2, 3], [4, 5], [6, 7, 8]]), 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment