Skip to content

Instantly share code, notes, and snippets.

@BigNerd
Created June 27, 2020 12:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BigNerd/5df828d23a72fc2cc192bd4d937bad2f to your computer and use it in GitHub Desktop.
Save BigNerd/5df828d23a72fc2cc192bd4d937bad2f to your computer and use it in GitHub Desktop.
A wrapper for thread safe execution of Keras model prediction when using Tensorflow as backend
import threading
from keras.models import Model
import keras
class ThreadSafeModel:
def __init__(self, model: Model):
self.model = model
self.lock = threading.Lock()
self.session = keras.backend.get_session() # store session that was used when creating or loading the model
def predict(self, x, batch_size=None, verbose=0, steps=None, callbacks=None, max_queue_size=10, workers=1,
use_multiprocessing=False):
with self.lock:
with self.session.graph.as_default():
keras.backend.set_session(self.session)
result = self.model.predict(
x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment