Skip to content

Instantly share code, notes, and snippets.

@Den1al
Last active February 28, 2024 23:48
Show Gist options
  • Save Den1al/3a85c23df1818b4857dccbbe694452d9 to your computer and use it in GitHub Desktop.
Save Den1al/3a85c23df1818b4857dccbbe694452d9 to your computer and use it in GitHub Desktop.
Get the dependencies of functions in python using abstract syntax trees
from pathlib import Path
import ast
from typing import Any, List, Tuple
from dataclasses import dataclass, field
@dataclass
class Dependencies:
arguments: List[ast.arg] = field(default_factory=list)
variables: List[ast.Name] = field(default_factory=list)
FunctionsList = List[ast.FunctionDef]
FunctionDepsList = List[Tuple[ast.FunctionDef, Dependencies]]
class GetCandidatesNodeVisitor(ast.NodeVisitor):
"""A traversal visitor that will generate the list of candidates"""
def __init__(self) -> None:
super().__init__()
self.functions: FunctionsList = []
def visit(self, node: ast.AST) -> FunctionsList:
super().visit(node)
return self.functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
self.functions.append(node)
class GetDependenciesNodeVisitor(ast.NodeVisitor):
"""A traversal visitor that will get all the depenecies of a given function"""
def __init__(self) -> None:
super().__init__()
self.dependencies = Dependencies()
def visit(self, node: ast.AST) -> Dependencies:
super().visit(node)
return self.dependencies
def visit_arg(self, node: ast.arg) -> Any:
self.dependencies.arguments.append(node)
def visit_Name(self, node: ast.Name) -> Any:
self.dependencies.variables.append(node)
def get_func_dependencies(func_def: ast.FunctionDef) -> Dependencies:
"""Gets the dependencies for a given function definition"""
visitor = GetDependenciesNodeVisitor()
return visitor.visit(func_def)
def parse(source: str) -> FunctionDepsList:
"""Parses a source string"""
tree = ast.parse(source)
visitor = GetCandidatesNodeVisitor()
functions = visitor.visit(tree)
return [(func, get_func_dependencies(func)) for func in functions]
def main() -> None:
p = Path("./example.py")
with p.open("r") as f:
results = parse(f.read())
for func, deps in results:
print(f"current {func.name=}")
args = [arg.arg for arg in deps.arguments]
print(f"- {args=}")
variables = [v.id for v in deps.variables]
print(f"- {variables=}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment