Last active
July 15, 2019 17:09
-
-
Save cshimmin/09480cf671d33a2331e8da1f2ad66419 to your computer and use it in GitHub Desktop.
calculate ECFs on jet constituents using einsum
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools as it | |
import numpy as np | |
# calculates ECF for a batch of jet constituents. | |
# x should have a shape like [batch_axis, particle_axis, 3] | |
# the last axis should contain (pT, eta, phi) | |
def ecf_numpy(N, beta, x): | |
pt = x[:,:,0] | |
eta = x[:,:,1:2] | |
phi = x[:,:,2:] | |
if N == 0: | |
return 1. | |
elif N == 1: | |
return np.sum(pt, axis=-1) | |
# pre-compute the R_ij matrix | |
R = np.concatenate([np.sqrt((eta[:,i]-eta)**2+(phi[:,i]-phi)**2) for i in range(x.shape[1])], axis=-1) | |
# and raise it to the beta power for use in the product expression | |
R_beta = R**beta | |
# indexing tensor, returns 1 if i>j>k... | |
eps = np.zeros((x.shape[1],)*N) | |
for idx in it.combinations(range(x.shape[1]), r=N): | |
eps[idx] = 1 | |
if N == 2: | |
return np.einsum('ij,...i,...j,...ij',eps,pt,pt,R_beta) | |
elif N == 3: | |
return np.einsum('ijk,...i,...j,...k,...ij,...ik,...jk',eps,pt,pt,pt,R_beta,R_beta,R_beta) | |
else: | |
# just for fun, the general case... | |
# use ascii chars a...z for einsum indices | |
letters = [chr(asc) for asc in range(97,97+N)] | |
idx_expression = ''.join(letters) +',' + ','.join('...%s'%c for c in letters) | |
for a,b in it.combinations(letters, r=2): | |
idx_expression += ',...%s%s'%(a,b) | |
#print(idx_expression) | |
args = (eps,) + (pt,)*N + (R_beta,)*(N*(N-1)//2) | |
return np.einsum(idx_expression, *args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment