Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Created November 12, 2024 01:41
Show Gist options
  • Save msaroufim/2df6d47a0164b3524eefb97aa23f6696 to your computer and use it in GitHub Desktop.
Save msaroufim/2df6d47a0164b3524eefb97aa23f6696 to your computer and use it in GitHub Desktop.
import ast
from pathlib import Path
from typing import Set, Dict
from collections import defaultdict
def analyze_imports(file_path: str) -> Dict[str, Set[str]]:
"""
Analyze Python file imports and return a dictionary of package dependencies.
Args:
file_path (str): Path to the Python file to analyze
Returns:
Dict[str, Set[str]]: Dictionary with:
- 'imports': Set of package names from regular imports
- 'from_imports': Set of (module, [imported_names]) from from-imports
"""
with open(file_path, 'r') as file:
tree = ast.parse(file.read())
imports = defaultdict(set)
for node in ast.walk(tree):
# Handle regular imports (import torch)
if isinstance(node, ast.Import):
for name in node.names:
base_package = name.name.split('.')[0]
imports['imports'].add(base_package)
# Handle from-imports (from torch import nn)
elif isinstance(node, ast.ImportFrom):
if node.module: # Ignore relative imports like "from . import x"
base_package = node.module.split('.')[0]
imported_names = {n.name for n in node.names}
imports['from_imports'].add((base_package, tuple(imported_names)))
return imports
def print_dependencies(file_path: str):
"""Pretty print the dependencies of a Python file."""
try:
imports = analyze_imports(file_path)
print(f"\nDependencies for {file_path}:")
print("\nDirect imports:")
if imports['imports']:
for pkg in sorted(imports['imports']):
print(f" - {pkg}")
else:
print(" None")
print("\nFrom imports:")
if imports['from_imports']:
for module, names in sorted(imports['from_imports']):
names_str = ', '.join(sorted(names))
print(f" - {module}: {names_str}")
else:
print(" None")
except Exception as e:
print(f"Error analyzing {file_path}: {str(e)}")
# Example usage
if __name__ == "__main__":
# You can use it on a single file
print_dependencies("your_script.py")
# Or analyze multiple files in a directory
directory = Path(".")
for python_file in directory.glob("*.py"):
print_dependencies(str(python_file))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment