Created
March 1, 2024 02:46
-
-
Save PaulLockett/44d93020a3ae1ff562dce1e8d35288e0 to your computer and use it in GitHub Desktop.
DSPy Json extraction metric
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import dspy | |
class AssessJsonExtractSig(dspy.Signature): | |
trial_json = dspy.InputField(desc="The trial extraxted JSON") | |
ideal_json = dspy.InputField(desc="The ideal extraxted JSON") | |
input_str = dspy.InputField(desc="The str to extract the JSON from") | |
closeness = dspy.OutputField(desc="a float between 0.0 and 1.0; how closely is the trial_json is the same as the ideal_json") | |
class AssessJsonExtract(dspy.Module): | |
"""Assess the quality of a JSON extraction.""" | |
def __init__(self): | |
super().__init__() | |
self.generate_answer = dspy.ChainOfThought(AssessJsonExtractSig) | |
def extract_decimal_prefix(self, closeness_str): | |
""" | |
Extracts the decimal number prefix from a string if present. | |
If the string does not start with a decimal number, returns the original string. | |
Parameters: | |
closeness_str (str): The input string potentially starting with a decimal number. | |
Returns: | |
str: The extracted decimal number as a string, or the original string if no decimal prefix is found. | |
""" | |
import re | |
# Regular expression to match a decimal number at the start of the string | |
decimal_prefix_pattern = r'^\s*(\d*\.\d+|\d+\.?)\s*' | |
match = re.match(decimal_prefix_pattern, closeness_str) | |
if match: | |
# Extract and return the decimal number as a string | |
return match.group(1) | |
else: | |
# Return the original string if no decimal prefix is found | |
return closeness_str | |
def isDecimal(self, value): | |
try: | |
float(value) | |
return True | |
except ValueError: | |
return False | |
def forward(self, trial_json, ideal_json, input_str): | |
pred = self.generate_answer(trial_json=trial_json, ideal_json=ideal_json, input_str=input_str) | |
closeness = self.extract_decimal_prefix(pred.closeness) | |
dspy.Suggest( | |
self.isDecimal(closeness), | |
"closeness must be a decimal number between 0.0 and 1.0 and no other explaining text", | |
) | |
dspy.Suggest( | |
all(char.isdigit() or char == '.' for char in closeness), | |
"closeness should not have any characters other than numbers and a decimal point", | |
) | |
return closeness | |
def is_json_valid(json_string): | |
try: | |
json.loads(json_string) | |
return True | |
except ValueError: | |
return False | |
from dspy.primitives.assertions import assert_transform_module, backtrack_handler | |
import functools | |
def metric(gold, pred, trace=None): | |
input, ideal_answer, model_output = gold.input, gold.JSON, pred.json | |
correct = is_json_valid(model_output) | |
alignment = model_output == ideal_answer | |
if alignment: | |
score = (correct + alignment) if correct else 0 | |
else: | |
AssessJsonExtractWithBacktrack = assert_transform_module(AssessJsonExtract(), | |
functools.partial(backtrack_handler, max_backtracks=3)) | |
with dspy.context(lm=strongModel): | |
closeness = AssessJsonExtractWithBacktrack(trial_json=model_output, ideal_json=ideal_answer, input_str=input) | |
try: | |
closeness = float(closeness) | |
except ValueError: | |
closeness = 1.0 | |
score = (correct + alignment + closeness) if correct else 0 | |
if trace is not None: return score >= 3 | |
return score / 3.0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment