Skip to content

Instantly share code, notes, and snippets.

@cshimmin
Last active July 15, 2019 17:09
Show Gist options
  • Save cshimmin/09480cf671d33a2331e8da1f2ad66419 to your computer and use it in GitHub Desktop.
Save cshimmin/09480cf671d33a2331e8da1f2ad66419 to your computer and use it in GitHub Desktop.
calculate ECFs on jet constituents using einsum
import itertools as it
import numpy as np
# calculates ECF for a batch of jet constituents.
# x should have a shape like [batch_axis, particle_axis, 3]
# the last axis should contain (pT, eta, phi)
def ecf_numpy(N, beta, x):
pt = x[:,:,0]
eta = x[:,:,1:2]
phi = x[:,:,2:]
if N == 0:
return 1.
elif N == 1:
return np.sum(pt, axis=-1)
# pre-compute the R_ij matrix
R = np.concatenate([np.sqrt((eta[:,i]-eta)**2+(phi[:,i]-phi)**2) for i in range(x.shape[1])], axis=-1)
# and raise it to the beta power for use in the product expression
R_beta = R**beta
# indexing tensor, returns 1 if i>j>k...
eps = np.zeros((x.shape[1],)*N)
for idx in it.combinations(range(x.shape[1]), r=N):
eps[idx] = 1
if N == 2:
return np.einsum('ij,...i,...j,...ij',eps,pt,pt,R_beta)
elif N == 3:
return np.einsum('ijk,...i,...j,...k,...ij,...ik,...jk',eps,pt,pt,pt,R_beta,R_beta,R_beta)
else:
# just for fun, the general case...
# use ascii chars a...z for einsum indices
letters = [chr(asc) for asc in range(97,97+N)]
idx_expression = ''.join(letters) +',' + ','.join('...%s'%c for c in letters)
for a,b in it.combinations(letters, r=2):
idx_expression += ',...%s%s'%(a,b)
#print(idx_expression)
args = (eps,) + (pt,)*N + (R_beta,)*(N*(N-1)//2)
return np.einsum(idx_expression, *args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment