Skip to content

Instantly share code, notes, and snippets.

@weimzh
Last active July 26, 2023 04:10
Show Gist options
  • Save weimzh/6a95a73ff2f9c6fc06f19a5d31795e80 to your computer and use it in GitHub Desktop.
Save weimzh/6a95a73ff2f9c6fc06f19a5d31795e80 to your computer and use it in GitHub Desktop.
bag2cyber.py
#!/usr/bin/env python3
#
# Convert entire rosbag to Apollo Cyber Record format.
#
# Copyright (c) 2023, Wei Mingzhi <whistler_wmz@users.sf.net>.
#
# SPDX-License-Identifier: BSD-3-Clause
# https://spdx.org/licenses/BSD-3-Clause.html
#
import sys
import rosbag
import sys
import os
import re
from google.protobuf.descriptor_pb2 import DescriptorProto, FieldDescriptorProto, FileDescriptorProto
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.message_factory import GetMessages
from cyber_record import record
basetype_map = {
'bool': FieldDescriptorProto.TYPE_BOOL,
'int8': FieldDescriptorProto.TYPE_SINT32,
'uint8': FieldDescriptorProto.TYPE_UINT32,
'int16': FieldDescriptorProto.TYPE_SINT32,
'uint16': FieldDescriptorProto.TYPE_UINT32,
'int32': FieldDescriptorProto.TYPE_SINT32,
'uint32': FieldDescriptorProto.TYPE_UINT32,
'int64': FieldDescriptorProto.TYPE_SINT64,
'uint64': FieldDescriptorProto.TYPE_UINT64,
'float32': FieldDescriptorProto.TYPE_FLOAT,
'float64': FieldDescriptorProto.TYPE_DOUBLE,
'string': FieldDescriptorProto.TYPE_STRING,
'time': FieldDescriptorProto.TYPE_UINT64,
'duration': FieldDescriptorProto.TYPE_SINT64
}
def read_bag_types(bag_path):
bag = rosbag.Bag(bag_path)
types = {}
topics = {}
for topic, msg, t in bag.read_messages():
if not msg._type in types:
types[msg._type] = msg._full_text
topics[topic] = msg._type
bag.close()
packages = {}
for t in types:
l = types[t].split('\n')
msg_package = t.split('/')[0]
if msg_package not in packages:
packages[msg_package] = {}
msg_name = t.split('/')[1]
packages[msg_package][msg_name] = {}
for line in l:
if line.startswith('==='):
continue
elif line.startswith('MSG: '):
m = line.split(' ')[1]
msg_package = m.split('/')[0]
if msg_package not in packages:
packages[msg_package] = {}
msg_name = m.split('/')[1]
packages[msg_package][msg_name] = {}
else:
line = line.split('#')[0].strip()
if len(line) == 0:
continue
if '=' in line:
# const value, ignore for now
continue
msg_type = re.split('[ \t]', line)[0]
msg_member = re.split('[ \t]', line)[-1]
packages[msg_package][msg_name][msg_member] = msg_type
return (packages, topics)
def generate_proto(packages):
fd = FileDescriptorProto()
fd.name = 'apollo'
fd.package = 'apollo'
for package in packages:
for tp in packages[package]:
msg_type = fd.message_type.add()
msg_type.name = package + '__' + tp
for field in packages[package][tp]:
field_type = packages[package][tp][field]
field_proto = msg_type.field.add()
field_proto.name = field
field_proto.number = len(msg_type.field)
if '[' in field_type:
field_proto.label = FieldDescriptorProto.LABEL_REPEATED
field_type = field_type.split('[')[0]
else:
field_proto.label = FieldDescriptorProto.LABEL_OPTIONAL
if field_type in basetype_map:
if field_type == 'uint8' and field_proto.label == FieldDescriptorProto.LABEL_REPEATED:
# use bytes for uint8 array
field_proto.label = FieldDescriptorProto.LABEL_OPTIONAL
field_proto.type = FieldDescriptorProto.TYPE_BYTES
else:
field_proto.type = basetype_map[field_type]
else:
field_proto.type = FieldDescriptorProto.TYPE_MESSAGE
if '/' in field_type:
field_proto.type_name = field_type.replace('/', '__')
else:
if field_type in packages[package]:
field_proto.type_name = package + '__' + field_type
else:
field_proto.type_name = 'std_msgs__' + field_type
return fd
def convert_to_rosmsg(packages, msg_classes, ros_msg):
tp = ros_msg._type
msg_package, msg_tp = tp.split('/')
protobuf_tp = tp.replace('/', '__')
proto_msg = msg_classes['apollo.' + protobuf_tp]()
msg_struct = packages[msg_package][msg_tp]
for member in msg_struct:
member_type = msg_struct[member]
if member_type in basetype_map:
if member_type == 'time':
proto_msg.__setattr__(member, ros_msg.__getattribute__(member).to_nsec())
elif member_type == 'duration':
proto_msg.__setattr__(member, ros_msg.__getattribute__(member).to_nsec())
else:
proto_msg.__setattr__(member, ros_msg.__getattribute__(member))
elif '[' in member_type:
array_type = member_type.split('[')[0]
if array_type in basetype_map:
arr = ros_msg.__getattribute__(member)
if array_type == 'uint8':
proto_msg.__setattr__(member, ros_msg.__getattribute__(member))
else:
for elem in arr:
if array_type == 'time':
proto_msg.__getattribute__(member).append(elem.to_nsec())
elif array_type == 'duration':
proto_msg.__getattribute__(member).append(elem.to_nsec())
else:
proto_msg.__getattribute__(member).append(elem)
else:
arr = ros_msg.__getattribute__(member)
for elem in arr:
proto_msg.__getattribute__(member).add().CopyFrom(convert_to_rosmsg(packages, msg_classes, elem))
else:
proto_msg.__getattribute__(member).CopyFrom(convert_to_rosmsg(packages, msg_classes, ros_msg.__getattribute__(member)))
return proto_msg
def convert_bag_to_record(in_bag, out_record_path, ignore_sensor=False):
packages, topics = read_bag_types(in_bag)
fd = generate_proto(packages)
fd_str = fd.SerializeToString()
out_record = record.RecordWriter()
out_record.open(out_record_path)
for topic in topics:
if topics[topic].startswith('sensor_msgs/') and ignore_sensor:
continue
out_record.write_channel(topic, topics[topic], fd_str)
message_classes = GetMessages([fd])
bag = rosbag.Bag(in_bag)
for topic, msg, t in bag.read_messages():
if msg._type.startswith('sensor_msgs/') and ignore_sensor:
continue
converted = convert_to_rosmsg(packages, message_classes, msg)
out_record.write_message(topic, converted.SerializeToString(), t.to_nsec())
out_record.close()
if __name__ == '__main__':
convert_bag_to_record(sys.argv[1], sys.argv[2])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment