Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
A python script that loads a tensorflow model into memory without holding the global tensor flow graph hostage
import tensorflow as tf
import numpy as np
# The category name and probability percentage
class CategoryScore:
def __init__(self, category, probability: float):
self.category = category
self.probability = probability
# The categorizer handles running tensorflow models
class Categorizer:
def __init__(self, model_file_path: str, map: []):
self.map = map
self.graph = tf.Graph()
self.graph.as_default()
self.graph_def = self.graph.as_graph_def()
with tf.gfile.GFile(model_file_path, 'rb') as f:
self.graph_def.ParseFromString(f.read())
tf.import_graph_def(self.graph_def, name='')
output_layer = 'loss:0'
self.input_node = 'Placeholder:0'
self.sess = tf.Session()
self.prob_tensor = self.sess.graph.get_tensor_by_name(output_layer)
tf.reset_default_graph()
def score(self, image):
predictions, = self.sess.run(self.prob_tensor, {self.input_node: [image]})
label_index = 0
scores = []
for p in predictions:
category_score = CategoryScore(self.map[label_index],np.float64(np.round(p, 8)))
scores.append(category_score)
label_index += 1
return scores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.