Created
May 15, 2023 21:55
-
-
Save agoose77/7484748f78d79be59c3b265111ee9a62 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import vector | |
vector.register_awkward() | |
import awkward as ak | |
import numpy as np | |
from vector._compute.spatial import eta, theta, z | |
from typing import TypeVar | |
T = TypeVar("T") | |
def _sum_xy_z(lib, x_v: T, y_v: T, z_v: T) -> tuple[T, T, T]: | |
return (lib.sum(x_v, axis=1), lib.sum(y_v, axis=1), lib.sum(z_v, axis=1)) | |
def _sum_xy_theta(lib, theta_v: T, x_v: T, y_v: T) -> tuple[T, T, T]: | |
z_v = z.xy_theta(lib, x_v, y_v, theta_v) | |
x_u = lib.sum(x_v, axis=1) | |
y_u = lib.sum(y_v, axis=1) | |
z_u = lib.sum(z_v, axis=1) | |
return (theta.xy_z(lib, x_u, y_u, z_u), x_u, y_u) | |
def _sum_xy_eta_xy_eta(lib, eta_v, x_v, y_v): | |
z_v = z.xy_eta(lib, x_v, y_v, eta_v) | |
x_u = lib.sum(x_v, axis=1) | |
y_u = lib.sum(y_v, axis=1) | |
z_u = lib.sum(z_v, axis=1) | |
return (x_u, y_u, eta.xy_z(lib, x_u, y_u, z_u)) | |
def sorted_tuple(x): | |
return tuple(sorted(x)) | |
DISPATCH_MAP = { | |
("x", "y", "z"): _sum_xy_z, | |
("theta", "x", "y"): _sum_xy_theta, | |
("eta", "x", "y"): _sum_xy_theta, | |
} | |
def _sum_vector3d(array, mask_identity): | |
fields = tuple(sorted(array.fields)) | |
result = DISPATCH_MAP[fields](np, *[array[f] for f in fields]) | |
return ak.contents.RecordArray( | |
contents=[ak.to_layout(c) for c in result], | |
fields=fields, | |
parameters=array.layout.content.parameters, | |
) | |
ak.behavior[ak.sum, "Vector3D"] = _sum_vector3d | |
vecs = vector.Array( | |
[ | |
[ | |
{"x": 1, "y": 2, "z": 3}, | |
{"x": 4, "y": 5, "z": 6}, | |
], | |
[ | |
{"x": 1, "y": 2, "z": 3}, | |
{"x": 4, "y": 5, "z": 6}, | |
{"x": 1, "y": 1, "z": 1}, | |
], | |
] | |
) | |
assert ak.sum(vecs, axis=0, keepdims=True).to_list() == [ | |
[{"x": 2, "y": 4, "z": 6}, {"x": 8, "y": 10, "z": 12}, {"x": 1, "y": 1, "z": 1}] | |
] | |
assert ak.sum(vecs, axis=1, keepdims=True).to_list() == [ | |
[{"x": 5, "y": 7, "z": 9}], | |
[{"x": 6, "y": 8, "z": 10}], | |
] | |
vecs_2 = vector.Array( | |
[ | |
[ | |
{"x": 1, "y": 2, "theta": 0.3}, | |
{"x": 4, "y": 5, "theta": 0.6}, | |
], | |
[ | |
{"x": 1, "y": 2, "theta": 0.3}, | |
{"x": 4, "y": 5, "theta": 0.1}, | |
{"x": 1, "y": 1, "theta": 0.2}, | |
], | |
] | |
) | |
assert ak.almost_equal( | |
ak.sum(vecs_2, axis=0, keepdims=True), | |
[ | |
[ | |
{"theta": 0.3, "x": 2, "y": 4}, | |
{"theta": 0.17325, "x": 8, "y": 10}, | |
{"theta": 0.2, "x": 1, "y": 1}, | |
] | |
], | |
check_parameters=False, | |
) | |
assert ak.almost_equal( | |
ak.sum(vecs_2, axis=1, keepdims=True), | |
[ | |
[{"theta": 0.478406, "x": 5, "y": 7}], | |
[{"theta": 0.127472, "x": 6, "y": 8}], | |
], | |
check_parameters=False, | |
check_regular=False, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment