Skip to content

Instantly share code, notes, and snippets.

@jpata
Created April 24, 2021 05:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jpata/5fc22d5232018182732aa79066fc2260 to your computer and use it in GitHub Desktop.
Save jpata/5fc22d5232018182732aa79066fc2260 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import numpy as np
#download model_full from https://jpata.web.cern.ch/jpata/2101.08578/v2/model_full.tar.gz
#untar the archive
#note that the custom loss is not included!
model = tf.keras.models.load_model("model_full")
model.summary()
#(batch size, elements in event, features)
inputs = np.random.randn(2, 6400, 12)
outputs = model(inputs)
pred_id = outputs[:, :, :6]
pred_charge = outputs[:, :, 7]
pred_momentum = outputs[:, :, 8:]
print(pred_id)
print(pred_charge)
print(pred_momentum)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment