Created
November 12, 2024 01:41
-
-
Save msaroufim/2df6d47a0164b3524eefb97aa23f6696 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from pathlib import Path | |
from typing import Set, Dict | |
from collections import defaultdict | |
def analyze_imports(file_path: str) -> Dict[str, Set[str]]: | |
""" | |
Analyze Python file imports and return a dictionary of package dependencies. | |
Args: | |
file_path (str): Path to the Python file to analyze | |
Returns: | |
Dict[str, Set[str]]: Dictionary with: | |
- 'imports': Set of package names from regular imports | |
- 'from_imports': Set of (module, [imported_names]) from from-imports | |
""" | |
with open(file_path, 'r') as file: | |
tree = ast.parse(file.read()) | |
imports = defaultdict(set) | |
for node in ast.walk(tree): | |
# Handle regular imports (import torch) | |
if isinstance(node, ast.Import): | |
for name in node.names: | |
base_package = name.name.split('.')[0] | |
imports['imports'].add(base_package) | |
# Handle from-imports (from torch import nn) | |
elif isinstance(node, ast.ImportFrom): | |
if node.module: # Ignore relative imports like "from . import x" | |
base_package = node.module.split('.')[0] | |
imported_names = {n.name for n in node.names} | |
imports['from_imports'].add((base_package, tuple(imported_names))) | |
return imports | |
def print_dependencies(file_path: str): | |
"""Pretty print the dependencies of a Python file.""" | |
try: | |
imports = analyze_imports(file_path) | |
print(f"\nDependencies for {file_path}:") | |
print("\nDirect imports:") | |
if imports['imports']: | |
for pkg in sorted(imports['imports']): | |
print(f" - {pkg}") | |
else: | |
print(" None") | |
print("\nFrom imports:") | |
if imports['from_imports']: | |
for module, names in sorted(imports['from_imports']): | |
names_str = ', '.join(sorted(names)) | |
print(f" - {module}: {names_str}") | |
else: | |
print(" None") | |
except Exception as e: | |
print(f"Error analyzing {file_path}: {str(e)}") | |
# Example usage | |
if __name__ == "__main__": | |
# You can use it on a single file | |
print_dependencies("your_script.py") | |
# Or analyze multiple files in a directory | |
directory = Path(".") | |
for python_file in directory.glob("*.py"): | |
print_dependencies(str(python_file)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment