-
-
Save inayet/6ae560b82059a3651100df2aa9850f8a to your computer and use it in GitHub Desktop.
A module with self-correction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging # Import the logging module | |
from dspy import Module, OpenAI, settings, ChainOfThought, Assert | |
logger = logging.getLogger(__name__) # Create a logger instance | |
logger.setLevel(logging.ERROR) # Set the logger's level to ERROR or the appropriate level | |
class GenModule(Module): | |
def __init__(self, output_key, input_keys: list[str] = None, lm=None): | |
if lm is None: | |
lm = OpenAI(max_tokens=500) | |
settings.configure(lm=lm) | |
if input_keys is None: | |
self.input_keys = ["prompt"] | |
super().__init__() | |
self.output_key = output_key | |
# Define the generation and correction queries based on generation_type | |
self.signature = ', '.join(self.input_keys) + f" -> {self.output_key}" | |
self.correction_signature = ', '.join(self.input_keys) + f", error -> {self.output_key}" | |
# DSPy modules for generation and correction | |
self.generate = ChainOfThought(self.signature) | |
self.correct_generate = ChainOfThought(self.correction_signature) | |
def forward(self, **kwargs): | |
# Generate the output using provided inputs | |
gen_result = self.generate(**kwargs) | |
output = gen_result.get(self.output_key) | |
# Try validating the output | |
try: | |
return self.validate_output(output) | |
except (AssertionError, ValueError) as error: | |
logger.error(error) | |
# Correction attempt | |
corrected_result = self.correct_generate(**kwargs, error=str(error)) | |
corrected_output = corrected_result.get(self.output_key) | |
return self.validate_output(corrected_output) | |
def validate_output(self, output): | |
# Implement validation logic or override in subclass | |
raise NotImplementedError("Validation logic should be implemented in subclass") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment