Skip to content

Instantly share code, notes, and snippets.

@PaulLockett
Created March 1, 2024 02:46
Show Gist options
  • Save PaulLockett/44d93020a3ae1ff562dce1e8d35288e0 to your computer and use it in GitHub Desktop.
Save PaulLockett/44d93020a3ae1ff562dce1e8d35288e0 to your computer and use it in GitHub Desktop.
DSPy Json extraction metric
import json
import dspy
class AssessJsonExtractSig(dspy.Signature):
trial_json = dspy.InputField(desc="The trial extraxted JSON")
ideal_json = dspy.InputField(desc="The ideal extraxted JSON")
input_str = dspy.InputField(desc="The str to extract the JSON from")
closeness = dspy.OutputField(desc="a float between 0.0 and 1.0; how closely is the trial_json is the same as the ideal_json")
class AssessJsonExtract(dspy.Module):
"""Assess the quality of a JSON extraction."""
def __init__(self):
super().__init__()
self.generate_answer = dspy.ChainOfThought(AssessJsonExtractSig)
def extract_decimal_prefix(self, closeness_str):
"""
Extracts the decimal number prefix from a string if present.
If the string does not start with a decimal number, returns the original string.
Parameters:
closeness_str (str): The input string potentially starting with a decimal number.
Returns:
str: The extracted decimal number as a string, or the original string if no decimal prefix is found.
"""
import re
# Regular expression to match a decimal number at the start of the string
decimal_prefix_pattern = r'^\s*(\d*\.\d+|\d+\.?)\s*'
match = re.match(decimal_prefix_pattern, closeness_str)
if match:
# Extract and return the decimal number as a string
return match.group(1)
else:
# Return the original string if no decimal prefix is found
return closeness_str
def isDecimal(self, value):
try:
float(value)
return True
except ValueError:
return False
def forward(self, trial_json, ideal_json, input_str):
pred = self.generate_answer(trial_json=trial_json, ideal_json=ideal_json, input_str=input_str)
closeness = self.extract_decimal_prefix(pred.closeness)
dspy.Suggest(
self.isDecimal(closeness),
"closeness must be a decimal number between 0.0 and 1.0 and no other explaining text",
)
dspy.Suggest(
all(char.isdigit() or char == '.' for char in closeness),
"closeness should not have any characters other than numbers and a decimal point",
)
return closeness
def is_json_valid(json_string):
try:
json.loads(json_string)
return True
except ValueError:
return False
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
import functools
def metric(gold, pred, trace=None):
input, ideal_answer, model_output = gold.input, gold.JSON, pred.json
correct = is_json_valid(model_output)
alignment = model_output == ideal_answer
if alignment:
score = (correct + alignment) if correct else 0
else:
AssessJsonExtractWithBacktrack = assert_transform_module(AssessJsonExtract(),
functools.partial(backtrack_handler, max_backtracks=3))
with dspy.context(lm=strongModel):
closeness = AssessJsonExtractWithBacktrack(trial_json=model_output, ideal_json=ideal_answer, input_str=input)
try:
closeness = float(closeness)
except ValueError:
closeness = 1.0
score = (correct + alignment + closeness) if correct else 0
if trace is not None: return score >= 3
return score / 3.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment