Last active
June 17, 2022 01:39
-
-
Save Desdaemon/226c44e45dd39c9b2b4d4964eff2f447 to your computer and use it in GitHub Desktop.
Create Python Stubs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# pyright: basic | |
import re | |
import sys | |
import os | |
import ast as a | |
from typing import Callable, Generic, Iterable, Iterator, TypeVar | |
from multiprocessing import Pool | |
from pathlib import Path | |
T = TypeVar('T') | |
U = TypeVar('U') | |
def eprint(*args): | |
print(*args, file=sys.stderr) | |
ELLIPSIS = a.Ellipsis() | |
STUB: list[a.stmt] = [a.Expr(ELLIPSIS)] | |
def main(): | |
files = map(os.path.realpath, | |
filter(lambda x: x.endswith('.py'), _resolve_files())) | |
with Pool() as pool: | |
pool.map(generate_file_stub, files) | |
def generate_file_stub(file: str): | |
outpath = Path(file).with_suffix('.pyi') | |
with open(file, 'r') as f: | |
contents = f.read() | |
try: | |
ast = a.parse(contents, file, type_comments=True) | |
except Exception as err: | |
eprint(f'[ERR] Could not parse {file} due to:\n{err}') | |
return | |
for tree in a.walk(ast): | |
stubify(tree) | |
ast.body = [ | |
*a.parse( | |
'from typing import *\n' | |
'from typing_extensions import *\n').body, *ast.body] | |
with open(outpath, 'w') as f: | |
f.write(a.unparse(ast)) | |
def stubify(tree: a.AST): | |
match tree: | |
case a.FunctionDef() | a.AsyncFunctionDef(): | |
stubify_function(tree) | |
case a.AnnAssign(): | |
tree.value = ELLIPSIS | |
case a.ClassDef(): | |
stubify_class(tree) | |
case _: | |
return | |
def annotate_function(func: a.FunctionDef | a.AsyncFunctionDef, | |
docstring: str): | |
type_anns, rtype = parse_docstring(docstring) | |
if rtype and not func.returns: | |
func.returns = a.Str(preprocesss_type(rtype[1])) | |
if type_anns: | |
for arg in func.args.args: | |
type_ = type_anns.pop(arg.arg, None) | |
if type_ and not arg.type_comment: | |
arg.annotation = a.Str(preprocesss_type(type_)) | |
def stubify_function(func: a.FunctionDef | a.AsyncFunctionDef): | |
if func.name != '__init__': | |
docstring = a.get_docstring(func) | |
if docstring: | |
annotate_function(func, docstring) | |
func.body = STUB | |
def stubify_class(cls: a.ClassDef): | |
docstring = a.get_docstring(cls) | |
if docstring: | |
cls.body = cls.body[1:] | |
for decl in cls.body: | |
match decl: | |
case a.FunctionDef(name='__init__'): | |
annotate_function(decl, docstring) | |
return | |
case _: | |
continue | |
if not cls.body: | |
cls.body = STUB | |
def preprocesss_type(raw: str): | |
return re.subn(r'(:class:|[~`])', '', raw.strip())[0] | |
ANN_PATTERN = re.compile(r':param\b(.+)\b(.+):') | |
TYPE_PATTERN = re.compile(r':type (.+): (.+)') | |
def parse_docstring(src: str): | |
return ( | |
{match[1]: match[0].strip() for match in ANN_PATTERN.findall(src)} | |
| {match[0]: match[1].strip() for match in TYPE_PATTERN.findall(src)}, | |
re.search(r':rtype:(.+)', src) | |
) | |
class filter_map(Generic[T, U]): | |
__slots__ = ['mapper', 'iter'] | |
def __init__(self, mapper: Callable[[T], U | None], iterable: Iterable): | |
self.mapper = mapper | |
self.iter = iter(iterable) | |
def __iter__(self) -> Iterator[U]: return self | |
def __next__(self) -> U: | |
val = self.mapper(next(self.iter)) | |
if val: | |
return val | |
else: | |
return next(self) | |
def _resolve_files(): | |
if len(sys.argv) < 2: | |
return | |
for root_path in sys.argv[1:]: | |
if os.path.isfile(root_path): | |
yield root_path | |
else: | |
for path, _, filenames in os.walk(root_path): | |
for filename in filenames: | |
yield os.path.join(path, filename) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment