Skip to content

Instantly share code, notes, and snippets.

@Desdaemon
Last active June 17, 2022 01:39
Show Gist options
  • Save Desdaemon/226c44e45dd39c9b2b4d4964eff2f447 to your computer and use it in GitHub Desktop.
Save Desdaemon/226c44e45dd39c9b2b4d4964eff2f447 to your computer and use it in GitHub Desktop.
Create Python Stubs
#!/usr/bin/env python
# pyright: basic
import re
import sys
import os
import ast as a
from typing import Callable, Generic, Iterable, Iterator, TypeVar
from multiprocessing import Pool
from pathlib import Path
T = TypeVar('T')
U = TypeVar('U')
def eprint(*args):
print(*args, file=sys.stderr)
ELLIPSIS = a.Ellipsis()
STUB: list[a.stmt] = [a.Expr(ELLIPSIS)]
def main():
files = map(os.path.realpath,
filter(lambda x: x.endswith('.py'), _resolve_files()))
with Pool() as pool:
pool.map(generate_file_stub, files)
def generate_file_stub(file: str):
outpath = Path(file).with_suffix('.pyi')
with open(file, 'r') as f:
contents = f.read()
try:
ast = a.parse(contents, file, type_comments=True)
except Exception as err:
eprint(f'[ERR] Could not parse {file} due to:\n{err}')
return
for tree in a.walk(ast):
stubify(tree)
ast.body = [
*a.parse(
'from typing import *\n'
'from typing_extensions import *\n').body, *ast.body]
with open(outpath, 'w') as f:
f.write(a.unparse(ast))
def stubify(tree: a.AST):
match tree:
case a.FunctionDef() | a.AsyncFunctionDef():
stubify_function(tree)
case a.AnnAssign():
tree.value = ELLIPSIS
case a.ClassDef():
stubify_class(tree)
case _:
return
def annotate_function(func: a.FunctionDef | a.AsyncFunctionDef,
docstring: str):
type_anns, rtype = parse_docstring(docstring)
if rtype and not func.returns:
func.returns = a.Str(preprocesss_type(rtype[1]))
if type_anns:
for arg in func.args.args:
type_ = type_anns.pop(arg.arg, None)
if type_ and not arg.type_comment:
arg.annotation = a.Str(preprocesss_type(type_))
def stubify_function(func: a.FunctionDef | a.AsyncFunctionDef):
if func.name != '__init__':
docstring = a.get_docstring(func)
if docstring:
annotate_function(func, docstring)
func.body = STUB
def stubify_class(cls: a.ClassDef):
docstring = a.get_docstring(cls)
if docstring:
cls.body = cls.body[1:]
for decl in cls.body:
match decl:
case a.FunctionDef(name='__init__'):
annotate_function(decl, docstring)
return
case _:
continue
if not cls.body:
cls.body = STUB
def preprocesss_type(raw: str):
return re.subn(r'(:class:|[~`])', '', raw.strip())[0]
ANN_PATTERN = re.compile(r':param\b(.+)\b(.+):')
TYPE_PATTERN = re.compile(r':type (.+): (.+)')
def parse_docstring(src: str):
return (
{match[1]: match[0].strip() for match in ANN_PATTERN.findall(src)}
| {match[0]: match[1].strip() for match in TYPE_PATTERN.findall(src)},
re.search(r':rtype:(.+)', src)
)
class filter_map(Generic[T, U]):
__slots__ = ['mapper', 'iter']
def __init__(self, mapper: Callable[[T], U | None], iterable: Iterable):
self.mapper = mapper
self.iter = iter(iterable)
def __iter__(self) -> Iterator[U]: return self
def __next__(self) -> U:
val = self.mapper(next(self.iter))
if val:
return val
else:
return next(self)
def _resolve_files():
if len(sys.argv) < 2:
return
for root_path in sys.argv[1:]:
if os.path.isfile(root_path):
yield root_path
else:
for path, _, filenames in os.walk(root_path):
for filename in filenames:
yield os.path.join(path, filename)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment