Skip to content

Instantly share code, notes, and snippets.

@inayet
Forked from seanchatmangpt/gen_module.py
Created February 13, 2024 15:39
Show Gist options
  • Save inayet/6ae560b82059a3651100df2aa9850f8a to your computer and use it in GitHub Desktop.
Save inayet/6ae560b82059a3651100df2aa9850f8a to your computer and use it in GitHub Desktop.
A module with self-correction
import logging # Import the logging module
from dspy import Module, OpenAI, settings, ChainOfThought, Assert
logger = logging.getLogger(__name__) # Create a logger instance
logger.setLevel(logging.ERROR) # Set the logger's level to ERROR or the appropriate level
class GenModule(Module):
def __init__(self, output_key, input_keys: list[str] = None, lm=None):
if lm is None:
lm = OpenAI(max_tokens=500)
settings.configure(lm=lm)
if input_keys is None:
self.input_keys = ["prompt"]
super().__init__()
self.output_key = output_key
# Define the generation and correction queries based on generation_type
self.signature = ', '.join(self.input_keys) + f" -> {self.output_key}"
self.correction_signature = ', '.join(self.input_keys) + f", error -> {self.output_key}"
# DSPy modules for generation and correction
self.generate = ChainOfThought(self.signature)
self.correct_generate = ChainOfThought(self.correction_signature)
def forward(self, **kwargs):
# Generate the output using provided inputs
gen_result = self.generate(**kwargs)
output = gen_result.get(self.output_key)
# Try validating the output
try:
return self.validate_output(output)
except (AssertionError, ValueError) as error:
logger.error(error)
# Correction attempt
corrected_result = self.correct_generate(**kwargs, error=str(error))
corrected_output = corrected_result.get(self.output_key)
return self.validate_output(corrected_output)
def validate_output(self, output):
# Implement validation logic or override in subclass
raise NotImplementedError("Validation logic should be implemented in subclass")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment