-
-
Save trianta2/fd04bdbfc9bdef5631c0d76582a04aca to your computer and use it in GitHub Desktop.
protobuf solution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Example solution | |
""" | |
from collections import namedtuple | |
from google.protobuf.descriptor import FieldDescriptor | |
_type_dict = {v: k for k, v in vars(FieldDescriptor).items() if k.startswith('TYPE_')} | |
Repeated = namedtuple('Repeated', ['value']) | |
Map = namedtuple('Map', ['key', 'value']) | |
def _field_type(field, context): | |
'''Helper that returns either a str or nametuple corresponding to the field type''' | |
if field.message_type is not None: | |
return message_as_namedtuple(field.message_type, context) | |
else: | |
return _type_dict[field.type] | |
def field_type(field, context): | |
'''Returns the protobuf type for a given field descriptor | |
A Repeated, Map, or str object may be returned. Strings correspond to protobuf types. | |
''' | |
if field.label == FieldDescriptor.LABEL_REPEATED: | |
msg_type = field.message_type | |
is_map = msg_type is not None and msg_type.GetOptions().map_entry | |
if is_map: | |
key = _field_type(field.message_type.fields[0], context) | |
value = _field_type(field.message_type.fields[1], context) | |
return Map(key, value) | |
else: | |
value = _field_type(field, context) | |
return Repeated(value) | |
else: | |
return _field_type(field, context) | |
def message_as_namedtuple(descr, context): | |
'''Returns a namedtuple corresponding to the given message descriptor''' | |
name = descr.name | |
if name not in context: | |
Msg = namedtuple(name, [f.name for f in descr.fields]) | |
context[name] = Msg(*(field_type(f, context) for f in descr.fields)) | |
return context[name] | |
def module_msgs(module): | |
'''Returns a dict of {message name: namedtuple} from a given protobuf module''' | |
context = dict() | |
return {k: message_as_namedtuple(v, context) | |
for k, v in module.DESCRIPTOR.message_types_by_name.items()} | |
def is_message(field): | |
'''Helper that returns True if a field is a custom message type''' | |
return isinstance(field, tuple) | |
if __name__ == '__main__': | |
'''Main''' | |
import test_pb2 | |
msgs = module_msgs(test_pb2) | |
Data1 = msgs['Data1'] | |
Data2 = msgs['Data2'] | |
assert is_message(Data2.a) | |
assert Data2.a is Data1 | |
assert isinstance(Data2.b, Map) | |
assert is_message(Data2.b.value) | |
assert isinstance(Data2.c, Map) | |
assert not is_message(Data2.c.value) | |
assert isinstance(Data2.d, Repeated) | |
assert is_message(Data2.d.value) | |
assert isinstance(Data2.e, Repeated) | |
assert not is_message(Data2.e.value) | |
print(msgs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
syntax = "proto3"; | |
message Data1 { | |
double a = 1; | |
float b = 2; | |
int32 c = 3; | |
int64 d = 4; | |
bool e = 5; | |
string f = 6; | |
bytes g = 7; | |
} | |
message Data2 { | |
Data1 a = 1; | |
map<string, Data1> b = 2; | |
map<string, int32> c = 3; | |
repeated Data1 d = 4; | |
repeated int32 e = 5; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello, are you still active here? I stumbled upon your code and it's gonna be really helpful in my project, but I have some questions I want to ask you if possible.