Skip to content

Instantly share code, notes, and snippets.

@nankeen nankeen/object_detector.py Secret
Created Jun 17, 2018

Embed
What would you like to do?
from disco.bot import Plugin
from io import BytesIO
from PIL import Image, ImageDraw
import requests
import tensorflow as tf
import numpy as np
class ObjectDetector(Plugin):
format_map = {
'JPEG': 'jpg',
'PNG': 'png',
'JPEG 2000': 'jpg'
}
def load(self, ctx):
super(ObjectDetector, self).load(ctx)
self.load_detection_graph()
self.generate_tensor_dict()
def load_detection_graph(self, model_path='model/frozen_inference_graph.pb'):
'''
Loads the file specified by `model_path` into self.graph
'''
self.graph = tf.Graph()
with self.graph.as_default():
# Creates a graph def instance, this is a representation of the graph definitions
graph_def = tf.GraphDef()
with tf.gfile.GFile(model_path, 'rb') as f:
serialized_graph = f.read()
# Parse the file as a graph def and import it into the detection graph
graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(graph_def, name='')
def generate_tensor_dict(self, tensors=[
'num_detections',
'detection_boxes',
'detection_scores',
'detection_classes'
]):
'''
Reads all the tensors into a dictionary like so
{
'tensor_name': <tensor object>
}
'''
self.tensor_dict = {}
with self.graph.as_default():
graph = tf.get_default_graph()
ops = graph.get_operations()
all_tensor_names = {output.name for op in ops for output in op.outputs}
for key in tensors:
tensor_name = key + ':0'
if tensor_name in all_tensor_names:
self.tensor_dict[key] = graph.get_tensor_by_name(tensor_name)
def load_image_into_numpy_array(self, image):
# Convert PIL image into a numpy array like image
(im_width, im_height) = image.size
return np.array(image.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
def run_inference(self, image):
with self.graph.as_default():
with tf.Session() as sess:
# Gets the image tensor object
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
# Do forward pass for the output in tensor dict
output = sess.run(self.tensor_dict, feed_dict={
image_tensor: np.expand_dims(image, 0)
})
return output
def draw_bounding_boxes(
self,
image,
inference,
threshold=0.5
):
# Create a draw interface
draw = ImageDraw.Draw(image)
for box, score in zip(inference['detection_boxes'][0],
inference['detection_scores'][0]):
# Scale the bounding box coordinates
p1 = tuple(box[:2][::-1] * image.size)
p2 = tuple(box[2:][::-1] * image.size)
if score > threshold:
# Draw a red rectange
draw.rectangle([p1, p2], outline=(255, 0, 0))
# Discards the draw interface
del draw
return image
def load_image_from_url(self, url):
response = requests.get(url)
file = BytesIO(response.content)
image = Image.open(file)
if image.format not in self.format_map.keys():
raise OSError('Unrecognized format')
return image
def create_attachment_from_image(self, image):
file = BytesIO()
image.save(file, 'PNG')
file.seek(0)
return file
@Plugin.command('!loss', '<link:str...>')
def command_detect_loss(self, event, link):
try:
image = self.load_image_from_url(link)
except OSError as e:
event.msg.reply("I can't find a JPEG or PNG image at the link you've given")
return
# Run inference for the image
image_numpy = self.load_image_into_numpy_array(image)
inference = self.run_inference(image_numpy)
image = self.draw_bounding_boxes(image, inference)
# Send the image with bounding boxes back
file = self.create_attachment_from_image(image)
filename = 'loss_inference.png'
event.msg.reply(
"I found the following image at the link you've given.",
attachments=[(filename, file)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.