Skip to content

Instantly share code, notes, and snippets.

@TakanoTaiga
Last active February 8, 2024 21:10
Show Gist options
  • Save TakanoTaiga/eae37b30c34b19fe0d595908bb2a810f to your computer and use it in GitHub Desktop.
Save TakanoTaiga/eae37b30c34b19fe0d595908bb2a810f to your computer and use it in GitHub Desktop.
i2t ros2
# Copyright 2023 Taiga Takano
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:#www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
import rclpy
from rclpy.node import Node
from sensor_msgs.msg import Image as rosimg
import cv_bridge
class MinimalSubscriber(Node):
def __init__(self):
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
super().__init__('blip_ros2')
self.subscription = self.create_subscription(
rosimg,
'/camera/color/image_raw',
self.listener_callback,
10)
self.subscription
self.bridge = cv_bridge.CvBridge()
def listener_callback(self, msg):
cv_image = self.bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
text = "a photography of"
inputs = self.processor(cv_image, text, return_tensors="pt").to("cuda", torch.float16)
out = self.model.generate(**inputs)
self.get_logger().info(self.processor.decode(out[0], skip_special_tokens=True))
def main(args=None):
rclpy.init(args=args)
minimal_subscriber = MinimalSubscriber()
rclpy.spin(minimal_subscriber)
minimal_subscriber.destroy_node()
rclpy.shutdown()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment