Skip to content

Instantly share code, notes, and snippets.

@tadamcz
Created February 3, 2025 20:55
Show Gist options
  • Save tadamcz/e639d7df0663a9c2e15ac7f97275b01b to your computer and use it in GitHub Desktop.
Save tadamcz/e639d7df0663a9c2e15ac7f97275b01b to your computer and use it in GitHub Desktop.
Epoch AI MATH implementation (based on inspect_evals)
from typing import Optional, List
from inspect_ai._eval.registry import task
from inspect_ai._eval.task import Task
from inspect_ai.dataset._sources.hf import hf_dataset
from bench.model import DEFAULT_GRADER_MODEL
from bench.task.hendrycks_math.dataset import filter_dataset, record_to_sample
from bench.task.hendrycks_math.scorer import (
normalized_string_match,
sympy_equiv,
model_graded_equiv,
)
from bench.task.hendrycks_math.solver import math_solver
@task
def hendrycks_math(
levels: Optional[List[str]] = None,
subjects: Optional[List[str]] = None,
fewshot: int = 0,
fewshot_seed: int = 42,
) -> Task:
if levels is None:
levels = []
if subjects is None:
subjects = []
dataset = hf_dataset(
# TODO: change this back to the official dataset if DMCA issues are resolved
"tadamcz/hendrycks___competition_math",
split="test",
trust=True,
sample_fields=record_to_sample,
shuffle=True,
auto_id=True,
)
# Subset the data based on levels and/or subjects
dataset = filter_dataset(dataset=dataset, levels=levels, subjects=subjects)
scorers = [
normalized_string_match(),
sympy_equiv(),
model_graded_equiv(model=DEFAULT_GRADER_MODEL),
]
return Task(
dataset=dataset,
plan=math_solver(fewshot=fewshot, fewshot_seed=fewshot_seed),
scorer=scorers,
epochs=8,
metadata={"inspect-log-public": True},
)
@task(name="MATH level 5")
def hendrycks_math_lvl_5() -> Task:
return hendrycks_math(levels=["5"])
from typing import Dict
from inspect_ai.dataset import Dataset, Sample
from bench.task.hendrycks_math.scorer import remove_boxed, last_boxed_only_string
def filter_dataset(dataset: Dataset, levels: list, subjects: list) -> Dataset:
"""Filters the MATH dataset by levels and/or subjects.
Arguments:
dataset (Dataset): Dataset object to be filtered.
levels (List): List of levels to filter on, 1 to 5.
subjects (List): List of subjects to filter on.
"""
# Filter dataset by levels, if required
levels = levels if isinstance(levels, list) else [levels]
levels = [str(elm) for elm in levels]
if len(levels) > 0:
dataset = dataset.filter(
predicate=lambda sample: sample.metadata["level"] in levels
if sample.metadata is not None
else False,
)
# Filter dataset by subjects, if required
subjects = subjects if isinstance(subjects, list) else [subjects]
if len(subjects) > 0:
dataset = dataset.filter(
predicate=lambda sample: sample.metadata["subject"] in subjects
if sample.metadata is not None
else False,
)
return dataset
def record_to_sample(record: Dict) -> Sample:
return Sample(
input=record["problem"],
target=remove_boxed(last_boxed_only_string(record["solution"])),
metadata={
"level": record["level"].lower().lstrip("level "),
"subject": record["type"].lower(),
"solution": record["solution"],
},
)
def sample_to_fewshot(sample: Sample) -> str:
# Based on https://arxiv.org/pdf/2206.14858 - Appendix D.2
# Tags are capitalized to match the format of the user prompt
prob_str = f"""PROBLEM:\n{sample.input}"""
soln = sample.metadata["solution"] if sample.metadata is not None else None
assert (
soln is not None
), "Solution not found in sample, make sure to include it in the 'sample.metadata' dict."
soln_str = f"""SOLUTION:\n{soln}"""
ans_str = f"""ANSWER: {sample.target}"""
return f"""{prob_str}\n\n{soln_str}\n{ans_str}"""
import logging
import re
import signal
import sympy # type: ignore
from inspect_ai.model import Model
from inspect_ai.model import get_model
from inspect_ai.scorer import (
Score,
AnswerPattern,
CORRECT,
INCORRECT,
)
from inspect_ai.scorer import (
Target,
accuracy,
scorer,
stderr,
)
from inspect_ai.solver import TaskState
from sympy.parsing.latex import parse_latex # type: ignore
logger = logging.getLogger(__name__)
EQUIVALANCE_TEMPLATE = r"""
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications
Examples:
Expression 1: $2x+3$
Expression 2: $3+2x$
Yes
Expression 1: 3/2
Expression 2: 1.5
Yes
Expression 1: $x^2+2x+1$
Expression 2: $y^2+2y+1$
No
Expression 1: $x^2+2x+1$
Expression 2: $(x+1)^2$
Yes
Expression 1: 3245/5
Expression 2: 649
No
(these are actually equal, don't mark them equivalent if you need to
do nontrivial simplifications)
Expression 1: 2/(-3)
Expression 2: -2/3
Yes
(trivial simplifications are allowed)
Expression 1: 72 degrees
Expression 2: 72
Yes
(give benefit of the doubt to units)
Expression 1: 64
Expression 2: 64 square feet
Yes
(give benefit of the doubt to units)
---
YOUR TASK
Respond with only "Yes" or "No" (without quotes). Do not include a rationale.
Expression 1: %(expression1)s
Expression 2: %(expression2)s
""".strip()
SUBSTITUTIONS = [
("an ", ""),
("a ", ""),
(".$", "$"),
("\\$", ""),
(r"\ ", ""),
(" ", ""),
("mbox", "text"),
(",\\text{and}", ","),
("\\text{and}", ","),
("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
"square",
"ways",
"integers",
"dollars",
"mph",
"inches",
"ft",
"hours",
"km",
"units",
"\\ldots",
"sue",
"points",
"feet",
"minutes",
"digits",
"cents",
"degrees",
"cm",
"gm",
"pounds",
"meters",
"meals",
"edges",
"students",
"childrentickets",
"multiples",
"\\text{s}",
"\\text{.}",
"\\text{\ns}",
"\\text{}^2",
"\\text{}^3",
"\\text{\n}",
"\\text{}",
r"\mathrm{th}",
r"^\circ",
r"^{\circ}",
r"\;",
r",\!",
"{,}",
'"',
"\\dots",
]
def _extract_answer_helper(text: str) -> str | None:
"""Helper function to extract answer from a text string after finding 'ANSWER'."""
try:
# Find start of answer after "ANSWER"
answer_start = text.index("ANSWER") + 6
# Skip optional colon and spaces
while answer_start < len(text) and (
text[answer_start] == ":" or text[answer_start].isspace()
):
answer_start += 1
answer = text[answer_start:].strip()
return answer if answer else None
except ValueError:
return None
def extract_answer(completion: str) -> str | None:
"""Extract answer from model completion using string manipulation.
Args:
completion: The model completion text to extract from
Returns:
The extracted answer string, or None if no answer found
"""
# Try LaTeX text pattern first
if "\\text{ANSWER" in completion:
try:
# Find start of answer text
start_idx = completion.index("\\text{ANSWER")
text_end = completion.index("}", start_idx)
# Get the answer text content
answer_text = completion[start_idx : text_end + 1]
# Try to get answer from inside the braces
in_brace_answer = _extract_answer_helper(answer_text[:-1])
if in_brace_answer:
return in_brace_answer
# Otherwise look for content after the closing brace
after_brace = completion[text_end + 1 :].strip()
if "\n" in after_brace:
after_brace = after_brace[: after_brace.index("\n")]
return after_brace if after_brace else None
except ValueError:
pass
lines = completion.split("\n")
# Try markdown bold pattern
for line in lines:
line = line.strip()
if line.startswith("**") and line.endswith("**") and "ANSWER" in line:
# Remove the markdown bold markers
line = line[2:-2].strip()
answer = _extract_answer_helper(line)
if answer:
return answer
continue
# Try simple line pattern
for line in lines:
if "ANSWER" in line:
answer = _extract_answer_helper(line)
if answer:
return answer
continue
return None
async def score_helper(
state: TaskState,
target: Target,
model_graded: bool,
use_sympy: bool = False,
model: Model | None = None,
) -> Score:
answer = extract_answer(state.output.completion)
if answer:
if not model_graded:
correct = await match_helper(
answer=answer,
target=target,
use_sympy=use_sympy,
)
# Ask grader model to judge equivalence
else:
if model is None:
raise ValueError("Model is required for model graded scoring")
prompt = EQUIVALANCE_TEMPLATE % ({"expression1": target.text, "expression2": answer})
result = await model.generate(prompt)
# Return the score
correct = result.completion.strip().lower() == "yes"
score = Score(
value=CORRECT if correct else INCORRECT,
explanation=state.output.completion,
answer=answer,
)
if model_graded and score.metadata is not None:
score.metadata.update({"grader_model_usage": result.usage})
else:
score = Score(
value=INCORRECT,
explanation="Answer not found in model output: " + f"{state.output.completion}",
)
return score
# From here till normalize_final_answer() is borrowed from:
# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py#L144
class timeout:
def __init__(self, seconds=1, error_message="Timeout"):
self.seconds = seconds
self.error_message = error_message
def handle_timeout(self, signum, frame):
raise TimeoutError(self.error_message)
def __enter__(self):
signal.signal(signal.SIGALRM, self.handle_timeout)
signal.alarm(self.seconds)
def __exit__(self, type, value, traceback):
signal.alarm(0)
async def is_equiv_sympy(x1: str, x2: str) -> bool:
"""x1 and x2 are normalized latex string"""
try:
with timeout(seconds=5):
try:
parsed_x1 = parse_latex(x1)
parsed_x2 = parse_latex(x2)
except (
sympy.parsing.latex.errors.LaTeXParsingError,
sympy.SympifyError,
TypeError,
) as e:
logger.debug(f"Couldn't parse one of {x1} or {x2}: {e}")
return False
try:
diff = parsed_x1 - parsed_x2
except TypeError:
logger.debug(f"Couldn't subtract {x1} and {x2}")
return False
try:
if sympy.simplify(diff) == 0:
return True
else:
return False
except (ValueError, TypeError, ZeroDivisionError):
logger.debug(f"Had some trouble simplifying when comparing {x1} and {x2}")
return False
except TimeoutError:
logger.debug(f"Timed out comparing {x1} and {x2}")
return False
except RecursionError:
logger.debug(f"Recursion error comparing {x1} and {x2}")
return False
async def normalize_final_answer(final_answer: str) -> str:
"""
Normalize a final answer to a quantitative reasoning question.
Copied character for character from appendix D of Lewkowycz et al. (2022)
"""
final_answer = final_answer.split("=")[-1]
for before, after in SUBSTITUTIONS:
final_answer = final_answer.replace(before, after)
for expr in REMOVED_EXPRESSIONS:
final_answer = final_answer.replace(expr, "")
# Extract answer that is in LaTeX math, is bold,
# is surrounded by a box, etc.
final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer)
final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer)
final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer)
try:
# If surrounded by `\(` and `\)`, remove them
if final_answer[:2] == "\\(" and final_answer[-2:] == "\\)":
final_answer = final_answer[2:-2]
except IndexError:
pass
# Normalize shorthand TeX:
# \fracab -> \frac{a}{b}
# \frac{abc}{bef} -> \frac{abc}{bef}
# \fracabc -> \frac{a}{b}c
# \sqrta -> \sqrt{a}
# \sqrtab -> sqrt{a}b
final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
final_answer = final_answer.replace("$", "")
# Normalize 100,000 -> 100000
if final_answer.replace(",", "").isdigit():
final_answer = final_answer.replace(",", "")
return final_answer
async def is_equiv(str1: str | None, str2: str | None) -> bool:
if str1 is None and str2 is None:
logger.debug("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
ss1 = await strip_string(str1)
ss2 = await strip_string(str2)
return ss1 == ss2
async def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = await remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = await fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = await fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = await fix_a_slash_b(string)
return string
async def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
async def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except (AssertionError, ValueError):
return string
async def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
async def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def remove_boxed(s):
if "\\boxed " in s:
left = "\\boxed "
assert s[: len(left)] == left
return s[len(left) :]
left = "\\boxed{"
assert s[: len(left)] == left
assert s[-1] == "}"
return s[len(left) : -1]
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if "\\boxed " in string:
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx is None:
retval = None
else:
retval = string[idx : right_brace_idx + 1]
return retval
@scorer(metrics=[accuracy(), stderr()])
def model_graded_equiv(model: Model):
async def score(state: TaskState, target: Target):
return await score_helper(
state=state,
target=target,
model=model,
model_graded=True,
)
return score
@scorer(metrics=[accuracy(), stderr()])
def sympy_equiv():
async def score(state: TaskState, target: Target):
return await score_helper(
state=state,
target=target,
model_graded=False,
use_sympy=True,
)
return score
@scorer(metrics=[accuracy(), stderr()])
def normalized_string_match():
async def score(state: TaskState, target: Target):
return await score_helper(
state=state,
target=target,
model_graded=False,
use_sympy=False,
)
return score
async def match_helper(
answer: str,
target: Target,
use_sympy: bool = False,
) -> bool:
# If the strings already match exactly, we can return True immediately
if answer == target.text:
return True
norm_answer = await normalize_final_answer(answer)
norm_target = await normalize_final_answer(target.text)
if use_sympy:
# Use sympy library for exact match based on https://arxiv.org/pdf/2206.14858
correct = await is_equiv_sympy(norm_answer, norm_target)
else:
correct = await is_equiv(norm_answer, norm_target)
return correct
from inspect_ai.dataset import hf_dataset
from inspect_ai.solver import (
Solver,
generate,
prompt_template,
system_message,
)
from .dataset import record_to_sample, sample_to_fewshot
# Few-shot prompt template partially based on https://arxiv.org/pdf/2206.14858 - Appendix D.2
SYSTEM_W_EXAMPLES_PROMPT_TEMPLATE = """
You will be asked to solve a math problem. Some examples of problems and solutions are provided below.
{examples}
""".strip()
# Setup for problem + instructions for providing answer
USER_PROMPT_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form "ANSWER: $ANSWER" (without quotes) where $ANSWER is the answer to the problem.
{prompt}
Remember to put your answer on its own line at the end in the form "ANSWER: $ANSWER" (without quotes) where $ANSWER is the answer to the problem, and you do not need to use a \\boxed command.
""".strip()
def math_solver(
fewshot: int,
fewshot_seed: int,
) -> list[Solver]:
"""Build solver for MATH task.
Arguments:
fewshot (int): Number of few shot examples to use.
fewshot_seed (int): Random seed for sampling few shot examples.
"""
solver = [prompt_template(USER_PROMPT_TEMPLATE), generate()]
if fewshot:
fewshot_samples = hf_dataset(
"hendrycks/competition_math",
split="train",
trust=True,
sample_fields=record_to_sample,
shuffle=True,
seed=fewshot_seed,
limit=fewshot,
)
solver.insert(
0,
system_message(
SYSTEM_W_EXAMPLES_PROMPT_TEMPLATE.format(
examples="\n\n".join(
[sample_to_fewshot(sample=sample) for sample in fewshot_samples]
)
)
),
)
return solver
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment