Skip to content

Instantly share code, notes, and snippets.

@agmond
Last active January 25, 2021 18:48
Show Gist options
  • Save agmond/9240c74f5ff88497eec7f6d2949a38fe to your computer and use it in GitHub Desktop.
Save agmond/9240c74f5ff88497eec7f6d2949a38fe to your computer and use it in GitHub Desktop.
Fix `super()` calls after migration to Python3
"""
The following script fixes `super()` calls after migration from Python 2 to Python 3.
The script edits the code automatically. Upon completion, the following steps are recommended:
- Search for the regex `super\([^\)]` and fix manually those places (if needed)
- Search for the regex `super\(\s[^\)]` and fix manually those places (if needed)
- Run Flake8 and manually fix styling problems
"""
import ast
import linecache
from pathlib import Path
root_path = Path('/path/to/main/directory/')
files = root_path.glob('**/*.py')
class SuperVisitor(ast.NodeVisitor):
def __init__(self, filename, *args, **kwargs):
super().__init__(*args, **kwargs)
self.filename = filename
self.current_class_name = None
self.lines_to_edit = []
def visit_ClassDef(self, node):
self.current_class_name = node.name
for child_node in node.body:
self.generic_visit(child_node)
self.current_class_name = None
return node
def visit_Call(self, node):
if (
self.current_class_name
and hasattr(node, 'func')
and getattr(node.func, 'id', None) == 'super'
and hasattr(node, 'args')
and len(node.args) == 2
):
class_name_arg, self_arg = node.args
if getattr(class_name_arg, 'id', None) == self.current_class_name and getattr(self_arg, 'id', None) == 'self':
assert class_name_arg.lineno == self_arg.end_lineno, f'Not in the same line: {self.filename}:{node.lineno}'
self.lines_to_edit.append((class_name_arg.lineno, class_name_arg.col_offset, self_arg.end_col_offset))
for child_node in ast.walk(node):
if child_node != node:
self.generic_visit(child_node)
return node
for filename in map(str, files):
with open(filename, 'r+') as f:
source = f.read()
parsed = ast.parse(source, filename)
node_visitor = SuperVisitor(filename)
node_visitor.visit(parsed)
if not node_visitor.lines_to_edit:
continue
linecache.clearcache()
for lineno, start, end in node_visitor.lines_to_edit:
current_line_text = linecache.getline(filename, lineno)
new_line_text = current_line_text[:start] + current_line_text[end:]
source = source.replace(current_line_text, new_line_text, 1)
f.seek(0)
f.truncate(0)
f.write(source)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment