Last active
September 12, 2019 08:53
-
-
Save onlined/368f900dacd1dbcf4b7ac07e827306ed to your computer and use it in GitHub Desktop.
Removes unnecessary None return annotation from __init__ methods
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from typing import List | |
import ast | |
import sys | |
def is_annotated_with_none_return(func: ast.FunctionDef) -> bool: | |
return (isinstance(func.returns, ast.NameConstant) and | |
func.returns.value is None) | |
class _InitFinder(ast.NodeVisitor): | |
def __init__(self) -> None: # -> None is written for test | |
super().__init__() | |
# Actually Literal['C', 'F'], but not in the standard library | |
self.class_and_function_stack: List[str] = [] | |
self.inits: List[ast.FunctionDef] = [] | |
def in_class_scope(self) -> bool: | |
return bool(self.class_and_function_stack and | |
self.class_and_function_stack[-1] == 'C') | |
def is_init(self, func: ast.FunctionDef) -> bool: | |
return func.name == '__init__' and self.in_class_scope() | |
def visit_ClassDef(self, cls: ast.ClassDef): | |
self.class_and_function_stack.append('C') | |
self.generic_visit(cls) | |
self.class_and_function_stack.pop() | |
def visit_FunctionDef(self, func: ast.FunctionDef): | |
if self.is_init(func) and is_annotated_with_none_return(func): | |
self.inits.append(func) | |
self.class_and_function_stack.append('F') | |
self.generic_visit(func) | |
self.class_and_function_stack.pop() | |
def get_inits(module: ast.Module) -> List[ast.FunctionDef]: | |
finder = _InitFinder() | |
finder.visit(module) | |
return finder.inits | |
def remove_return_annotation(line: str, offset: int) -> str: | |
rparen_index = line.rfind(')', 0, offset) | |
colon_index = line.find(':', offset) | |
return line[:rparen_index+1] + line[colon_index:] | |
def main() -> None: | |
source = sys.stdin.read() | |
tree = ast.parse(source, '<stdin>') | |
lines = source.splitlines() | |
for init in get_inits(tree): | |
ann = init.returns | |
assert ann is not None | |
line = lines[ann.lineno - 1] | |
offset = ann.col_offset | |
lines[ann.lineno - 1] = remove_return_annotation(line, offset) | |
print('\n'.join(lines)) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment