Last active
January 12, 2019 11:33
-
-
Save ryanwang522/d1dd6f62e036e3710592bc34aa338736 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import sys | |
import rospy | |
import cv2 | |
from std_msgs.msg import String | |
from sensor_msgs.msg import Image | |
import cv_bridge | |
import numpy as np | |
pub = rospy.Publisher("face_detect", String) | |
arg_classes = "yolov3.txt" | |
arg_config = "yolov3.cfg" | |
arg_weights = "yolov3.weights" | |
COLOR = None | |
def get_output_layers(net): | |
layer_names = net.getLayerNames() | |
output_layers = [layer_names[i[0] - 1] | |
for i in net.getUnconnectedOutLayers()] | |
return output_layers | |
def draw_prediction(img, classes, class_id, confidence, x, y, x_plus_w, y_plus_h): | |
label = str(classes[class_id]) | |
color = COLORS[class_id] | |
x = int(x) | |
y = int(y) | |
x_plus_w = int(x_plus_w) | |
y_plus_h = int(y_plus_h) | |
cv2.rectangle(img, (x, y), (x_plus_w, y_plus_h), color, 2) | |
cv2.putText(img, label, (x-10, y-10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
def yolo(image): | |
Width = image.shape[1] | |
Height = image.shape[0] | |
scale = 0.00392 | |
classes = None | |
with open(arg_classes, 'r') as f: | |
classes = [line.strip() for line in f.readlines()] | |
global COLORS | |
COLORS = np.random.uniform(0, 255, size=(len(classes), 3)) | |
net = cv2.dnn.readNet(arg_weights, arg_config) | |
blob = cv2.dnn.blobFromImage( | |
image, scale, (416, 416), (0, 0, 0), True, crop=False) | |
net.setInput(blob) | |
outs = net.forward(get_output_layers(net)) | |
class_ids = [] | |
confidences = [] | |
boxes = [] | |
conf_threshold = 0.5 | |
nms_threshold = 0.4 | |
for out in outs: | |
for detection in out: | |
scores = detection[5:] | |
class_id = np.argmax(scores) | |
confidence = scores[class_id] | |
if confidence > 0.5: | |
center_x = int(detection[0] * Width) | |
center_y = int(detection[1] * Height) | |
w = int(detection[2] * Width) | |
h = int(detection[3] * Height) | |
x = center_x - w / 2 | |
y = center_y - h / 2 | |
class_ids.append(class_id) | |
confidences.append(float(confidence)) | |
boxes.append([x, y, w, h]) | |
indices = cv2.dnn.NMSBoxes(boxes, confidences, conf_threshold, nms_threshold) | |
for i in indices: | |
i = i[0] | |
box = boxes[i] | |
x = box[0] | |
y = box[1] | |
w = box[2] | |
h = box[3] | |
draw_prediction(image, classes, class_ids[i], confidences[i], round( | |
x), round(y), round(x+w), round(y+h)) | |
return image, len(indices) | |
def image_callback(data): | |
bridge = cv_bridge.CvBridge() | |
try: | |
image = bridge.imgmsg_to_cv2(data, "bgr8") | |
except cv_bridge.CvBridgeError as e: | |
print(e) | |
# Convert the image from BGR color (which OpenCV uses) to RGB color (which face_recognition uses) | |
img, obj_cnt = yolo(image) | |
cv2.imshow("Image window", img) | |
if obj_cnt > 0: | |
msg = "Detected" | |
pub.Publisher(msg) | |
def extract_image(): | |
rospy.init_node("image", anonymous=True) | |
rospy.Subscriber("/camera/rgb/image_color", Image, image_callback) | |
# spin() simply keeps python from exiting until this node is stopped | |
rospy.spin() | |
if __name__ == "__main__": | |
try: | |
extract_image() | |
except rospy.ROSInterruptException: | |
print("Keyboard Interrupt") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment