Skip to content

Instantly share code, notes, and snippets.

@Magnus167
Created July 10, 2023 16:47
Show Gist options
  • Save Magnus167/727a4f982c541f6fb626d49b50d09d14 to your computer and use it in GitHub Desktop.
Save Magnus167/727a4f982c541f6fb626d49b50d09d14 to your computer and use it in GitHub Desktop.
getting function level dependencies from a python package or a set of notebooks
"""
Scans Jupyter Notebooks for functions used from a given package.
The return
"""
from typing import List, Dict, Set, Tuple, Type, Callable
import ast
# import functools
import glob
import nbformat
def try_except(fail_ok=False):
"""
Decorator to handle exceptions.
"""
def decorator(func: Callable) -> Callable:
def wrapper(*args, **kwargs) -> Callable:
halt_exceptions: Tuple[Type[Exception], ...] = (
KeyboardInterrupt,
SystemExit,
MemoryError,
OSError,
)
try:
return func(*args, **kwargs)
except Exception as exc:
if isinstance(exc, halt_exceptions):
raise exc
if fail_ok:
# get the typing... hint of the return type and return an empty object of that type
...
return wrapper
return decorator
@try_except(fail_ok=False)
def remove_jupyter_magics(code: str) -> str:
"""
Removes Jupyter magics from the code.
:param <str> code: code to remove Jupyter magics from.
:return <str> code: code with Jupyter magics removed.
"""
# Remove Jupyter magics
code_lines: List[str] = code.split("\n")
for i, cline in enumerate(code_lines):
if cline.startswith("%") or cline.startswith("!"):
code_lines[i]: str = "# " + cline
code: str = "\n".join(code_lines)
return code
@try_except(fail_ok=True)
def extract_functions_used(code: str) -> Set[str]:
"""
Extracts function names used in the code.
:param <str> code: code to extract function names from.
:return <set> functions_used: set of function names used in the code.
"""
functions_used: Set[str] = set()
# try:
tree: ast.AST = ast.parse(code)
# except:
# breakpoint()
for node in ast.walk(tree):
if isinstance(node, ast.Name):
functions_used.add(node.id)
return functions_used
@try_except(fail_ok=True)
def extract_functions_imported(code) -> Set[str]:
"""
Extracts function names imported in the code.
:param <str> code: code to extract function names from.
:return <set> functions_imported: set of function names imported in the code.
"""
functions_imported: Set[str] = set()
# Extract imported package names
tree: ast.AST = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
functions_imported.add(alias.name.split(".")[0])
elif isinstance(node, ast.ImportFrom):
functions_imported.add(node.module.split(".")[0])
return functions_imported
def main():
"""
Driver function.
"""
# Path to the directory containing the Jupyter Notebooks
notebooks_directory: str = "./notebooks"
# Package name to track
package_name: str = "package_name"
# Mapping of functions used in each notebook
mapping: Dict[str, Set[str]] = {}
# Iterate over the Jupyter Notebook files (at any level of nesting)
for notebook_file in glob.glob(f"{notebooks_directory}/**/*.ipynb", recursive=True):
with open(notebook_file, "r", encoding="utf8") as f:
notebook: nbformat.NotebookNode = nbformat.read(f, as_version=4)
functions_used = set()
functions_imported = set()
# Iterate over code cells in the notebook
for cell in notebook.cells:
if cell.cell_type == "code":
code: str = cell.source
code: str = remove_jupyter_magics(code)
functions_used.update(extract_functions_used(code))
functions_imported.update(extract_functions_imported(code))
# Filter functions used from the desired package
functions_used_l: List[str] = list(
set(
[
func
for func in functions_used
if func in functions_imported and package_name in functions_imported
]
)
)
# Store the mapping of functions used in the notebook
mapping[notebook_file]: Set[str] = functions_used_l
# Print the mapping of functions used in each notebook
for notebook_file, functions_used in mapping.items():
print(f"{notebook_file}: {functions_used}")
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment