Created
April 2, 2024 23:07
-
-
Save noklam/117be538e95173fc19493e4eb5e6fab3 to your computer and use it in GitHub Desktop.
%load_node for kedro 0.18 series, replace kedro/ipython/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script creates an IPython extension to load Kedro-related variables in | |
local scope. | |
""" | |
from __future__ import annotations | |
import inspect | |
import logging | |
import os | |
import sys | |
import typing | |
import warnings | |
from pathlib import Path | |
from types import MappingProxyType | |
from typing import Any, Callable, OrderedDict, Union | |
from IPython.core.getipython import get_ipython | |
from IPython.core.magic import needs_local_scope, register_line_magic | |
from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring | |
from rich.console import Console | |
from rich.syntax import Syntax | |
from kedro.framework.cli import load_entry_points | |
from kedro.framework.cli.project import CONF_SOURCE_HELP, PARAMS_ARG_HELP | |
from kedro.framework.cli.utils import ENV_HELP, _split_params | |
from kedro.framework.project import ( | |
LOGGING, # noqa: F401 | |
_ProjectPipelines, | |
configure_project, | |
pipelines, | |
) | |
from kedro.framework.session import KedroSession | |
from kedro.framework.startup import bootstrap_project | |
from kedro.pipeline.node import Node | |
logger = logging.getLogger(__name__) | |
FunctionParameters = MappingProxyType | |
_PYPROJECT = "pyproject.toml" | |
def _is_databricks() -> bool: | |
return "DATABRICKS_RUNTIME_VERSION" in os.environ | |
def _is_project(project_path: Union[str, Path]) -> bool: | |
metadata_file = Path(project_path).expanduser().resolve() / _PYPROJECT | |
if not metadata_file.is_file(): | |
return False | |
try: | |
return "[tool.kedro]" in metadata_file.read_text(encoding="utf-8") | |
except Exception: # noqa: broad-except | |
return False | |
def _find_kedro_project(current_dir: Path): # pragma: no cover | |
while current_dir != current_dir.parent: | |
if _is_project(current_dir): | |
return current_dir | |
current_dir = current_dir.parent | |
return None | |
def load_ipython_extension(ipython: Any) -> None: | |
""" | |
Main entry point when %load_ext kedro.ipython is executed, either manually or | |
automatically through `kedro ipython` or `kedro jupyter lab/notebook`. | |
IPython will look for this function specifically. | |
See https://ipython.readthedocs.io/en/stable/config/extensions/index.html | |
""" | |
ipython.register_magic_function(magic_reload_kedro, magic_name="reload_kedro") | |
logger.info("Registered line magic '%reload_kedro'") | |
ipython.register_magic_function(magic_load_node, magic_name="load_node") | |
logger.info("Registered line magic '%load_node'") | |
if _find_kedro_project(Path.cwd()) is None: | |
logger.warning( | |
"Kedro extension was registered but couldn't find a Kedro project. " | |
"Make sure you run '%reload_kedro <project_root>'." | |
) | |
return | |
reload_kedro() | |
@typing.no_type_check | |
@needs_local_scope | |
@magic_arguments() | |
@argument( | |
"path", | |
type=str, | |
help=( | |
"Path to the project root directory. If not given, use the previously set" | |
"project root." | |
), | |
nargs="?", | |
default=None, | |
) | |
@argument("-e", "--env", type=str, default=None, help=ENV_HELP) | |
@argument( | |
"--params", | |
type=lambda value: _split_params(None, None, value), | |
default=None, | |
help=PARAMS_ARG_HELP, | |
) | |
@argument("--conf-source", type=str, default=None, help=CONF_SOURCE_HELP) | |
def magic_reload_kedro( | |
line: str, | |
local_ns: dict[str, Any] | None = None, | |
conf_source: str | None = None, | |
) -> None: | |
""" | |
The `%reload_kedro` IPython line magic. | |
See https://kedro.readthedocs.io/en/stable/notebooks_and_ipython/kedro_and_notebooks.html#reload-kedro-line-magic | |
for more. | |
""" | |
args = parse_argstring(magic_reload_kedro, line) | |
reload_kedro(args.path, args.env, args.params, local_ns, args.conf_source) | |
def reload_kedro( | |
path: str | None = None, | |
env: str | None = None, | |
extra_params: dict[str, Any] | None = None, | |
local_namespace: dict[str, Any] | None = None, | |
conf_source: str | None = None, | |
) -> None: # pragma: no cover | |
"""Function that underlies the %reload_kedro Line magic. This should not be imported | |
or run directly but instead invoked through %reload_kedro.""" | |
project_path = _resolve_project_path(path, local_namespace) | |
metadata = bootstrap_project(project_path) | |
_remove_cached_modules(metadata.package_name) | |
configure_project(metadata.package_name) | |
session = KedroSession.create( | |
project_path, | |
env=env, | |
extra_params=extra_params, | |
conf_source=conf_source, | |
) | |
context = session.load_context() | |
catalog = context.catalog | |
get_ipython().push( # type: ignore[no-untyped-call] | |
variables={ | |
"context": context, | |
"catalog": catalog, | |
"session": session, | |
"pipelines": pipelines, | |
} | |
) | |
logger.info("Kedro project %s", str(metadata.project_name)) | |
logger.info( | |
"Defined global variable 'context', 'session', 'catalog' and 'pipelines'" | |
) | |
for line_magic in load_entry_points("line_magic"): | |
register_line_magic(needs_local_scope(line_magic)) # type: ignore[no-untyped-call] | |
logger.info("Registered line magic '%s'", line_magic.__name__) # type: ignore[attr-defined] | |
def _resolve_project_path( | |
path: str | None = None, local_namespace: dict[str, Any] | None = None | |
) -> Path: | |
""" | |
Resolve the project path to use with reload_kedro, updating or adding it | |
(in-place) to the local ipython Namespace (``local_namespace``) if necessary. | |
Arguments: | |
path: the path to use as a string object | |
local_namespace: Namespace with local variables of the scope where the line | |
magic is invoked in a dict. | |
""" | |
if path: | |
project_path = Path(path).expanduser().resolve() | |
else: | |
if local_namespace and "context" in local_namespace: | |
project_path = local_namespace["context"].project_path | |
else: | |
project_path = _find_kedro_project(Path.cwd()) | |
if project_path: | |
logger.info( | |
"Resolved project path as: %s.\nTo set a different path, run " | |
"'%%reload_kedro <project_root>'", | |
project_path, | |
) | |
if ( | |
project_path | |
and local_namespace | |
and "context" in local_namespace | |
and project_path != local_namespace["context"].project_path | |
): | |
logger.info("Updating path to Kedro project: %s...", project_path) | |
return project_path | |
def _remove_cached_modules(package_name: str) -> None: # pragma: no cover | |
to_remove = [mod for mod in sys.modules if mod.startswith(package_name)] | |
# `del` is used instead of `reload()` because: If the new version of a module does not | |
# define a name that was defined by the old version, the old definition remains. | |
for module in to_remove: | |
del sys.modules[module] | |
def _guess_run_environment() -> str: # pragma: no cover | |
"""Best effort to guess the IPython/Jupyter environment""" | |
# https://github.com/microsoft/vscode-jupyter/issues/7380 | |
if os.environ.get("VSCODE_PID") or os.environ.get("VSCODE_CWD"): | |
return "vscode" | |
elif _is_databricks(): | |
return "databricks" | |
elif hasattr(get_ipython(), "kernel"): # type: ignore[no-untyped-call] | |
# IPython terminal does not have this attribute | |
return "jupyter" | |
else: | |
return "ipython" | |
@typing.no_type_check | |
@magic_arguments() | |
@argument( | |
"node", | |
type=str, | |
help=("Name of the Node."), | |
nargs="?", | |
default=None, | |
) | |
def magic_load_node(args: str) -> None: | |
"""The line magic %load_node <node_name>. | |
Currently, this feature is only available for Jupyter Notebook (>7.0), Jupyter Lab, IPython, | |
and VSCode Notebook. This line magic will generate code in multiple cells to load | |
datasets from `DataCatalog`, import relevant functions and modules, node function | |
definition and a function call. If generating code is not possible, it will print | |
the code instead. | |
""" | |
parameters = parse_argstring(magic_load_node, args) | |
node_name = parameters.node | |
cells = _load_node(node_name, pipelines) | |
run_environment = _guess_run_environment() | |
if run_environment == "jupyter": | |
# Only create cells if it is jupyter | |
for cell in cells: | |
_create_cell_with_text(cell, is_jupyter=True) | |
elif run_environment in ("ipython", "vscode"): | |
# Combine multiple cells into one | |
combined_cell = "\n\n".join(cells) | |
_create_cell_with_text(combined_cell, is_jupyter=False) | |
else: | |
_print_cells(cells) | |
class _NodeBoundArguments(inspect.BoundArguments): | |
"""Similar to inspect.BoundArguments""" | |
def __init__( | |
self, signature: inspect.Signature, arguments: OrderedDict[str, Any] | |
) -> None: | |
super().__init__(signature, arguments) | |
@property | |
def input_params_dict(self) -> dict[str, str] | None: | |
"""A mapping of {variable name: dataset_name}""" | |
var_positional_arg_name = self._find_var_positional_arg() | |
inputs_params_dict = {} | |
for param, dataset_name in self.arguments.items(): | |
if param == var_positional_arg_name: | |
# If the argument is *args, use the dataset name instead | |
for arg in dataset_name: | |
inputs_params_dict[arg] = arg | |
else: | |
inputs_params_dict[param] = dataset_name | |
return inputs_params_dict | |
def _find_var_positional_arg(self) -> str | None: | |
"""Find the name of the VAR_POSITIONAL argument( *args), if any.""" | |
for k, v in self.signature.parameters.items(): | |
if v.kind == inspect.Parameter.VAR_POSITIONAL: | |
return k | |
return None | |
def _create_cell_with_text(text: str, is_jupyter: bool = True) -> None: | |
if is_jupyter: | |
from ipylab import JupyterFrontEnd | |
app = JupyterFrontEnd() | |
# Noted this only works with Notebook >7.0 or Jupyter Lab. It doesn't work with | |
# VS Code Notebook due to imcompatible backends. | |
app.commands.execute("notebook:insert-cell-below") | |
app.commands.execute("notebook:replace-selection", {"text": text}) | |
else: | |
get_ipython().set_next_input(text) # type: ignore[no-untyped-call] | |
def _print_cells(cells: list[str]) -> None: | |
for cell in cells: | |
Console().print("") | |
Console().print(Syntax(cell, "python", theme="monokai", line_numbers=False)) | |
def _load_node(node_name: str, pipelines: _ProjectPipelines) -> list[str]: | |
"""Prepare the code to load dataset from catalog, import statements and function body. | |
Args: | |
node_name (str): The name of the node. | |
Returns: | |
list[str]: A list of string which is the generated code, each string represent a | |
notebook cell. | |
""" | |
warnings.warn( | |
"This is an experimental feature, only Jupyter Notebook (>7.0), Jupyter Lab, IPython, and VSCode Notebook " | |
"are supported. If you encounter unexpected behaviour or would like to suggest " | |
"feature enhancements, add it under this github issue https://github.com/kedro-org/kedro/issues/3580" | |
) | |
node = _find_node(node_name, pipelines) | |
node_func = node.func | |
imports_cell = _prepare_imports(node_func) | |
function_definition_cell = _prepare_function_body(node_func) | |
node_bound_arguments = _get_node_bound_arguments(node) | |
inputs_params_mapping = _prepare_node_inputs(node_bound_arguments) | |
node_inputs_cell = _format_node_inputs_text(inputs_params_mapping) | |
function_call_cell = _prepare_function_call(node_func, node_bound_arguments) | |
cells: list[str] = [] | |
if node_inputs_cell: | |
cells.append(node_inputs_cell) | |
cells.append(imports_cell) | |
cells.append(function_definition_cell) | |
cells.append(function_call_cell) | |
return cells | |
def _find_node(node_name: str, pipelines: _ProjectPipelines) -> Node: | |
for pipeline in pipelines.values(): | |
try: | |
found_node: Node = pipeline.filter(node_names=[node_name]).nodes[0] | |
return found_node | |
except ValueError: | |
continue | |
# If reached the node was not found in the project | |
raise ValueError( | |
f"Node with name='{node_name}' not found in any pipelines. Remember to specify the node name, not the node function." | |
) | |
def _prepare_imports(node_func: Callable) -> str: | |
"""Prepare the import statements for loading a node.""" | |
python_file = inspect.getsourcefile(node_func) | |
logger.info(f"Loading node definition from {python_file}") | |
# Confirm source file was found | |
if python_file: | |
import_statement = [] | |
with open(python_file) as file: | |
# Parse any line start with from or import statement | |
for line in file.readlines(): | |
if line.startswith("from") or line.startswith("import"): | |
import_statement.append(line.strip()) | |
clean_imports = "\n".join(import_statement).strip() | |
return clean_imports | |
else: | |
raise FileNotFoundError(f"Could not find {node_func.__name__}") | |
def _get_node_bound_arguments(node: Node) -> _NodeBoundArguments: | |
node_func = node.func | |
node_inputs = node.inputs | |
args, kwargs = Node._process_inputs_for_bind(node_inputs) | |
signature = inspect.signature(node_func) | |
bound_arguments = signature.bind(*args, **kwargs) | |
return _NodeBoundArguments(bound_arguments.signature, bound_arguments.arguments) | |
def _prepare_node_inputs( | |
node_bound_arguments: _NodeBoundArguments, | |
) -> dict[str, str] | None: | |
# Remove the *args. For example {'first_arg':'a', 'args': ('b','c')} | |
# will be loaded as follow: | |
# first_arg = catalog.load("a") | |
# b = catalog.load("b") # It doesn't have an arg name, so use the dataset name instead. | |
# c = catalog.load("c") | |
return node_bound_arguments.input_params_dict | |
def _format_node_inputs_text(input_params_dict: dict[str, str] | None) -> str | None: | |
statements = [ | |
"# Prepare necessary inputs for debugging", | |
"# All debugging inputs must be defined in your project catalog", | |
] | |
if not input_params_dict: | |
return None | |
for func_param, dataset_name in input_params_dict.items(): | |
statements.append(f'{func_param} = catalog.load("{dataset_name}")') | |
input_statements = "\n".join(statements) | |
return input_statements | |
def _prepare_function_body(func: Callable) -> str: | |
source_lines, _ = inspect.getsourcelines(func) | |
body = "".join(source_lines) | |
return body | |
def _prepare_function_call( | |
node_func: Callable, node_bound_arguments: _NodeBoundArguments | |
) -> str: | |
"""Prepare the text for the function call.""" | |
func_name = node_func.__name__ | |
args = node_bound_arguments.input_params_dict | |
kwargs = node_bound_arguments.kwargs | |
# Construct the statement of func_name(a=1,b=2,c=3) | |
args_str_literal = [f"{node_input}" for node_input in args] if args else [] | |
kwargs_str_literal = [ | |
f"{node_input}={dataset_name}" for node_input, dataset_name in kwargs.items() | |
] | |
func_params = ", ".join(args_str_literal + kwargs_str_literal) | |
body = f"""{func_name}({func_params})""" | |
return body |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment