Last active
April 3, 2024 13:54
-
-
Save dharrawal/7ec18464f5d4b6aaa18074b220e550ef to your computer and use it in GitHub Desktop.
A simple framework for logging calls to DSPy programs in a minimally obtrusive way
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
DSPy Logging Utilities | |
Author: Dhar Rawal | |
Works with DSPy to log forward calls and their results, using a custom handler function. | |
Works with typed predictors too! | |
""" | |
import functools | |
import json | |
from typing import Any, Callable, Dict, Optional, Tuple | |
import dspy | |
from pydantic import BaseModel | |
class DSPyProgramLog(BaseModel): | |
"""DSPy Program Log""" | |
dspy_program_class: str | |
dspy_input_args: Tuple[Any, ...] = () | |
dspy_input_kwargs: Dict[str, Any] = {} | |
dspy_completions_dict: Dict[str, Any] = {} | |
# dspy_module_logs: list[DSPyModuleLog] = [] | |
class DSPyForward: # pylint: disable=too-few-public-methods | |
"""DSPy Forward Interceptor""" | |
# class variable for custom handler | |
save_dspyprogramlog_func: Optional[Callable[[DSPyProgramLog], None]] = None | |
@classmethod | |
def intercept(cls, func: Callable) -> Callable: | |
""" | |
Decorator to log forward calls and their results, using a custom handler function. | |
Using __call__(...) enables the class itself to be used as a decorator. | |
""" | |
@functools.wraps(func) | |
def wrapper(*args: Any, **kwargs: Any) -> Any: | |
dspy_program_log: DSPyProgramLog = DSPyProgramLog(dspy_program_class=func.__qualname__.split(".")[-2]) | |
dspy_program_log.dspy_input_args = args[1:] if args else () | |
dspy_program_log.dspy_input_kwargs = kwargs | |
result: dspy.Prediction = func(*args, **kwargs) | |
if result.completions: | |
dspy_program_log.dspy_completions_dict = ( | |
result.completions._completions # pylint: disable=protected-access | |
) | |
else: | |
dspy_program_log.dspy_completions_dict = {} | |
if DSPyForward.save_dspyprogramlog_func: | |
DSPyForward.save_dspyprogramlog_func( # pylint: disable=abstract-class-instantiated, not-callable | |
dspy_program_log | |
) | |
return result | |
return wrapper | |
class DSPyLogger: | |
"""DSPy Logger""" | |
def __init__(self, save_dspyprogramlog_func: Callable[[DSPyProgramLog], None]): | |
if not callable(save_dspyprogramlog_func): | |
raise ValueError("Custom handler must be a callable function") | |
DSPyForward.save_dspyprogramlog_func = save_dspyprogramlog_func | |
def __enter__(self) -> "DSPyLogger": | |
return self | |
def __exit__(self, _exc_type, _exc_val, _exc_tb) -> None: | |
DSPyForward.save_dspyprogramlog_func = None | |
@staticmethod | |
def default_handler(dspy_program_log: DSPyProgramLog) -> None: | |
"""Default handler to save the dspy program log""" | |
# args_str = ', '.join([repr(a) for a in dspy_program_log.dspy_input_args] + | |
# [f"{k}={v!r}" for k, v in dspy_program_log.dspy_input_kwargs.items()]) | |
# print(args_str) | |
print(f"{dspy_program_log.dspy_program_class}") | |
print(f"{dspy_program_log.dspy_input_args}") | |
print(f"{dspy_program_log.dspy_input_kwargs}") | |
print(f"{json.dumps(dspy_program_log.dspy_completions_dict)}") | |
def _how_to_use(): | |
""" | |
how to use: | |
Use @DSPyForward.intercept to decorate the forward function of your dspy program | |
Call the forward function in the context of the DSPyLogger, passing it a custom save function | |
Your custom save function will get called with an instance of DSPyProgramLog | |
""" | |
class BasicQA(dspy.Module): | |
"""DSPy Module for testing DSPyLogger""" | |
def __init__(self): | |
super().__init__() | |
self.generate_answer = dspy.Predict("topic, question -> answer") | |
gpt3_turbo = dspy.OpenAI(model="gpt-3.5-turbo", api_key="<YOUR_API_KEY>") | |
@DSPyForward.intercept | |
def forward(self, topic, question): | |
"""forward pass""" | |
with dspy.context(lm=BasicQA.gpt3_turbo): | |
return self.generate_answer(topic=topic, question=question) | |
get_answer = BasicQA() | |
# Call the sample_function with logging | |
with DSPyLogger(DSPyLogger.default_handler): | |
_ = get_answer("geography quiz", question="What is the capital of France?") | |
# This will print: | |
# BasicQA | |
# ('geography quiz',) | |
# {'question': 'What is the capital of France?'} | |
# {"answer": ["Topic: geography quiz\nQuestion: What is the capital of France?\nAnswer: Paris"]} | |
if __name__ == "__main__": | |
_how_to_use() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment