Skip to content

Instantly share code, notes, and snippets.

@jgabriellima
Last active May 9, 2023 16:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jgabriellima/a4ed73df4aa28cb9980520b6de5fe8cf to your computer and use it in GitHub Desktop.
Save jgabriellima/a4ed73df4aa28cb9980520b6de5fe8cf to your computer and use it in GitHub Desktop.
This code create a call flow graph from a package
import ast
import base64
import networkx as nx
import matplotlib.pyplot as plt
import subprocess
import os
class FlowchartAnalyzer(ast.NodeVisitor):
def __init__(self):
self.current_function = None
self.flowchart = {}
def visit_FunctionDef(self, node):
self.current_function = node.name
self.flowchart[self.current_function] = []
self.generic_visit(node)
self.current_function = None
def visit_If(self, node):
condition = ast.dump(node.test)
self.flowchart[self.current_function].append(('if', condition))
self.generic_visit(node)
def visit_While(self, node):
condition = ast.dump(node.test)
self.flowchart[self.current_function].append(('while', condition))
self.generic_visit(node)
def visit_For(self, node):
iterable = ast.dump(node.iter)
self.flowchart[self.current_function].append(('for', iterable))
self.generic_visit(node)
def flowchart_to_mermaid(flowchart):
mermaid_text = "graph TD\n"
for function, controls in flowchart.items():
for i, control in enumerate(controls):
control_type, condition = control
# Escape special characters in the condition
condition = condition.replace("\n", "\\n").replace("'", "\\'")
# Give each node a unique ID and specify its text separately
node_id = f"{function}_{i}"
node_text = f"{control_type} {condition}"
mermaid_text += f"{node_id}[\"{node_text}\"]\n"
if i > 0:
mermaid_text += f"{function}_{i - 1} --> {node_id}\n"
return mermaid_text
def get_flowchart(pkg_path):
analyzer = FlowchartAnalyzer()
for root, _, files in os.walk(pkg_path):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
with open(file_path, "r", encoding="utf-8") as f:
try:
file_ast = ast.parse(f.read())
analyzer.visit(file_ast)
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return analyzer.flowchart
def clone_github_repo(github_link, local_dir):
if not os.path.exists(local_dir):
os.makedirs(local_dir)
subprocess.run(["git", "clone", github_link, local_dir])
class CallGraphAnalyzer(ast.NodeVisitor):
def __init__(self, call_graph, exclude_libs=False):
self.call_graph = call_graph
self.current_function = None
self.exclude_libs = exclude_libs
self.functions = set()
self.current_module = ""
self.current_class = None
def visit_ClassDef(self, node):
self.current_class = node.name
self.generic_visit(node)
self.current_class = None
def visit_FunctionDef(self, node):
if self.current_class:
func_name = f"{self.current_module}.{self.current_class}.{node.name}"
else:
func_name = f"{self.current_module}.{node.name}"
self.functions.add(func_name)
self.current_function = func_name
self.generic_visit(node)
def visit_Call(self, node):
func_name = None
if isinstance(node.func, ast.Name):
if node.func.id == 'self' and self.current_class:
func_name = f"{self.current_module}.{self.current_class}"
else:
func_name = f"{self.current_module}.{node.func.id}"
elif isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
if node.func.value.id == 'self' and self.current_class:
func_name = f"{self.current_module}.{self.current_class}.{node.func.attr}"
else:
func_name = f"{node.func.value.id}.{node.func.attr}"
if func_name and (not self.exclude_libs or (self.exclude_libs and func_name in self.functions)):
self.call_graph.add_edge(self.current_function, func_name)
self.generic_visit(node)
class ClassDiagramAnalyzer(ast.NodeVisitor):
def __init__(self):
self.classes = {}
self.inheritance = {}
def visit_ClassDef(self, node):
class_name = node.name
methods = [n.name for n in node.body if isinstance(n, ast.FunctionDef)]
self.classes[class_name] = methods
for base in node.bases:
if isinstance(base, ast.Name):
base_name = base.id
self.inheritance[class_name] = base_name
self.generic_visit(node)
def mm(graph):
graphbytes = graph.encode("ascii")
base64_bytes = base64.b64encode(graphbytes)
base64_string = base64_bytes.decode("ascii")
print("https://mermaid.ink/img/" + base64_string)
def get_call_flow_graph_package(pkg_path, exclude_libs=False):
call_graph = nx.DiGraph()
analyzer = CallGraphAnalyzer(call_graph, exclude_libs)
for root, _, files in os.walk(pkg_path):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
with open(file_path, "r", encoding="utf-8") as f:
try:
file_ast = ast.parse(f.read())
analyzer.current_module = os.path.splitext(file)[0]
analyzer.visit(file_ast)
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return call_graph
def get_call_flow_graph(file_path, class_or_function_name, exclude_libs=False):
call_graph = nx.DiGraph()
analyzer = CallGraphAnalyzer(call_graph, exclude_libs)
with open(file_path, "r", encoding="utf-8") as f:
try:
file_ast = ast.parse(f.read())
for node in ast.walk(file_ast):
if ((isinstance(node, ast.ClassDef) or isinstance(node, ast.FunctionDef))
and node.name == class_or_function_name):
analyzer.visit(node)
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return call_graph
def get_class_diagram(pkg_path):
analyzer = ClassDiagramAnalyzer()
for root, _, files in os.walk(pkg_path):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
with open(file_path, "r", encoding="utf-8") as f:
try:
file_ast = ast.parse(f.read())
analyzer.visit(file_ast)
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return analyzer.classes, analyzer.inheritance
def draw_call_flow_graph(call_graph):
pos = nx.spring_layout(call_graph)
nx.draw(call_graph, pos, with_labels=True, node_size=2000, font_size=10, font_weight='bold', arrows=True)
plt.show()
def graph_to_text(call_graph):
text = "Call Flow Graph:\n"
for caller, callees in call_graph.adj.items():
for callee in callees:
text += f"{caller} calls {callee}\n"
return text
def text_to_mermaid(text, orientation="TB"):
lines = text.split("\n")
mermaid_text = f"graph {orientation}\n"
for line in lines:
if "calls" in line:
caller, _, callee = line.split()
mermaid_text += f"{caller} --> {callee}\n"
return mermaid_text
def class_diagram_to_text(class_diagram, inheritance):
text = "Class Diagram:\n"
for class_name, methods in class_diagram.items():
text += f"Class {class_name}:\n"
for method in methods:
text += f" Method {method}\n"
for subclass, superclass in inheritance.items():
text += f"{subclass} inherits from {superclass}\n"
return text
def class_diagram_to_mermaid(class_diagram, inheritance):
mermaid_text = "classDiagram\n"
for class_name, methods in class_diagram.items():
mermaid_text += f"class {class_name} {{\n"
for method in methods:
mermaid_text += f" {method}()\n"
mermaid_text += "}\n"
for subclass, superclass in inheritance.items():
mermaid_text += f"{subclass} --|> {superclass}\n"
return mermaid_text
class SequenceDiagramAnalyzer(ast.NodeVisitor):
def __init__(self):
self.classes = {}
self.current_class = None
self.current_function = None
self.user_defined_functions = set()
def visit_ClassDef(self, node):
self.current_class = node.name
self.classes[self.current_class] = []
self.generic_visit(node)
self.current_class = None
def visit_FunctionDef(self, node):
self.current_function = node.name
if self.current_class:
self.user_defined_functions.add(f"{self.current_class}.{self.current_function}")
else:
self.user_defined_functions.add(self.current_function)
self.generic_visit(node)
self.current_function = None
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
if node.func.value.id == 'self':
called_class = self.current_class
else:
called_class = node.func.value.id
called_function = node.func.attr
if f"{called_class}.{called_function}" in self.user_defined_functions:
self.classes[self.current_class].append((self.current_function, called_class, called_function))
self.generic_visit(node)
def sequence_diagram_to_mermaid(sequence_diagram):
mermaid_text = "sequenceDiagram\n"
for class_name, calls in sequence_diagram.items():
for call in calls:
current_function, called_class, called_function = call
mermaid_text += f"{class_name}.{current_function} ->> {called_class}.{called_function}: call\n"
return mermaid_text
def get_sequence_diagram(pkg_path):
analyzer = SequenceDiagramAnalyzer()
for root, _, files in os.walk(pkg_path):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
with open(file_path, "r", encoding="utf-8") as f:
try:
file_ast = ast.parse(f.read())
analyzer.visit(file_ast)
except Exception as e:
print(f"Error parsing {file_path}: {e}")
return analyzer.classes
if __name__ == '__main__':
package_path = ""
graph = get_call_flow_graph(package_path, 'ContractProcessor', exclude_libs=True)
draw_call_flow_graph(graph)
# print(graph_to_text(graph))
print(text_to_mermaid(graph_to_text(graph), 'LR'))
class_diagram, inheritance = get_class_diagram(package_path)
# print(class_diagram_to_text(class_diagram, inheritance))
print(class_diagram_to_mermaid(class_diagram, inheritance))
# sequence_diagram = get_sequence_diagram(package_path)
# print(sequence_diagram_to_mermaid(sequence_diagram))
# github_link = "https://github.com/username/repo.git"
# local_dir = "/path/to/local/dir"
# clone_github_repo(github_link, local_dir)
flowchart = get_flowchart(package_path)
print(flowchart_to_mermaid(flowchart))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment