Skip to content

Instantly share code, notes, and snippets.

@cleoold
Created March 24, 2021 02:47
Show Gist options
  • Save cleoold/6db17392b33de59c10303c6337eb692f to your computer and use it in GitHub Desktop.
Save cleoold/6db17392b33de59c10303c6337eb692f to your computer and use it in GitHub Desktop.
Convert type annotations to new style
# requires Python 3.9+ for unparse()
import ast
from typing import cast
class MyTransformer(ast.NodeTransformer):
def visit_Subscript(self, node: ast.Subscript) -> ast.AST:
result = node
name = cast(ast.Name, node.value)
if name.id in ('List', 'Set', 'Tuple', 'Dict'):
name.id = name.id.lower()
elif name.id == 'Union':
# slice is Tuple:
if isinstance(node.slice, ast.Tuple):
elts = iter(node.slice.elts)
# make chained | expressions
left_elt = next(elts)
for elt in elts:
left_elt=ast.BinOp(
left=left_elt,
op=ast.BitOr(),
right=elt,
)
result = left_elt
# otherwise slice is single, do not change
elif name.id == 'Optional':
result = ast.BinOp(
left=node.slice,
op=ast.BitOr(),
right=ast.Constant(value=None),
)
elif name.id == 'Callable':
spec = cast(ast.Tuple, node.slice).elts
if isinstance(spec[0], ast.List):
params = spec[0].elts
# args is ...
else:
params = [ast.Starred(value=ast.Name(id='Any')), ast.Name(id='**Any')]
return_t = spec[1]
result = ast.FunctionType(
argtypes=params,
returns=return_t,
)
self.generic_visit(result)
return result
# hack to avoid ambiguities in "() -> X | Y"
# https://bugs.python.org/issue43609
class MyUnparser(ast._Unparser):
def visit_FunctionType(self, node: ast.FunctionType):
with self.delimit("(", ")"):
self.interleave(
lambda: self.write(", "), self.traverse, node.argtypes
)
self.write(" -> ")
# add paren when return type is a union (bit or)
need_paren = isinstance(node.returns, ast.BinOp) and isinstance(node.returns.op, ast.BitOr)
with self.delimit_if("(", ")", need_paren):
self.traverse(node.returns)
def my_conv(s: str) -> str:
'Converts type annotation to new styles.'
return MyUnparser().visit(MyTransformer().visit(ast.parse(s)))
def test():
assert my_conv('Union[List[int], Tuple[int], Set[int], Dict[str, int]]') == 'list[int] | tuple[int] | set[int] | dict[str, int]'
assert my_conv('Optional[str]') == 'str | None'
assert my_conv('Callable[..., str]') == '(*Any, **Any) -> str'
assert my_conv('Callable[[int, str], Callable[[str], Callable[[], None]]]') == '(int, str) -> (str) -> () -> None'
assert my_conv('Union[Callable[[], Optional[str]], str, None]') == '() -> (str | None) | str | None'
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment