Skip to content

Instantly share code, notes, and snippets.

@mikalv
Created August 13, 2019 01:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mikalv/5b0867d4220012a786fc980830d70058 to your computer and use it in GitHub Desktop.
Save mikalv/5b0867d4220012a786fc980830d70058 to your computer and use it in GitHub Desktop.
Generates dot file for python dependencies in a package
#!/usr/bin/env python
import _ast, ast
import argparse
import colorsys
import logging
import md5
import os.path
import sys
_logger = logging.getLogger( __name__)
class DependencyDotGenerator:
# This class is adapted from examples on http://www.tarind.com/depgraph.html
def render(self, dependencies, output=None):
# Create input for DOT.
f = open(output, 'w') if output else sys.stdout
f.write('digraph G {\n')
f.write('ranksep=1.0;\n')
f.write('node [style=filled,fontname=Helvetica,fontsize=10];\n')
for m, deps in dependencies.items():
for d in deps:
f.write('%s -> %s' % ( self.fix(m), self.fix(d) ) )
self.write_attributes(f, self.edge_attributes(m, d))
f.write(';\n')
f.write(self.fix(m))
self.write_attributes(f, self.node_attributes(m))
f.write(';\n')
f.write('}\n')
f.close()
def fix(self,s):
# Convert a module name to a syntactically correct node name
return s.replace('.','_')
def write_attributes(self, f, a):
if a:
f.write(' [')
f.write(','.join(a))
f.write(']')
def node_attributes(self, m):
a = []
a.append('label="%s"' % self.label(m))
a.append('fillcolor="%s"' % self.color(m))
return a
def edge_attributes(self, k, v):
a = []
return a
def label(self, s):
# Convert a module name to a formatted node label. This is a default policy - please override.
#
return '\\.\\n'.join(s.split('.'))
def color(self, s):
# Return the node color for this module name. This is a default policy - please override.
#
# Calculate a color systematically based on the hash of the module name. Modules in the
# same package have the same color. Unpackaged modules are grey
t = self.normalise_module_name_for_hash_coloring(s)
return self.color_from_name(t)
def normalise_module_name_for_hash_coloring(self, s):
# Color the modules at the same level with same color
i = s.rfind('.')
if i < 0:
return ''
else:
return s[:i]
def color_from_name(self, name):
n = md5.md5(name).digest()
hf = float(ord(n[0])+ord(n[1])*0xff)/0xffff
sf = float(ord(n[2]))/0xff
vf = float(ord(n[3]))/0xff
r,g,b = colorsys.hsv_to_rgb(hf, 0.3+0.6*sf, 0.8+0.2*vf)
return '#%02x%02x%02x' % (r*256,g*256,b*256)
class ImportVisitor(ast.NodeVisitor):
# This class visits AST tree and records all dependent modules from a package.
def __init__(self, package_name):
self.depgraph = {}
# dependency graph which is a map like
# module -> set_of_imported_modules
self.package_name = package_name
# We will only record modules in the gievn package.
self.cur_module_name = None
# current module of the AST tree being visited
def add_dependency(self, depend_module):
if depend_module.startswith(self.package_name):
# Only include the module we are interested, which is a module in our own package
self.depgraph.setdefault(self.cur_module_name, set()).add(depend_module)
def visit(self, node):
if isinstance(node, _ast.Import):
# AST definition:
# Import(alias* names)
# alias = (identifier name, identifier? asname)
for alias in ast.iter_child_nodes(node):
for (fieldname, value) in ast.iter_fields(alias):
if fieldname == 'name':
self.add_dependency(value)
elif isinstance(node, _ast.ImportFrom):
# AST definition:
# ImportFrom(identifier? module, alias* names, int? level)
for (fieldname, value) in ast.iter_fields(node):
if fieldname == 'module':
self.add_dependency(value)
def build_dot_graph_for_package(package_path, output_file):
(_, package_name) = os.path.split(package_path)
import_visitor = ImportVisitor(package_name)
# Analyze all the .py files under this directory 'path'
path_prefix_len = len(os.path.split(package_path)[0])
for root, _, files in os.walk(package_path):
module_prefix = ".".join(root[path_prefix_len + 1:].split('/'))
for pyfile in files:
if pyfile.endswith(".py"):
module_name = ".".join([module_prefix, os.path.splitext(pyfile)[0]])
import_visitor.cur_module_name = module_name
ast_tree = ast.parse(open(os.path.join(root, pyfile)).read())
for node in ast.walk(ast_tree):
import_visitor.visit(node)
dot_generator = DependencyDotGenerator()
dot_generator.render(import_visitor.depgraph, output_file)
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='Build module dependency graph for a package.')
parser.add_argument('-p', '--path', required=True, help='path to the top level package we want to analyze')
parser.add_argument('-o', '--out', help='output file, if missing, output is written to stdout')
args = parser.parse_args()
build_dot_graph_for_package(args.path, args.out)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment