Created
October 23, 2020 19:31
-
-
Save theunkn0wn1/9b2407669d607e8d15bacab6e6fb4cc9 to your computer and use it in GitHub Desktop.
Post-process protoc generated python sources to use relative imports
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import pathlib | |
import dataclasses | |
from typing import Optional, Set | |
import logging | |
@dataclasses.dataclass(frozen=True) | |
class Alias: | |
name: str | |
line: int | |
as_name: Optional[str] | |
@classmethod | |
def from_ast_alias(cls, alias: ast.alias, line: int): | |
return cls(name=alias.name, as_name=alias.asname, line=line) | |
cwd = pathlib.Path() | |
def get_imports(raw: str) -> Set[Alias]: | |
tree = ast.parse(raw) | |
import_statements = [node for node in tree.body if isinstance(node, (ast.ImportFrom, ast.Import))] | |
keyed_imports: Set[Alias] = set() | |
for node in import_statements: | |
keyed_imports.update({Alias.from_ast_alias(alias, node.lineno) for alias in node.names}) | |
return keyed_imports | |
def post_process(target: pathlib.Path) -> str: | |
""" | |
Post processes a generated gRPC python source file to use relative imports | |
Args: | |
target: path to target file | |
Returns: | |
str: contents of processed file | |
""" | |
raw = target.read_text() | |
lines = raw.split('\n') | |
aliases = get_imports(raw) | |
for alias in aliases: | |
line_number = alias.line - 1 # because humans | |
if "_pb" not in alias.name: | |
# not a generated *_pb import | |
continue | |
if "from . import" in lines[line_number]: | |
continue # already patched | |
# replace the import with a relative one | |
lines[line_number] = lines[line_number].replace("import ", "from . import ") | |
return "\n".join(lines) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment