Skip to content

Instantly share code, notes, and snippets.

@gtors
Last active January 27, 2021 11:13
Show Gist options
  • Save gtors/27f878d2d42721d0f69e0ec98810ef80 to your computer and use it in GitHub Desktop.
Save gtors/27f878d2d42721d0f69e0ec98810ef80 to your computer and use it in GitHub Desktop.
Python script for sorting class definitions in file
"""
Script for sorting classes by type in annotation. It may be usefull for code refactoring after code generation.
For example:
```
class C:
a: Optional[A] = None
b: Optional[B] = None
class A:
b: Optional[B] = None
class B:
c: int
```
will be fixed to:
```
class B:
c: int
class A:
b: Optional[B] = None
class C:
a: Optional[A] = None
b: Optional[B] = None
```
"""
import sys
import ast
import functools
from collections import defaultdict
from graphlib import TopologicalSorter # py3.9
ast_filter = lambda it, ty: (x for x in getattr(it, 'body', it) if isinstance(x, ty))
def populate_class_graph(tree):
graph = defaultdict(list)
for cls in ast_filter(tree, ast.ClassDef):
graph[cls.name] = []
for prop in ast_filter(cls, ast.AnnAssign):
ann = prop.annotation
if isinstance(ann, ast.Name):
graph[cls.name].append(ann.id)
elif isinstance(ann, ast.Subscript):
if isinstance(ann.slice, ast.Name):
graph[cls.name].append(ann.slice.id)
elif isinstance(ann.slice, (ast.Tuple, ast.Subscript)):
if isinstance(ann.slice, ast.Tuple):
traverse = prop.annotation.slice.elts
else:
traverse = [ann.slice]
while traverse:
_traverse = []
for name in ast_filter(traverse, ast.Name):
graph[cls.name].append(name.id)
for subs in ast_filter(traverse, ast.Subscript):
if isinstance(subs.slice, (ast.Name, ast.Subscript)):
_traverse.append(subs.slice)
elif isinstance(subs.slice, ast.Tuple):
_traverse.extend(subs.slice.elts)
traverse = _traverse
return graph
if __name__ == "__main__":
file_name = sys.argv(1)
with open(file_name) as f:
py_code = f.read()
ast_tree = ast.parse(py_code)
graph = populate_class_graph(ast_tree)
topo_sorter = TopologicalSorter(graph)
class_order = tuple(topo_sorter.static_order())
# A bit messy, but if the file consists only of imports and class definitions, then ok
tree.body.sort(key=(
lambda x: (
class_order.index(n)
if (n := getattr(x, 'name', None)) in class_order else
0
)
)
with open("fixed_" + file_name) as f:
f.write(ast.unparse(tree))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment