Created
June 29, 2021 09:33
-
-
Save zacbaum/8b22f0c02c5847919a2b38e09fb63cf5 to your computer and use it in GitHub Desktop.
3D Slicer Process for Tensorflow Predictions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from Processes import Process, ProcessesLogic | |
import pickle | |
class LivePredictionProcess(Process): | |
def __init__(self, scriptPath, volume, model_path, active=True): | |
Process.__init__(self, scriptPath) | |
self.volume = volume # Numpy array, to use as input for the model. | |
self.model_path = model_path # Path to the TF model you'd like to load, as TF Models are not picklable. | |
self.active = bytes([1]) if active else bytes([0]) # Used to stop the process by enabling/disabling the script. | |
self.name = f"LivePrediction-{os.path.basename(model_path)}" | |
self.output = None | |
def setActive(self, active=True): | |
self.active = bytes([1]) if active else bytes([0]) | |
def onStarted(self): | |
input = self.prepareProcessInput() | |
input_len = len(input).to_bytes(8, byteorder='big') | |
self.write(self.active) # Write if the predictions are still running | |
self.write(input_len) # Write the length of the input we have yet to recieve in the buffer | |
self.write(input) # Write the pickled inputs for the model | |
def prepareProcessInput(self): | |
input = {} | |
input['model_path'] = self.model_path | |
input['volume'] = self.volume | |
return pickle.dumps(input) | |
def useProcessOutput(self, processOutput): | |
try: | |
output = pickle.loads(processOutput) | |
self.output = output | |
except EOFError: | |
self.output = None | |
##################### | |
# Running the process | |
##################### | |
scriptFolder = slicer.modules.YOURMODULEHERE.path.replace('YOURMODULEHERE.py', '/Resources/ProcessScripts/') | |
scriptPath = os.path.join(scriptFolder, "LivePrediction.slicer.py") | |
livePredictionProcess = LivePredictionProcess(scriptPath, images, modelPath) | |
livePredictionProcess.connect('readyReadStandardOutput()', CallbackToUsePredictedImages) | |
def onLivePredictProcessCompleted(): | |
logging.info('Live Prediction: Process Finished') | |
livePredictionProcess.disconnect('readyReadStandardOutput()', CallbackToUsePredictedImages) | |
logic = ProcessesLogic(completedCallback=lambda: onLivePredictProcessCompleted()) | |
logic.addProcess(self.livePredictionProcess) | |
logic.run() | |
logging.info('Live Prediction: Process Started') | |
##################### | |
# Send data to the process | |
##################### | |
livePredictionProcess.volume = newData | |
livePredictionProcess.onStarted() | |
##################### | |
# Recieve data from the process | |
##################### | |
def CallbackToUsePredictedImages(self): | |
stdout = livePredictionProcess.readAllStandardOutput().data() | |
livePredictionProcess.useProcessOutput(stdout) | |
y = livePredictionProcess.output['prediction'] | |
# Do something with the outputs (update segmented image array, etc.) | |
##################### | |
# Stop the process | |
##################### | |
livePredictionProcess.setActive(False) # Disable the active byte | |
livePredictionProcess.onStarted() # Send the modified byte to the process to terminate it |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import sys | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model | |
ACTIVE = bytes([1]) | |
NOT_ACTIVE = bytes([0]) | |
# Do first read at the same time that we instantiate the model, this ensures that the TF graph is written/ready right away when we start live predictions. | |
input = sys.stdin.buffer.read(1) # Read control byte | |
if input == ACTIVE: | |
input_length = sys.stdin.buffer.read(8) # Read data length | |
input_length = int.from_bytes(input_length, byteorder='big') | |
input_data = sys.stdin.buffer.read(input_length) # Read the data | |
elif input == NOT_ACTIVE: | |
sys.exit() | |
input_data = pickle.loads(input_data) | |
model = load_model(input_data['model_path'], compile=False) | |
model.call = tf.function(model.call, experimental_relax_shapes=True) # If you are doing Batch Size == 1 predictions, this tends to speed things up (wrapping .call in a @tf.function decorator) | |
_ = model(input_data['volume'], training=False) # Run dummy prediction on first (blanked) volume to instantiate TF graph | |
while True: | |
input = sys.stdin.buffer.read(1) # Read control byte | |
if input == ACTIVE: | |
input_length = sys.stdin.buffer.read(8) # Read data length | |
input_length = int.from_bytes(input_length, byteorder='big') | |
input_data = sys.stdin.buffer.read(input_length) # Read the data | |
if input == NOT_ACTIVE: | |
break | |
input_data = pickle.loads(input_data) | |
output = {} | |
output['prediction'] = model(input_data['volume'], training=False).numpy() | |
sys.stdout.buffer.write(pickle.dumps(output)) | |
sys.stdout.buffer.flush() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from Processes import Process, ProcessesLogic | |
import pickle | |
class OfflinePredictionProcess(Process): | |
def __init__(self, scriptPath, volume, model_path): | |
Process.__init__(self, scriptPath) | |
self.volume = volume # Numpy array, to use as input for the model. | |
self.model_path = model_path # Path to the TF model you'd like to load, as TF Models are not picklable. | |
self.name = f"OfflinePrediction-{os.path.basename(model_path)}" | |
self.output = None | |
def prepareProcessInput(self): | |
input = {} | |
input['volume'] = self.volume | |
input['model_path'] = self.model_path | |
with open('data.pkl', 'wb') as f: | |
pickle.dump(input, f) | |
def useProcessOutput(self, processOutput): | |
output = pickle.loads(processOutput) | |
os.remove('data.pkl') | |
self.output = output | |
##################### | |
# Running the process | |
##################### | |
scriptFolder = slicer.modules.YOURMODULEHERE.path.replace('YOURMODULEHERE.py', '/Resources/ProcessScripts/') | |
scriptPath = os.path.join(scriptFolder, "OfflinePrediction.slicer.py") | |
predictionProcess = OfflinePredictionProcess(scriptPath, images, modelPath) | |
def onOfflinePredictProcessCompleted(): | |
logging.info('Offline Prediction: Process Finished') | |
y = predictionProcess.output["prediction"] | |
# Do something with the outputs (update segmented volume, use as input for volume reconstruction, etc.) | |
logic = ProcessesLogic(completedCallback=lambda : onOfflinePredictProcessCompleted()) | |
logic.addProcess(predictionProcess) | |
logic.run() | |
logging.info('Offline Prediction: Process Started') | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pickle | |
import sys | |
from tensorflow.keras.models import load_model | |
with open('data.pkl', 'rb') as f: | |
input = pickle.load(f) | |
model = load_model(input["model_path"], compile=False) | |
pred = model.predict(input["volume"]) | |
output = {} | |
output['prediction'] = pred | |
sys.stdout.buffer.write(pickle.dumps(output)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment