Last active
February 28, 2024 23:48
-
-
Save Den1al/3a85c23df1818b4857dccbbe694452d9 to your computer and use it in GitHub Desktop.
Get the dependencies of functions in python using abstract syntax trees
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import ast | |
from typing import Any, List, Tuple | |
from dataclasses import dataclass, field | |
@dataclass | |
class Dependencies: | |
arguments: List[ast.arg] = field(default_factory=list) | |
variables: List[ast.Name] = field(default_factory=list) | |
FunctionsList = List[ast.FunctionDef] | |
FunctionDepsList = List[Tuple[ast.FunctionDef, Dependencies]] | |
class GetCandidatesNodeVisitor(ast.NodeVisitor): | |
"""A traversal visitor that will generate the list of candidates""" | |
def __init__(self) -> None: | |
super().__init__() | |
self.functions: FunctionsList = [] | |
def visit(self, node: ast.AST) -> FunctionsList: | |
super().visit(node) | |
return self.functions | |
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: | |
self.functions.append(node) | |
class GetDependenciesNodeVisitor(ast.NodeVisitor): | |
"""A traversal visitor that will get all the depenecies of a given function""" | |
def __init__(self) -> None: | |
super().__init__() | |
self.dependencies = Dependencies() | |
def visit(self, node: ast.AST) -> Dependencies: | |
super().visit(node) | |
return self.dependencies | |
def visit_arg(self, node: ast.arg) -> Any: | |
self.dependencies.arguments.append(node) | |
def visit_Name(self, node: ast.Name) -> Any: | |
self.dependencies.variables.append(node) | |
def get_func_dependencies(func_def: ast.FunctionDef) -> Dependencies: | |
"""Gets the dependencies for a given function definition""" | |
visitor = GetDependenciesNodeVisitor() | |
return visitor.visit(func_def) | |
def parse(source: str) -> FunctionDepsList: | |
"""Parses a source string""" | |
tree = ast.parse(source) | |
visitor = GetCandidatesNodeVisitor() | |
functions = visitor.visit(tree) | |
return [(func, get_func_dependencies(func)) for func in functions] | |
def main() -> None: | |
p = Path("./example.py") | |
with p.open("r") as f: | |
results = parse(f.read()) | |
for func, deps in results: | |
print(f"current {func.name=}") | |
args = [arg.arg for arg in deps.arguments] | |
print(f"- {args=}") | |
variables = [v.id for v in deps.variables] | |
print(f"- {variables=}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment