Created
September 7, 2016 04:57
-
-
Save bigsnarfdude/4cbf171dd1ab0c13edfad3eaa5506745 to your computer and use it in GitHub Desktop.
[self-driving-car] ros tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import rospy | |
from sensor_msgs.msg import Image | |
from std_msgs.msg import String | |
from cv_bridge import CvBridge | |
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.models.image.imagenet import classify_image | |
class RosTensorFlow(): | |
def __init__(self): | |
classify_image.maybe_download_and_extract() | |
self._session = tf.Session() | |
classify_image.create_graph() | |
self._cv_bridge = CvBridge() | |
self._sub = rospy.Subscriber('image', Image, self.callback, queue_size=1) | |
self._pub = rospy.Publisher('result', String, queue_size=1) | |
self.score_threshold = rospy.get_param('~score_threshold', 0.1) | |
self.use_top_k = rospy.get_param('~use_top_k', 5) | |
def callback(self, image_msg): | |
cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "bgr8") | |
# copy from | |
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/imagenet/classify_image.py | |
image_data = cv2.imencode('.jpg', cv_image)[1].tostring() | |
# Creates graph from saved GraphDef. | |
softmax_tensor = self._session.graph.get_tensor_by_name('softmax:0') | |
predictions = self._session.run( | |
softmax_tensor, {'DecodeJpeg/contents:0': image_data}) | |
predictions = np.squeeze(predictions) | |
# Creates node ID --> English string lookup. | |
node_lookup = classify_image.NodeLookup() | |
top_k = predictions.argsort()[-self.use_top_k:][::-1] | |
for node_id in top_k: | |
human_string = node_lookup.id_to_string(node_id) | |
score = predictions[node_id] | |
if score > self.score_threshold: | |
rospy.loginfo('%s (score = %.5f)' % (human_string, score)) | |
self._pub.publish(human_string) | |
def main(self): | |
rospy.spin() | |
if __name__ == '__main__': | |
rospy.init_node('rostensorflow') | |
tensor = RosTensorFlow() | |
tensor.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment