Skip to content

Instantly share code, notes, and snippets.

@sobolevn
Created June 28, 2023 18:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sobolevn/967ac61b8b2575e26a33f5a6bf5df201 to your computer and use it in GitHub Desktop.
Save sobolevn/967ac61b8b2575e26a33f5a6bf5df201 to your computer and use it in GitHub Desktop.
import ast
import argparse
import os
import sys
import tokenize
from dataclasses import dataclass
from typing import Final, TypeAlias
_AnyFunction: TypeAlias = ast.FunctionDef | ast.AsyncFunctionDef
_EXCLUDE_FILES: Final = frozenset({
'Lib/test/badsyntax_3131.py',
'Lib/test/badsyntax_pep3120.py',
'Lib/test/bad_coding2.py',
})
@dataclass(frozen=True, slots=True)
class MethodDuplicate:
node: _AnyFunction
class_name: str
def report_error(self) -> str:
return f'Found duplicate method {self.node.name} in {self.class_name} class'
@dataclass(frozen=True, slots=True)
class ClassDuplicate:
node: ast.ClassDef
def report_error(self) -> str:
return f'Found duplicate {self.node.name} class'
class ClassSpec:
def __init__(self, class_name: str) -> None:
self.name = class_name
self.methods: set[set] = set()
def add_method(self, method: _AnyFunction) -> MethodDuplicate | None:
if method.name in self.methods:
return MethodDuplicate(method, self.name)
self.methods.add(method.name)
return None
class TestMethodDuplicatesVisitor(ast.NodeVisitor):
_METHOD_PREFIX: Final = 'test_'
def __init__(self) -> None:
self.current_class: ClassSpec | None = None
self.classes: set[str] = set()
self.duplicates: list[ClassDuplicate | MethodDuplicate] = []
def visit_ClassDef(self, node: ast.ClassDef) -> None:
if isinstance(node._parent, ast.Module):
if node.name in self.classes:
self.duplicates.append(ClassDuplicate(node))
self.classes.add(node.name)
self.current_class = ClassSpec(node.name)
self.generic_visit(node)
def visit_FunctionDef(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
if (
self.current_class
and isinstance(node._parent, ast.ClassDef)
and node.name.startswith(self._METHOD_PREFIX)
):
result = self.current_class.add_method(node)
if result is not None:
self.duplicates.append(result)
self.generic_visit(node)
visit_AsyncFunctionDef = visit_FunctionDef
def _read_file(filename: str) -> str:
# Taken from https://github.com/PyCQA/flake8/
try:
with tokenize.open(filename) as fd:
return fd.read()
except (SyntaxError, UnicodeError):
# If we can't detect the codec with tokenize.detect_encoding, or
# the detected encoding is incorrect, just fallback to latin-1.
with open(filename, encoding='latin-1') as fd:
return fd.read()
def _set_parents(tree: ast.AST) -> ast.AST:
for statement in ast.walk(tree):
for child in ast.iter_child_nodes(statement):
setattr(child, '_parent', statement)
return tree
def main() -> None:
parser = argparse.ArgumentParser('Find duplicate test methods')
parser.add_argument('dir', type=str)
args = parser.parse_args()
found_any_duplicates = False
for dirpath, _, files in os.walk(args.dir):
for file in files:
if not file.endswith('.py'):
continue
full_path = os.path.join(dirpath, file)
if full_path in _EXCLUDE_FILES:
continue
source = ast.parse(_read_file(full_path), filename=full_path)
visitor = TestMethodDuplicatesVisitor()
visitor.visit(_set_parents(source))
for dup in visitor.duplicates:
found_any_duplicates = True
print(f'{full_path}:{dup.node.lineno} {dup.report_error()}')
sys.exit(found_any_duplicates)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment