Skip to content

Instantly share code, notes, and snippets.

@irmen
Created February 7, 2018 21:25
Show Gist options
  • Save irmen/559c1ea8cdc7468a327351120c7b50ec to your computer and use it in GitHub Desktop.
Save irmen/559c1ea8cdc7468a327351120c7b50ec to your computer and use it in GitHub Desktop.
import graph plotter
import argparse
import os
import ast
import graphviz
def gather_files(path, recurse=True):
print("Gather source files ({0})...".format("recursively" if recurse else "nonrecursively"))
if not os.path.isfile(os.path.join(path, "__init__.py")):
raise ValueError("The path you provided was not the root of a package (there's no __init__.py)")
all_files = []
skip_dirs = {".hg", ".svn", ".git", ".tox", "__pycache__", ".idea", ".tmp", ".bak"}
for dirpath, dirnames, filenames in os.walk(path):
if recurse:
dirnames[:] = [name for name in dirnames if name not in skip_dirs]
else:
dirnames.clear()
filenames = [f for f in filenames if f.endswith(".py")]
for filename in filenames:
all_files.append(os.path.join(dirpath, filename))
return all_files
def gather_imports(files, package_root):
def convert_path_to_package(path):
if os.path.basename(path) == "__init__.py":
path = path[:-12]
elif path.endswith(".py"):
path = path[:-3]
return path.replace(os.sep, ".")
file_imports = {}
strip_prefix_len = len(os.path.dirname(package_root)) + 1
print("Parsing...", end="", flush=True)
for fn in files:
with open(fn) as fh:
tree = ast.parse(fh.read(), fn)
print(".", end="", flush=True)
import_nodes = (node for node in ast.walk(tree) if isinstance(node, (ast.Import, ast.ImportFrom)))
imports = []
for node in import_nodes:
imports.append({
"names": [name.name for name in node.names],
"from": getattr(node, "module", None) or "",
"level": getattr(node, "level", None) or 0,
"line_no": node.lineno
})
fn = fn[strip_prefix_len:]
fn = convert_path_to_package(fn)
file_imports[fn] = imports
print()
return file_imports
def process_imports(package_name, file_imports):
def module_from_package(modulename):
return modulename == package_name or modulename.startswith(package_name+".") or modulename.startswith(".")
def convert_imports(modulename, imports):
for imp in list(imports):
if imp["from"]:
if imp["level"] == 0:
# convert from X import Y to import X + import X.Y
prefix = imp["from"]
imp["names"] = [prefix+"."+name for name in imp["names"]]
imports.append({"names": [prefix], "level": 0, "from": ""})
else:
# convert from ..X import Y to import base.X + import base.X.Y
levels = modulename.split(".")[:-imp["level"]]
prefix = ".".join(levels)
imp["names"] = [prefix+"."+imp["from"]]
imp["level"] = 0
imports.append({"names": [prefix], "level": 0, "from": ""})
imp["from"] = ""
# convert from .. import X, Y to import base + import base.X, import base.Y
for imp in imports:
if imp["level"] > 0:
levels = modulename.split(".")[:-imp["level"]]
prefix = ".".join(levels)
imp["level"] = 0
imp["names"] = [prefix+"."+name for name in imp["names"]]
imports.append({"names": [prefix], "level": 0, "from": ""})
return imports
print("ALL MODULES:")
imports_per_module = {}
all_modules = set(file_imports.keys())
for modulename in sorted(all_modules):
print(" ", modulename)
for modulename in sorted(all_modules):
imports = convert_imports(modulename, file_imports[modulename])
print("\nIMPORTS OF", modulename)
imported_names = set()
for imp in imports:
if imp["level"] > 0 or module_from_package(imp["from"]) or any(module_from_package(m) for m in imp["names"]):
if imp["from"]:
raise ValueError("should not have any from X import Y left after conversion")
if imp["level"] > 0:
raise ValueError("should not have any relative imports left after conversion")
imported_names.update(imp["names"])
imported_names &= all_modules
imports_per_module[modulename] = imported_names
print(" ", imported_names)
return imports_per_module
def get_real_dirname_case(full_path, dirname):
files = os.listdir(os.path.dirname(full_path))
dirname = dirname.lower()
for f in files:
if f.lower() == dirname:
return f
raise IOError("directory not found: "+dirname)
def make_graph(imports_per_module, package):
skip_modules = set()
#skip_modules = {name for name in imports_per_module if not imports_per_module[name]}
skip_modules.add("Pyro5") # XXX
skip_modules.add("Pyro4") # XXX
skip_modules.add("Pyro4.configuration") # XXX
skip_modules.add("Pyro4.errors") # XXX
skip_modules.add("Pyro5.errors") # XXX
skip_modules.add("Pyro5.config") # XXX
skip_modules.add("Pyro4.constants") # XXX
skip_modules.add("Pyro4.utils") # XXX
graph = graphviz.Digraph(name="import dependencies of "+package, comment="comment goes here", format="png", graph_attr={"ratio":"0.5"})
for node, edges in imports_per_module.items():
for edge in edges:
if edge not in skip_modules and node not in skip_modules:
graph.edge(node, edge)
graph.render("output.gv", view=True)
def main():
parser = argparse.ArgumentParser(description="Detect cyclic imports in Python packages")
parser.add_argument("--norecurse", action="store_true", help="don't scan path recursively")
parser.add_argument("path", type=str, help="path where to look for source files")
args = parser.parse_args()
args.path = os.path.abspath(args.path)
package = os.path.basename(args.path)
package = get_real_dirname_case(args.path, package)
package_root = os.path.join(os.path.dirname(args.path), package)
print("Processing package '{0}' in '{1}'...".format(package, package_root))
files = gather_files(package_root, not args.norecurse)
imports = gather_imports(files, package_root)
imports_per_module = process_imports(package, imports)
make_graph(imports_per_module, package)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment