Created
February 7, 2018 21:25
-
-
Save irmen/559c1ea8cdc7468a327351120c7b50ec to your computer and use it in GitHub Desktop.
import graph plotter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import os | |
import ast | |
import graphviz | |
def gather_files(path, recurse=True): | |
print("Gather source files ({0})...".format("recursively" if recurse else "nonrecursively")) | |
if not os.path.isfile(os.path.join(path, "__init__.py")): | |
raise ValueError("The path you provided was not the root of a package (there's no __init__.py)") | |
all_files = [] | |
skip_dirs = {".hg", ".svn", ".git", ".tox", "__pycache__", ".idea", ".tmp", ".bak"} | |
for dirpath, dirnames, filenames in os.walk(path): | |
if recurse: | |
dirnames[:] = [name for name in dirnames if name not in skip_dirs] | |
else: | |
dirnames.clear() | |
filenames = [f for f in filenames if f.endswith(".py")] | |
for filename in filenames: | |
all_files.append(os.path.join(dirpath, filename)) | |
return all_files | |
def gather_imports(files, package_root): | |
def convert_path_to_package(path): | |
if os.path.basename(path) == "__init__.py": | |
path = path[:-12] | |
elif path.endswith(".py"): | |
path = path[:-3] | |
return path.replace(os.sep, ".") | |
file_imports = {} | |
strip_prefix_len = len(os.path.dirname(package_root)) + 1 | |
print("Parsing...", end="", flush=True) | |
for fn in files: | |
with open(fn) as fh: | |
tree = ast.parse(fh.read(), fn) | |
print(".", end="", flush=True) | |
import_nodes = (node for node in ast.walk(tree) if isinstance(node, (ast.Import, ast.ImportFrom))) | |
imports = [] | |
for node in import_nodes: | |
imports.append({ | |
"names": [name.name for name in node.names], | |
"from": getattr(node, "module", None) or "", | |
"level": getattr(node, "level", None) or 0, | |
"line_no": node.lineno | |
}) | |
fn = fn[strip_prefix_len:] | |
fn = convert_path_to_package(fn) | |
file_imports[fn] = imports | |
print() | |
return file_imports | |
def process_imports(package_name, file_imports): | |
def module_from_package(modulename): | |
return modulename == package_name or modulename.startswith(package_name+".") or modulename.startswith(".") | |
def convert_imports(modulename, imports): | |
for imp in list(imports): | |
if imp["from"]: | |
if imp["level"] == 0: | |
# convert from X import Y to import X + import X.Y | |
prefix = imp["from"] | |
imp["names"] = [prefix+"."+name for name in imp["names"]] | |
imports.append({"names": [prefix], "level": 0, "from": ""}) | |
else: | |
# convert from ..X import Y to import base.X + import base.X.Y | |
levels = modulename.split(".")[:-imp["level"]] | |
prefix = ".".join(levels) | |
imp["names"] = [prefix+"."+imp["from"]] | |
imp["level"] = 0 | |
imports.append({"names": [prefix], "level": 0, "from": ""}) | |
imp["from"] = "" | |
# convert from .. import X, Y to import base + import base.X, import base.Y | |
for imp in imports: | |
if imp["level"] > 0: | |
levels = modulename.split(".")[:-imp["level"]] | |
prefix = ".".join(levels) | |
imp["level"] = 0 | |
imp["names"] = [prefix+"."+name for name in imp["names"]] | |
imports.append({"names": [prefix], "level": 0, "from": ""}) | |
return imports | |
print("ALL MODULES:") | |
imports_per_module = {} | |
all_modules = set(file_imports.keys()) | |
for modulename in sorted(all_modules): | |
print(" ", modulename) | |
for modulename in sorted(all_modules): | |
imports = convert_imports(modulename, file_imports[modulename]) | |
print("\nIMPORTS OF", modulename) | |
imported_names = set() | |
for imp in imports: | |
if imp["level"] > 0 or module_from_package(imp["from"]) or any(module_from_package(m) for m in imp["names"]): | |
if imp["from"]: | |
raise ValueError("should not have any from X import Y left after conversion") | |
if imp["level"] > 0: | |
raise ValueError("should not have any relative imports left after conversion") | |
imported_names.update(imp["names"]) | |
imported_names &= all_modules | |
imports_per_module[modulename] = imported_names | |
print(" ", imported_names) | |
return imports_per_module | |
def get_real_dirname_case(full_path, dirname): | |
files = os.listdir(os.path.dirname(full_path)) | |
dirname = dirname.lower() | |
for f in files: | |
if f.lower() == dirname: | |
return f | |
raise IOError("directory not found: "+dirname) | |
def make_graph(imports_per_module, package): | |
skip_modules = set() | |
#skip_modules = {name for name in imports_per_module if not imports_per_module[name]} | |
skip_modules.add("Pyro5") # XXX | |
skip_modules.add("Pyro4") # XXX | |
skip_modules.add("Pyro4.configuration") # XXX | |
skip_modules.add("Pyro4.errors") # XXX | |
skip_modules.add("Pyro5.errors") # XXX | |
skip_modules.add("Pyro5.config") # XXX | |
skip_modules.add("Pyro4.constants") # XXX | |
skip_modules.add("Pyro4.utils") # XXX | |
graph = graphviz.Digraph(name="import dependencies of "+package, comment="comment goes here", format="png", graph_attr={"ratio":"0.5"}) | |
for node, edges in imports_per_module.items(): | |
for edge in edges: | |
if edge not in skip_modules and node not in skip_modules: | |
graph.edge(node, edge) | |
graph.render("output.gv", view=True) | |
def main(): | |
parser = argparse.ArgumentParser(description="Detect cyclic imports in Python packages") | |
parser.add_argument("--norecurse", action="store_true", help="don't scan path recursively") | |
parser.add_argument("path", type=str, help="path where to look for source files") | |
args = parser.parse_args() | |
args.path = os.path.abspath(args.path) | |
package = os.path.basename(args.path) | |
package = get_real_dirname_case(args.path, package) | |
package_root = os.path.join(os.path.dirname(args.path), package) | |
print("Processing package '{0}' in '{1}'...".format(package, package_root)) | |
files = gather_files(package_root, not args.norecurse) | |
imports = gather_imports(files, package_root) | |
imports_per_module = process_imports(package, imports) | |
make_graph(imports_per_module, package) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment