Created
March 24, 2021 02:47
-
-
Save cleoold/6db17392b33de59c10303c6337eb692f to your computer and use it in GitHub Desktop.
Convert type annotations to new style
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# requires Python 3.9+ for unparse() | |
import ast | |
from typing import cast | |
class MyTransformer(ast.NodeTransformer): | |
def visit_Subscript(self, node: ast.Subscript) -> ast.AST: | |
result = node | |
name = cast(ast.Name, node.value) | |
if name.id in ('List', 'Set', 'Tuple', 'Dict'): | |
name.id = name.id.lower() | |
elif name.id == 'Union': | |
# slice is Tuple: | |
if isinstance(node.slice, ast.Tuple): | |
elts = iter(node.slice.elts) | |
# make chained | expressions | |
left_elt = next(elts) | |
for elt in elts: | |
left_elt=ast.BinOp( | |
left=left_elt, | |
op=ast.BitOr(), | |
right=elt, | |
) | |
result = left_elt | |
# otherwise slice is single, do not change | |
elif name.id == 'Optional': | |
result = ast.BinOp( | |
left=node.slice, | |
op=ast.BitOr(), | |
right=ast.Constant(value=None), | |
) | |
elif name.id == 'Callable': | |
spec = cast(ast.Tuple, node.slice).elts | |
if isinstance(spec[0], ast.List): | |
params = spec[0].elts | |
# args is ... | |
else: | |
params = [ast.Starred(value=ast.Name(id='Any')), ast.Name(id='**Any')] | |
return_t = spec[1] | |
result = ast.FunctionType( | |
argtypes=params, | |
returns=return_t, | |
) | |
self.generic_visit(result) | |
return result | |
# hack to avoid ambiguities in "() -> X | Y" | |
# https://bugs.python.org/issue43609 | |
class MyUnparser(ast._Unparser): | |
def visit_FunctionType(self, node: ast.FunctionType): | |
with self.delimit("(", ")"): | |
self.interleave( | |
lambda: self.write(", "), self.traverse, node.argtypes | |
) | |
self.write(" -> ") | |
# add paren when return type is a union (bit or) | |
need_paren = isinstance(node.returns, ast.BinOp) and isinstance(node.returns.op, ast.BitOr) | |
with self.delimit_if("(", ")", need_paren): | |
self.traverse(node.returns) | |
def my_conv(s: str) -> str: | |
'Converts type annotation to new styles.' | |
return MyUnparser().visit(MyTransformer().visit(ast.parse(s))) | |
def test(): | |
assert my_conv('Union[List[int], Tuple[int], Set[int], Dict[str, int]]') == 'list[int] | tuple[int] | set[int] | dict[str, int]' | |
assert my_conv('Optional[str]') == 'str | None' | |
assert my_conv('Callable[..., str]') == '(*Any, **Any) -> str' | |
assert my_conv('Callable[[int, str], Callable[[str], Callable[[], None]]]') == '(int, str) -> (str) -> () -> None' | |
assert my_conv('Union[Callable[[], Optional[str]], str, None]') == '() -> (str | None) | str | None' | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment