Skip to content

Instantly share code, notes, and snippets.

@zacbaum
Created June 29, 2021 09:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zacbaum/8b22f0c02c5847919a2b38e09fb63cf5 to your computer and use it in GitHub Desktop.
Save zacbaum/8b22f0c02c5847919a2b38e09fb63cf5 to your computer and use it in GitHub Desktop.
3D Slicer Process for Tensorflow Predictions
from Processes import Process, ProcessesLogic
import pickle
class LivePredictionProcess(Process):
def __init__(self, scriptPath, volume, model_path, active=True):
Process.__init__(self, scriptPath)
self.volume = volume # Numpy array, to use as input for the model.
self.model_path = model_path # Path to the TF model you'd like to load, as TF Models are not picklable.
self.active = bytes([1]) if active else bytes([0]) # Used to stop the process by enabling/disabling the script.
self.name = f"LivePrediction-{os.path.basename(model_path)}"
self.output = None
def setActive(self, active=True):
self.active = bytes([1]) if active else bytes([0])
def onStarted(self):
input = self.prepareProcessInput()
input_len = len(input).to_bytes(8, byteorder='big')
self.write(self.active) # Write if the predictions are still running
self.write(input_len) # Write the length of the input we have yet to recieve in the buffer
self.write(input) # Write the pickled inputs for the model
def prepareProcessInput(self):
input = {}
input['model_path'] = self.model_path
input['volume'] = self.volume
return pickle.dumps(input)
def useProcessOutput(self, processOutput):
try:
output = pickle.loads(processOutput)
self.output = output
except EOFError:
self.output = None
#####################
# Running the process
#####################
scriptFolder = slicer.modules.YOURMODULEHERE.path.replace('YOURMODULEHERE.py', '/Resources/ProcessScripts/')
scriptPath = os.path.join(scriptFolder, "LivePrediction.slicer.py")
livePredictionProcess = LivePredictionProcess(scriptPath, images, modelPath)
livePredictionProcess.connect('readyReadStandardOutput()', CallbackToUsePredictedImages)
def onLivePredictProcessCompleted():
logging.info('Live Prediction: Process Finished')
livePredictionProcess.disconnect('readyReadStandardOutput()', CallbackToUsePredictedImages)
logic = ProcessesLogic(completedCallback=lambda: onLivePredictProcessCompleted())
logic.addProcess(self.livePredictionProcess)
logic.run()
logging.info('Live Prediction: Process Started')
#####################
# Send data to the process
#####################
livePredictionProcess.volume = newData
livePredictionProcess.onStarted()
#####################
# Recieve data from the process
#####################
def CallbackToUsePredictedImages(self):
stdout = livePredictionProcess.readAllStandardOutput().data()
livePredictionProcess.useProcessOutput(stdout)
y = livePredictionProcess.output['prediction']
# Do something with the outputs (update segmented image array, etc.)
#####################
# Stop the process
#####################
livePredictionProcess.setActive(False) # Disable the active byte
livePredictionProcess.onStarted() # Send the modified byte to the process to terminate it
import pickle
import sys
import tensorflow as tf
from tensorflow.keras.models import load_model
ACTIVE = bytes([1])
NOT_ACTIVE = bytes([0])
# Do first read at the same time that we instantiate the model, this ensures that the TF graph is written/ready right away when we start live predictions.
input = sys.stdin.buffer.read(1) # Read control byte
if input == ACTIVE:
input_length = sys.stdin.buffer.read(8) # Read data length
input_length = int.from_bytes(input_length, byteorder='big')
input_data = sys.stdin.buffer.read(input_length) # Read the data
elif input == NOT_ACTIVE:
sys.exit()
input_data = pickle.loads(input_data)
model = load_model(input_data['model_path'], compile=False)
model.call = tf.function(model.call, experimental_relax_shapes=True) # If you are doing Batch Size == 1 predictions, this tends to speed things up (wrapping .call in a @tf.function decorator)
_ = model(input_data['volume'], training=False) # Run dummy prediction on first (blanked) volume to instantiate TF graph
while True:
input = sys.stdin.buffer.read(1) # Read control byte
if input == ACTIVE:
input_length = sys.stdin.buffer.read(8) # Read data length
input_length = int.from_bytes(input_length, byteorder='big')
input_data = sys.stdin.buffer.read(input_length) # Read the data
if input == NOT_ACTIVE:
break
input_data = pickle.loads(input_data)
output = {}
output['prediction'] = model(input_data['volume'], training=False).numpy()
sys.stdout.buffer.write(pickle.dumps(output))
sys.stdout.buffer.flush()
from Processes import Process, ProcessesLogic
import pickle
class OfflinePredictionProcess(Process):
def __init__(self, scriptPath, volume, model_path):
Process.__init__(self, scriptPath)
self.volume = volume # Numpy array, to use as input for the model.
self.model_path = model_path # Path to the TF model you'd like to load, as TF Models are not picklable.
self.name = f"OfflinePrediction-{os.path.basename(model_path)}"
self.output = None
def prepareProcessInput(self):
input = {}
input['volume'] = self.volume
input['model_path'] = self.model_path
with open('data.pkl', 'wb') as f:
pickle.dump(input, f)
def useProcessOutput(self, processOutput):
output = pickle.loads(processOutput)
os.remove('data.pkl')
self.output = output
#####################
# Running the process
#####################
scriptFolder = slicer.modules.YOURMODULEHERE.path.replace('YOURMODULEHERE.py', '/Resources/ProcessScripts/')
scriptPath = os.path.join(scriptFolder, "OfflinePrediction.slicer.py")
predictionProcess = OfflinePredictionProcess(scriptPath, images, modelPath)
def onOfflinePredictProcessCompleted():
logging.info('Offline Prediction: Process Finished')
y = predictionProcess.output["prediction"]
# Do something with the outputs (update segmented volume, use as input for volume reconstruction, etc.)
logic = ProcessesLogic(completedCallback=lambda : onOfflinePredictProcessCompleted())
logic.addProcess(predictionProcess)
logic.run()
logging.info('Offline Prediction: Process Started')
import pickle
import sys
from tensorflow.keras.models import load_model
with open('data.pkl', 'rb') as f:
input = pickle.load(f)
model = load_model(input["model_path"], compile=False)
pred = model.predict(input["volume"])
output = {}
output['prediction'] = pred
sys.stdout.buffer.write(pickle.dumps(output))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment