Last active
April 5, 2023 17:49
-
-
Save cshimmin/2798ce4ad47d7648f73f54a3c83f9066 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import scipy.optimize | |
import scipy.optimize | |
def norm_or_zero(v): | |
# normalize vectors along the last axis, avoiding div0 for zero vectors. | |
denom = np.linalg.norm(v, axis=-1, keepdims=True) | |
return np.where(denom > 0, v / denom, 0) | |
def get_thrust(p): | |
# p should have shape (*, N, 3), for the (px,py,pz) of an arbitrary batch-shape (*,) | |
# of input jets with up to N particles. Zero-pad along the N dimension for jets that | |
# have fewer than N particles. | |
# start with a guess of nhat by taking the axis of the momentum sum over the jet | |
psum = np.sum(p, axis=-2) # shape: (*, 3) | |
nhat_0 = norm_or_zero(psum) # shape: (*, 3) | |
# define optimization objective: | |
def obj(nhat): | |
# we have to restore nhat's shape since scipy.optimize flattens it | |
nhat = nhat.reshape(nhat_0.shape) | |
# force nhat to be normalized | |
# probably not ideal, creates degenerate directions in | |
# the gradiant, but, eh. you could use a proper constraint. | |
# or solve in a local 2d parameterization for S2. | |
nhat = norm_or_zero(nhat) # shape: (*, 3) | |
nhat = nhat[...,None,:] # shape: (*, 1, 3) | |
dotproduct = np.sum(p*nhat, axis=-1) # shape: (*, N) | |
numerator = np.sum(np.abs(dotproduct), axis=-1) # shape: (*,) | |
# we have to take the mean over all jets to | |
# to a scalar output, but the jacobian should be diagonal. | |
out = numerator.mean() | |
return -out # negative for minimize() | |
# maximize the numerator: | |
result = scipy.optimize.minimize(obj, nhat_0.flatten()) | |
print(result) | |
# now calculate thrust for the solved values of alphas: | |
nhat = result.x.reshape(nhat_0.shape) | |
nhat = norm_or_zero(nhat)[...,None,:] # shape (*,1,3) | |
dotproduct = np.sum(p*nhat, axis=-1) | |
numerator = np.sum(np.abs(dotproduct), axis=-1) # shape: (*,) | |
pmag = np.linalg.norm(p, axis=-1) # shape: (*,N) | |
denominator = pmag.sum(axis=-1) # shape: (*,) | |
return np.where(denominator > 0, numerator / denominator, 0) # shape: (*,) | |
def get_thrust_T(pt): | |
# pt should have shape (*, N, 2), for the px and py of an arbitrary batch-shape (*,) | |
# of input jets with up to N particles. Zero-pad along the N dimension for jets that | |
# have fewer than N particles. | |
# returns zero for jets that had all zero particles. | |
# transverse thrust can just be solved directly. | |
ptsum = np.sum(pt, axis=-2) # shape: (*, 2) | |
alphas = np.arctan2(ptsum[...,1], ptsum[...,0]) # shape: (*,) | |
alphas = alphas[...,None] # shape: (*,) -> (*,1) | |
dotproduct = (pt[...,0]*np.cos(alphas))**2 + (pt[...,1]*np.sin(alphas))**2 # shape: (*,N) | |
numerator = np.sqrt(dotproduct).sum(axis=-1) # shape: (*,) | |
ptmag = np.linalg.norm(pt, axis=-1) # shape: (*,N) | |
denominator = ptmag.sum(axis=-1) # shape: (*,) | |
return np.where(denominator > 0, numerator / denominator, 0) # shape: (*,) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment