Skip to content

Instantly share code, notes, and snippets.

@agoose77
Created May 15, 2023 21:55
Show Gist options
  • Save agoose77/7484748f78d79be59c3b265111ee9a62 to your computer and use it in GitHub Desktop.
Save agoose77/7484748f78d79be59c3b265111ee9a62 to your computer and use it in GitHub Desktop.
from __future__ import annotations
import vector
vector.register_awkward()
import awkward as ak
import numpy as np
from vector._compute.spatial import eta, theta, z
from typing import TypeVar
T = TypeVar("T")
def _sum_xy_z(lib, x_v: T, y_v: T, z_v: T) -> tuple[T, T, T]:
return (lib.sum(x_v, axis=1), lib.sum(y_v, axis=1), lib.sum(z_v, axis=1))
def _sum_xy_theta(lib, theta_v: T, x_v: T, y_v: T) -> tuple[T, T, T]:
z_v = z.xy_theta(lib, x_v, y_v, theta_v)
x_u = lib.sum(x_v, axis=1)
y_u = lib.sum(y_v, axis=1)
z_u = lib.sum(z_v, axis=1)
return (theta.xy_z(lib, x_u, y_u, z_u), x_u, y_u)
def _sum_xy_eta_xy_eta(lib, eta_v, x_v, y_v):
z_v = z.xy_eta(lib, x_v, y_v, eta_v)
x_u = lib.sum(x_v, axis=1)
y_u = lib.sum(y_v, axis=1)
z_u = lib.sum(z_v, axis=1)
return (x_u, y_u, eta.xy_z(lib, x_u, y_u, z_u))
def sorted_tuple(x):
return tuple(sorted(x))
DISPATCH_MAP = {
("x", "y", "z"): _sum_xy_z,
("theta", "x", "y"): _sum_xy_theta,
("eta", "x", "y"): _sum_xy_theta,
}
def _sum_vector3d(array, mask_identity):
fields = tuple(sorted(array.fields))
result = DISPATCH_MAP[fields](np, *[array[f] for f in fields])
return ak.contents.RecordArray(
contents=[ak.to_layout(c) for c in result],
fields=fields,
parameters=array.layout.content.parameters,
)
ak.behavior[ak.sum, "Vector3D"] = _sum_vector3d
vecs = vector.Array(
[
[
{"x": 1, "y": 2, "z": 3},
{"x": 4, "y": 5, "z": 6},
],
[
{"x": 1, "y": 2, "z": 3},
{"x": 4, "y": 5, "z": 6},
{"x": 1, "y": 1, "z": 1},
],
]
)
assert ak.sum(vecs, axis=0, keepdims=True).to_list() == [
[{"x": 2, "y": 4, "z": 6}, {"x": 8, "y": 10, "z": 12}, {"x": 1, "y": 1, "z": 1}]
]
assert ak.sum(vecs, axis=1, keepdims=True).to_list() == [
[{"x": 5, "y": 7, "z": 9}],
[{"x": 6, "y": 8, "z": 10}],
]
vecs_2 = vector.Array(
[
[
{"x": 1, "y": 2, "theta": 0.3},
{"x": 4, "y": 5, "theta": 0.6},
],
[
{"x": 1, "y": 2, "theta": 0.3},
{"x": 4, "y": 5, "theta": 0.1},
{"x": 1, "y": 1, "theta": 0.2},
],
]
)
assert ak.almost_equal(
ak.sum(vecs_2, axis=0, keepdims=True),
[
[
{"theta": 0.3, "x": 2, "y": 4},
{"theta": 0.17325, "x": 8, "y": 10},
{"theta": 0.2, "x": 1, "y": 1},
]
],
check_parameters=False,
)
assert ak.almost_equal(
ak.sum(vecs_2, axis=1, keepdims=True),
[
[{"theta": 0.478406, "x": 5, "y": 7}],
[{"theta": 0.127472, "x": 6, "y": 8}],
],
check_parameters=False,
check_regular=False,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment