Skip to content

Instantly share code, notes, and snippets.

@helinwang
Last active July 3, 2020 23:28
Show Gist options
  • Save helinwang/91c5532d78664920ee8395afc19d75a2 to your computer and use it in GitHub Desktop.
Save helinwang/91c5532d78664920ee8395afc19d75a2 to your computer and use it in GitHub Desktop.
TensorFlow decode arbitrary Protobuf using tf.io.decode_proto and TFRecordDataset.
import tensorflow as tf
from google.protobuf.descriptor_pb2 import FileDescriptorSet
from google.protobuf.descriptor_pb2 import FileDescriptorProto
import baz_pb2
def decode(x):
proto = FileDescriptorProto()
baz_pb2.DESCRIPTOR.CopyToProto(proto)
ret = tf.io.decode_proto(
x,
message_type="foo.bar.Baz",
field_names=["a", "b", "c", "d"],
output_types=[tf.float32, tf.int32, tf.float32, tf.float32],
descriptor_source=(
b"bytes://" + FileDescriptorSet(file=[proto]).SerializeToString()
),
)
return ret.values[0], ret.values[1], ret.values[2], ret.values[3]
def main():
dataset = tf.data.TFRecordDataset(
"output.recordio",
buffer_size=10000,
num_parallel_reads=tf.data.experimental.AUTOTUNE,
)
dataset = dataset.map(decode)
for data in dataset:
print(data)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment