Skip to content

Instantly share code, notes, and snippets.

@podhmo
Created May 10, 2020
Embed
What would you like to do?
from __future__ import annotations
import typing as t
import typing_extensions as tx
import inspect
import dataclasses
import enum
import os.path
from types import ModuleType
from inflection import pluralize
from prestring import NEWLINE
from prestring import Module as _Module
from prestring.utils import reify
from prestring.codeobject import Symbol, CodeObjectModuleMixin
from metashape.analyze.walker import Walker
Command = t.Callable[..., t.Any]
def scan_module(
module: ModuleType,
*,
is_ignored: t.Optional[t.Callable[[Command], bool]] = None, # inspect.isclass
targets: t.Optional[t.List[str]] = None,
) -> t.Dict[str, Command]:
targets = targets or list(module.__dict__.keys())
defs = {}
for name in targets:
v = module.__dict__.get(name)
if v is None:
continue
if name.startswith("_"):
continue
if not hasattr(v, "__module__"):
continue
if is_ignored and is_ignored(v):
continue
if v.__module__ == __name__:
pass
elif not hasattr(v, "__origin__") and name[0] != name[0].upper():
continue
if not callable(v):
continue
defs[name] = v
return defs
Kind = tx.Literal["object", "service", "list", "enum", "unknown"]
@dataclasses.dataclass
class Item:
name: str
type_: t.Type[t.Any]
kind: Kind
# special type
class Service:
_service_attrs_cache = None
@classmethod
def _get_methods(cls):
if "_service_attrs_cache" not in cls.__dict__:
parent_cache = Service._service_attrs_cache
if parent_cache is None:
parent_cache = Service._service_attrs_cache = set(
Service.__dict__.keys()
)
cls._service_attrs_cache = list(
[k for k in cls.__dict__.keys() if k not in parent_cache]
)
return [getattr(cls, name) for name in cls._service_attrs_cache]
def walk(defs: t.Dict[str, t.Type[t.Any]]) -> t.Iterator[Item]:
for name, v in defs.items():
kind: Kind = "unknown"
if hasattr(v, "__origin__"):
kind = v.__origin__.__name__ # list
assert kind == "list"
elif v == Service:
continue
elif inspect.isclass(v):
if issubclass(v, enum.Enum):
kind = "enum"
elif issubclass(v, Service):
kind = "service"
else:
kind = "object"
yield Item(name=name, kind=kind, type_=v)
class Module(CodeObjectModuleMixin, _Module):
def sep(self) -> None:
self.body.append(NEWLINE)
@reify
def _import_area(self) -> Module:
sm = self.submodule("", newline=True)
self.stmt("")
return sm
def import_(self, path: str) -> Symbol:
im = self._import_area
im.stmt(f'import "{path}";')
prefix = os.path.dirname(path).replace("/", ".")
return self.symbol(prefix)
class TypeResolver:
def __init__(
self, m: Module, *, aliases: t.Optional[t.Dict[t.Any, Symbol]] = None
) -> None:
self.m = m
self.aliases = aliases or {}
self.aliases.update({str: "string"})
def resolve_type(self, typ: t.Type[t.Any]) -> Symbol:
if typ in self.aliases:
return self.aliases[typ]
if hasattr(typ, "PROTO_PACKAGE"):
return getattr(self.m.import_(typ.PROTO_PACKAGE), typ.__name__)
return typ.__name__
def emit(items: t.Iterator[Item], *, name: str) -> Module:
from metashape.runtime import get_walker
m = Module(indent=" ")
m.stmt('syntax = "proto3";')
m.sep()
m.stmt(f"package {name};")
m.sep()
m._import_area
items = list(items)
w = get_walker([])
classes: t.Dict[t.Type[t.Any], Item] = {}
aliases: t.Dict[t.Any, Symbol] = {}
for item in items:
if item.kind in ("object", "enum", "service"):
classes[item.type_] = item
w.append(item.type_)
if hasattr(item.type_, "PROTO_PACKAGE"):
prefix = m.import_(item.type_.PROTO_PACKAGE)
aliases[item.type_] = getattr(prefix, item.name)
else:
aliases[item.type_] = m.symbol(item.name)
resolver = TypeResolver(m, aliases=aliases)
for cls in w.walk(kinds=["object", None]):
if hasattr(cls, "PROTO_PACKAGE"):
continue
item = classes[cls]
if item.kind == "object":
emit_class(m, item, w=w, resolver=resolver)
elif item.kind == "enum":
emit_enum(m, item, w=w, resolver=resolver)
elif item.kind == "service":
emit_service(m, item, w=w, resolver=resolver)
m.sep()
for item in items:
if item.kind == "list":
emit_list(m, item, w=w, resolver=resolver)
return m
def emit_class(m: Module, item: Item, *, w: Walker, resolver: TypeResolver) -> Symbol:
name = item.name
cls = item.type_
i = 1
m.stmt(f"message {name} {{")
with m.scope():
for name, info, _metadata in w.for_type(cls).walk(ignore_private=False):
typ = resolver.resolve_type(info.type_)
m.stmt(f"{typ} {name} = {i};") # todo: deprecated
i += 1
m.stmt("}")
return m.symbol(name)
def emit_enum(m: Module, item: Item, *, w: Walker, resolver: TypeResolver) -> Symbol:
name = item.name
cls = t.cast(t.Type[enum.Enum], item.type_) # enum.Enum
m.stmt(f"enum {name} {{")
with m.scope():
for attr in cls:
m.stmt(f"{attr.name} = {attr.value};")
m.stmt("}")
return m.symbol(name)
def emit_service(m: Module, item: Item, *, w: Walker, resolver: TypeResolver) -> Symbol:
name = item.name
cls = item.type_
i = 1
m.stmt(f"service {name} {{")
with m.scope():
for method in cls._get_methods():
argspec = inspect.getfullargspec(method)
argspec.annotations.update(t.get_type_hints(method))
args = [
str(resolver.resolve_type(argspec.annotations[name]))
for name in argspec.args
if name in argspec.annotations
]
ret_type = argspec.annotations.get("return")
# todo: suppot tuple
returns = [str(resolver.resolve_type(ret_type))]
m.stmt(
f"rpc {method.__name__}({', '.join(args)}) returns ({', '.join(returns)}) {{"
)
with m.scope():
pass
m.stmt("}")
m.stmt("}")
return m.symbol(name)
def emit_list(m: Module, item: Item, *, w: Walker, resolver: TypeResolver) -> Symbol:
name = item.name
cls = t.get_args(item.type_)[0]
i = 1
m.stmt(f"message {name} {{")
with m.scope():
m.stmt(f"repeated {cls.__name__} {pluralize(cls.__name__.lower())} = {i};")
m.stmt("}")
return m.symbol(name)
--- expected.proto 2020-05-10 23:51:12.000000000 +0900
+++ result.proto 2020-05-10 23:52:30.000000000 +0900
@@ -2,7 +2,6 @@
package myapp;
-import "google/api/annotations.proto";
import "google/type/date.proto";
import "google/protobuf/empty.proto";
@@ -11,7 +10,7 @@
string first_name = 2;
string family_name = 3;
Sex sex = 4;
- uint32 age = 5 [ deprecated = true ];
+ uint32 age = 5;
google.type.Date birthday = 6;
}
@@ -22,17 +21,19 @@
OTHER = 3;
}
-message UserList { repeated User users = 1; }
-
service UserService {
rpc Get(GetRequest) returns (User) {
- option deprecated = false;
- option (google.api.http) = {
- get : "user"
- };
+
+ }
+ rpc List(google.protobuf.Empty) returns (UserList) {
+
}
- rpc List(google.protobuf.Empty) returns (UserList) {}
}
-message GetRequest { uint64 id = 1; }
+message GetRequest {
+ uint64 id = 1;
+}
+message UserList {
+ repeated User users = 1;
+}
syntax = "proto3";
package myapp;
import "google/api/annotations.proto";
import "google/type/date.proto";
import "google/protobuf/empty.proto";
message User {
uint64 id = 1;
string first_name = 2;
string family_name = 3;
Sex sex = 4;
uint32 age = 5 [ deprecated = true ];
google.type.Date birthday = 6;
}
enum Sex {
SEX_UNKNOWN = 0;
MALE = 1;
FEMALE = 2;
OTHER = 3;
}
message UserList { repeated User users = 1; }
service UserService {
rpc Get(GetRequest) returns (User) {
option deprecated = false;
option (google.api.http) = {
get : "user"
};
}
rpc List(google.protobuf.Empty) returns (UserList) {}
}
message GetRequest { uint64 id = 1; }
from __future__ import annotations
import enum
import typing as t
from typestubs import uint32, uint
from typestubs import Date, Empty
from _emit import Service
class User:
id: uint
first_name: str
family_name: str
sex: Sex
age: uint32
birthday: Date
# https://en.wikipedia.org/wiki/ISO/IEC_5218 でなければautoが使える
@enum.unique
class Sex(enum.IntEnum):
SEX_UNKNOWN = 0
MALE = 1
FEMALE = 2
OTHER = 3 # 9
UserList = t.List[User]
class UserService(Service):
def Get(self, req: GetRequest) -> User:
pass
def List(self, empty: Empty) -> UserList:
pass
class GetRequest:
id: uint
def main() -> None:
import sys
from _emit import scan_module
from _emit import walk
from _emit import emit
module = sys.modules[__name__]
defs = scan_module(module, is_ignored=lambda x: x == main)
m = emit(walk(defs), name="myapp")
print(m)
if __name__ == "__main__":
main()
default:
python main.py | tee result.proto
diff -u expected.proto result.proto > a.diff || exit 0
doc:
protoc --doc_out=html,index.html:./ result.proto
# gen:
# protoc --go_out=plugins=grpc:. helloworld.proto
syntax = "proto3";
package myapp;
import "google/type/date.proto";
import "google/protobuf/empty.proto";
message User {
uint64 id = 1;
string first_name = 2;
string family_name = 3;
Sex sex = 4;
uint32 age = 5;
google.type.Date birthday = 6;
}
enum Sex {
SEX_UNKNOWN = 0;
MALE = 1;
FEMALE = 2;
OTHER = 3;
}
service UserService {
rpc Get(GetRequest) returns (User) {
}
rpc List(google.protobuf.Empty) returns (UserList) {
}
}
message GetRequest {
uint64 id = 1;
}
message UserList {
repeated User users = 1;
}
import typing as t
def proto_package(path: str) -> t.Callable[[t.Type[t.Any]], t.Type[t.Any]]:
def _deco(cls: t.Type[t.Any]) -> t.Type[t.Any]:
cls.PROTO_PACKAGE = path
return cls
return _deco
#
# uint32,uint64などの違いがある?
uint64 = t.NewType("uint64", int)
uint32 = t.NewType("uint32", int)
uint = uint64
@proto_package("google/type/date.proto")
class Date:
pass
@proto_package("google/protobuf/empty.proto")
class Empty:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment