Skip to content

Instantly share code, notes, and snippets.

@johnjosephhorton
Created January 9, 2024 01:05
Show Gist options
  • Save johnjosephhorton/9cb426fd349c9328a0071cf23510126d to your computer and use it in GitHub Desktop.
Save johnjosephhorton/9cb426fd349c9328a0071cf23510126d to your computer and use it in GitHub Desktop.
Proposed question re-factor
from __future__ import annotations
import re
import textwrap
from abc import ABC, abstractmethod
from jinja2 import Template, Environment, meta
from typing import Any, Type, Union
from edsl.exceptions import (
QuestionAnswerValidationError,
QuestionAttributeMissing,
QuestionResponseValidationError,
QuestionSerializationError,
QuestionScenarioRenderError,
)
from edsl.questions.question_registry import get_question_class
from edsl.questions.utils import LLMResponse
from edsl.utilities.utilities import HTMLSnippet
class Question(ABC):
""" """
@property
def data(self):
""" "Data is a dictionary of all the attributes of the question, except for the question_type"""
return {k.replace("_", "", 1): v for k, v in self.__dict__.items()}
def to_dict(self) -> dict:
"""Converts a dictionary and adds in the question type"""
data = self.data.copy()
data["question_type"] = self.question_type
return data
@classmethod
def from_dict(cls, data: dict) -> Question:
"""Constructs a Question from the dictionary created by the `to_dict` method"""
local_data = data.copy()
try:
question_type = local_data.pop("question_type")
except:
raise QuestionSerializationError(
"Question data does not have a 'question_type' field"
)
question_class = get_question_class(question_type)
return question_class(**local_data)
def __repr__(self):
class_name = self.__class__.__name__.replace("Enhanced", "")
items = [
f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}"
for k, v in self.data.items()
if k != "question_type"
]
return f"{class_name}({', '.join(items)})"
@abstractmethod
def validate_answer(self, answer: dict[str, str]):
pass
# TODO: Throws an error that should be addressed at QuestionFunctional
def __add__(self, other_question):
"""
Composes two questions into a single question.
>>> from edsl.scenarios.Scenario import Scenario
>>> from edsl.questions.QuestionFreeText import QuestionFreeText
>>> from edsl.questions.QuestionNumerical import QuestionNumerical
>>> q1 = QuestionFreeText(question_text = "What is the capital of {{country}}", question_name = "capital")
>>> q2 = QuestionNumerical(question_text = "What is the population of {{capital}}, in millions. Please round", question_name = "population")
>>> q3 = q1 + q2
>>> Scenario({"country": "France"}).to(q3).run().select("capital_population")
['2']
"""
from edsl.questions import compose_questions
return compose_questions(self, other_question)
@property
@abstractmethod
def instructions(self) -> str: # pragma: no cover
"""
Instructions for each question.
- the values are question type-specific
- the templating standard is Jinja2.
- it is necessary to include both "answer" and "comment" as the only keys.
- Note: children should implement this method as a property.
An example for `QuestionFreeText`:
You are being asked the following question: {{question_text}}
Return a valid JSON formatted like this:
{"answer": "<your free text answer>", "comment": "<put explanation here>"}
"""
pass
@staticmethod
def scenario_render(text: str, scenario_dict: dict) -> str:
"""
Scenarios come in as dictionaries. This function goes through the text of a question
we are presenting to the LLM and replaces the variables with the values from the scenario.
Because we allow for nesting, we need to do this many times.
- We hard-code in a nesting limit of 100 just because if we hit that, it's probably a bug and not that some lunatic has actually made a 100-deep nested question.
"""
t = text
MAX_NESTING = 100
counter = 0
while True:
counter += 1
new_t = Template(t).render(scenario_dict)
if new_t == t:
break
t = new_t
if counter > MAX_NESTING:
raise QuestionScenarioRenderError(
"Too much nesting - you created an infnite loop here, pal"
)
return new_t
def get_prompt(self, scenario=None) -> str:
"""Shows which prompt should be used with the LLM for this question.
It extracts the question attributes from the instantiated question data model.
"""
if scenario is None:
scenario = {}
template = Template(self.instructions)
template_with_attributes = template.render(self.data)
env = Environment()
ast = env.parse(template_with_attributes)
undeclared_variables = meta.find_undeclared_variables(ast)
if any([v not in scenario for v in undeclared_variables]):
raise QuestionScenarioRenderError(
f"Scenario is missing variables: {undeclared_variables}"
)
prompt = self.scenario_render(template_with_attributes, scenario)
return prompt
@abstractmethod
def translate_answer_code_to_answer(self): # pragma: no cover
"""Translates the answer code to the actual answer. Behavior depends on the question type."""
pass
@abstractmethod
def simulate_answer(self, human_readable=True) -> dict: # pragma: no cover
"""Simulates a valid answer for debugging purposes (what the validator expects)"""
pass
def formulate_prompt(self, traits=None, focal_item=None):
"""
Builds the prompt to send to the LLM. The system prompt contains:
- Context that might be helpful
- The traits of the agent
- The focal item and a description of what it is.
"""
system_prompt = ""
instruction_part = textwrap.dedent(
"""\
You are answering questions as if you were a human.
Do not break character.
"""
)
system_prompt += instruction_part
if traits is not None:
relevant_trait = traits.relevant_traits(self)
traits_part = f"Your traits are: {relevant_trait}"
system_prompt += traits_part
prompt = ""
if focal_item is not None:
focal_item_prompt_fragment = textwrap.dedent(
f"""\
The question you will be asked will be about a {focal_item.meta_description}.
The particular one you are responding to is: {focal_item.content}.
"""
)
prompt += focal_item_prompt_fragment
prompt += self.get_prompt()
return prompt, system_prompt
################
# Question -> Survey methods
################
def add_question(self, other):
"Adds a question to this question by turning them into a survey with two questions"
from edsl.surveys.Survey import Survey
s = Survey([self, other], [self.question_name, other.question_name])
return s
def run(self, *args, **kwargs):
"Turns a single question into a survey and run it."
from edsl.surveys.Survey import Survey
s = Survey([self], [self.question_name])
return s.run(*args, **kwargs)
def by(self, *args):
"This turns a single question into a survey and runs it."
from edsl.surveys.Survey import Survey
s = Survey([self], [self.question_name])
return s.by(*args)
import random
import textwrap
from jinja2 import Template
from typing import Optional, Type
from edsl.questions import Question
from edsl.exceptions import QuestionAnswerValidationError
from edsl.utilities.utilities import random_string
from edsl.utilities.utilities import is_valid_variable_name
MAX_OPTIONS = 10
class QuestionMultipleChoice(Question):
"""QuestionMultipleChoice"""
question_type = "multiple_choice"
def __init__(self, question_text, question_options, question_name):
self.question_text = question_text
self.question_options = question_options
self.question_name = question_name
@property
def question_name(self):
return self._question_name
@question_name.setter
def question_name(self, new_question_name):
"Validates the question name"
if not is_valid_variable_name(new_question_name):
raise Exception("Question name is not a valid variable name!")
self._question_name = new_question_name
@property
def question_options(self):
return self._question_options
@question_options.setter
def question_options(self, new_question_options):
"Validates the question options"
if len(new_question_options) > 10:
raise Exception("Question options are too long!")
if len(new_question_options) < 2:
raise Exception("Question options are too short!")
if not all(isinstance(x, str) for x in new_question_options):
raise Exception("Question options must be strings!")
self._question_options = new_question_options
@property
def question_text(self):
return self._question_text
@question_text.setter
def question_text(self, new_question_text):
"Validates the question text"
if len(new_question_text) > 1000:
raise Exception("Question is too long!")
self._question_text = new_question_text
def validate_response(self, response: dict[str, str]):
return response
def validate_answer(self, answer: dict[str, str]):
if "answer" not in answer:
raise QuestionAnswerValidationError(
f"Answer {answer} does not contain 'answer' key."
)
if "comment" not in answer:
raise QuestionAnswerValidationError(
f"Answer {answer} does not contain 'comment' key."
)
if answer["answer"] not in range(len(self.question_options)):
raise QuestionAnswerValidationError(
f"Answer {answer} is not a valid option."
)
return answer
@property
def instructions(self) -> str:
return textwrap.dedent(
"""\
You are being asked the following question: {{question_text}}
The options are
{% for option in question_options %}
{{ loop.index0 }}: {{option}}
{% endfor %}
Return a valid JSON formatted like this, selecting only the number of the option:
{"answer": <put answer code here>, "comment": "<put explanation here>"}
Only 1 option may be selected.
"""
)
################
# Less important
################
def translate_answer_code_to_answer(self, answer_code, scenario=None):
"""
Translates the answer code to the actual answer.
For example, for question_options ["a", "b", "c"], the answer codes are 0, 1, and 2.
The LLM will respond with 0, and this code will translate that to "a".
# TODO: REMOVE
>>> q = QuestionMultipleChoice(question_text = "How are you?", question_options = ["Good", "Great", "OK", "Bad"], question_name = "how_feeling")
>>> q.translate_answer_code_to_answer(0, {})
'Good'
"""
scenario = scenario or dict()
translated_options = [
Template(str(option)).render(scenario) for option in self.question_options
]
return translated_options[int(answer_code)]
def simulate_answer(self, human_readable=True) -> dict[str, str]:
"""Simulates a valid answer for debugging purposes"""
if human_readable:
answer = random.choice(self.question_options)
else:
answer = random.choice(range(len(self.question_options)))
return {
"answer": answer,
"comment": random_string(),
}
if __name__ == "__main__":
q = QuestionMultipleChoice(
question_text="Do you enjoying eating custard while skydiving?",
question_options=["yes, somtimes", "no", "only on Tuesdays"],
question_name="goose_fight",
)
results = q.run()
results.select("goose_fight").print()
q_dict = q.to_dict()
print(f"Serialized dictionary:{q_dict}")
new_q = Question.from_dict(q_dict)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment