Skip to content

Instantly share code, notes, and snippets.

@WuStangDan
Created October 24, 2017 23:38
Show Gist options
  • Save WuStangDan/f9cb0c4cda925dd3bd892fbf52f9e3e6 to your computer and use it in GitHub Desktop.
Save WuStangDan/f9cb0c4cda925dd3bd892fbf52f9e3e6 to your computer and use it in GitHub Desktop.
class TrafficLightClassifier(object):
def __init__(self):
PATH_TO_MODEL = 'frozen_inference_graph.pb'
self.detection_graph = tf.Graph()
with self.detection_graph.as_default():
od_graph_def = tf.GraphDef()
# Works up to here.
with tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0')
self.sess = tf.Session(graph=self.detection_graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment