-
-
Save johnjosephhorton/9cb426fd349c9328a0071cf23510126d to your computer and use it in GitHub Desktop.
Proposed question re-factor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import re | |
import textwrap | |
from abc import ABC, abstractmethod | |
from jinja2 import Template, Environment, meta | |
from typing import Any, Type, Union | |
from edsl.exceptions import ( | |
QuestionAnswerValidationError, | |
QuestionAttributeMissing, | |
QuestionResponseValidationError, | |
QuestionSerializationError, | |
QuestionScenarioRenderError, | |
) | |
from edsl.questions.question_registry import get_question_class | |
from edsl.questions.utils import LLMResponse | |
from edsl.utilities.utilities import HTMLSnippet | |
class Question(ABC): | |
""" """ | |
@property | |
def data(self): | |
""" "Data is a dictionary of all the attributes of the question, except for the question_type""" | |
return {k.replace("_", "", 1): v for k, v in self.__dict__.items()} | |
def to_dict(self) -> dict: | |
"""Converts a dictionary and adds in the question type""" | |
data = self.data.copy() | |
data["question_type"] = self.question_type | |
return data | |
@classmethod | |
def from_dict(cls, data: dict) -> Question: | |
"""Constructs a Question from the dictionary created by the `to_dict` method""" | |
local_data = data.copy() | |
try: | |
question_type = local_data.pop("question_type") | |
except: | |
raise QuestionSerializationError( | |
"Question data does not have a 'question_type' field" | |
) | |
question_class = get_question_class(question_type) | |
return question_class(**local_data) | |
def __repr__(self): | |
class_name = self.__class__.__name__.replace("Enhanced", "") | |
items = [ | |
f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" | |
for k, v in self.data.items() | |
if k != "question_type" | |
] | |
return f"{class_name}({', '.join(items)})" | |
@abstractmethod | |
def validate_answer(self, answer: dict[str, str]): | |
pass | |
# TODO: Throws an error that should be addressed at QuestionFunctional | |
def __add__(self, other_question): | |
""" | |
Composes two questions into a single question. | |
>>> from edsl.scenarios.Scenario import Scenario | |
>>> from edsl.questions.QuestionFreeText import QuestionFreeText | |
>>> from edsl.questions.QuestionNumerical import QuestionNumerical | |
>>> q1 = QuestionFreeText(question_text = "What is the capital of {{country}}", question_name = "capital") | |
>>> q2 = QuestionNumerical(question_text = "What is the population of {{capital}}, in millions. Please round", question_name = "population") | |
>>> q3 = q1 + q2 | |
>>> Scenario({"country": "France"}).to(q3).run().select("capital_population") | |
['2'] | |
""" | |
from edsl.questions import compose_questions | |
return compose_questions(self, other_question) | |
@property | |
@abstractmethod | |
def instructions(self) -> str: # pragma: no cover | |
""" | |
Instructions for each question. | |
- the values are question type-specific | |
- the templating standard is Jinja2. | |
- it is necessary to include both "answer" and "comment" as the only keys. | |
- Note: children should implement this method as a property. | |
An example for `QuestionFreeText`: | |
You are being asked the following question: {{question_text}} | |
Return a valid JSON formatted like this: | |
{"answer": "<your free text answer>", "comment": "<put explanation here>"} | |
""" | |
pass | |
@staticmethod | |
def scenario_render(text: str, scenario_dict: dict) -> str: | |
""" | |
Scenarios come in as dictionaries. This function goes through the text of a question | |
we are presenting to the LLM and replaces the variables with the values from the scenario. | |
Because we allow for nesting, we need to do this many times. | |
- We hard-code in a nesting limit of 100 just because if we hit that, it's probably a bug and not that some lunatic has actually made a 100-deep nested question. | |
""" | |
t = text | |
MAX_NESTING = 100 | |
counter = 0 | |
while True: | |
counter += 1 | |
new_t = Template(t).render(scenario_dict) | |
if new_t == t: | |
break | |
t = new_t | |
if counter > MAX_NESTING: | |
raise QuestionScenarioRenderError( | |
"Too much nesting - you created an infnite loop here, pal" | |
) | |
return new_t | |
def get_prompt(self, scenario=None) -> str: | |
"""Shows which prompt should be used with the LLM for this question. | |
It extracts the question attributes from the instantiated question data model. | |
""" | |
if scenario is None: | |
scenario = {} | |
template = Template(self.instructions) | |
template_with_attributes = template.render(self.data) | |
env = Environment() | |
ast = env.parse(template_with_attributes) | |
undeclared_variables = meta.find_undeclared_variables(ast) | |
if any([v not in scenario for v in undeclared_variables]): | |
raise QuestionScenarioRenderError( | |
f"Scenario is missing variables: {undeclared_variables}" | |
) | |
prompt = self.scenario_render(template_with_attributes, scenario) | |
return prompt | |
@abstractmethod | |
def translate_answer_code_to_answer(self): # pragma: no cover | |
"""Translates the answer code to the actual answer. Behavior depends on the question type.""" | |
pass | |
@abstractmethod | |
def simulate_answer(self, human_readable=True) -> dict: # pragma: no cover | |
"""Simulates a valid answer for debugging purposes (what the validator expects)""" | |
pass | |
def formulate_prompt(self, traits=None, focal_item=None): | |
""" | |
Builds the prompt to send to the LLM. The system prompt contains: | |
- Context that might be helpful | |
- The traits of the agent | |
- The focal item and a description of what it is. | |
""" | |
system_prompt = "" | |
instruction_part = textwrap.dedent( | |
"""\ | |
You are answering questions as if you were a human. | |
Do not break character. | |
""" | |
) | |
system_prompt += instruction_part | |
if traits is not None: | |
relevant_trait = traits.relevant_traits(self) | |
traits_part = f"Your traits are: {relevant_trait}" | |
system_prompt += traits_part | |
prompt = "" | |
if focal_item is not None: | |
focal_item_prompt_fragment = textwrap.dedent( | |
f"""\ | |
The question you will be asked will be about a {focal_item.meta_description}. | |
The particular one you are responding to is: {focal_item.content}. | |
""" | |
) | |
prompt += focal_item_prompt_fragment | |
prompt += self.get_prompt() | |
return prompt, system_prompt | |
################ | |
# Question -> Survey methods | |
################ | |
def add_question(self, other): | |
"Adds a question to this question by turning them into a survey with two questions" | |
from edsl.surveys.Survey import Survey | |
s = Survey([self, other], [self.question_name, other.question_name]) | |
return s | |
def run(self, *args, **kwargs): | |
"Turns a single question into a survey and run it." | |
from edsl.surveys.Survey import Survey | |
s = Survey([self], [self.question_name]) | |
return s.run(*args, **kwargs) | |
def by(self, *args): | |
"This turns a single question into a survey and runs it." | |
from edsl.surveys.Survey import Survey | |
s = Survey([self], [self.question_name]) | |
return s.by(*args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import textwrap | |
from jinja2 import Template | |
from typing import Optional, Type | |
from edsl.questions import Question | |
from edsl.exceptions import QuestionAnswerValidationError | |
from edsl.utilities.utilities import random_string | |
from edsl.utilities.utilities import is_valid_variable_name | |
MAX_OPTIONS = 10 | |
class QuestionMultipleChoice(Question): | |
"""QuestionMultipleChoice""" | |
question_type = "multiple_choice" | |
def __init__(self, question_text, question_options, question_name): | |
self.question_text = question_text | |
self.question_options = question_options | |
self.question_name = question_name | |
@property | |
def question_name(self): | |
return self._question_name | |
@question_name.setter | |
def question_name(self, new_question_name): | |
"Validates the question name" | |
if not is_valid_variable_name(new_question_name): | |
raise Exception("Question name is not a valid variable name!") | |
self._question_name = new_question_name | |
@property | |
def question_options(self): | |
return self._question_options | |
@question_options.setter | |
def question_options(self, new_question_options): | |
"Validates the question options" | |
if len(new_question_options) > 10: | |
raise Exception("Question options are too long!") | |
if len(new_question_options) < 2: | |
raise Exception("Question options are too short!") | |
if not all(isinstance(x, str) for x in new_question_options): | |
raise Exception("Question options must be strings!") | |
self._question_options = new_question_options | |
@property | |
def question_text(self): | |
return self._question_text | |
@question_text.setter | |
def question_text(self, new_question_text): | |
"Validates the question text" | |
if len(new_question_text) > 1000: | |
raise Exception("Question is too long!") | |
self._question_text = new_question_text | |
def validate_response(self, response: dict[str, str]): | |
return response | |
def validate_answer(self, answer: dict[str, str]): | |
if "answer" not in answer: | |
raise QuestionAnswerValidationError( | |
f"Answer {answer} does not contain 'answer' key." | |
) | |
if "comment" not in answer: | |
raise QuestionAnswerValidationError( | |
f"Answer {answer} does not contain 'comment' key." | |
) | |
if answer["answer"] not in range(len(self.question_options)): | |
raise QuestionAnswerValidationError( | |
f"Answer {answer} is not a valid option." | |
) | |
return answer | |
@property | |
def instructions(self) -> str: | |
return textwrap.dedent( | |
"""\ | |
You are being asked the following question: {{question_text}} | |
The options are | |
{% for option in question_options %} | |
{{ loop.index0 }}: {{option}} | |
{% endfor %} | |
Return a valid JSON formatted like this, selecting only the number of the option: | |
{"answer": <put answer code here>, "comment": "<put explanation here>"} | |
Only 1 option may be selected. | |
""" | |
) | |
################ | |
# Less important | |
################ | |
def translate_answer_code_to_answer(self, answer_code, scenario=None): | |
""" | |
Translates the answer code to the actual answer. | |
For example, for question_options ["a", "b", "c"], the answer codes are 0, 1, and 2. | |
The LLM will respond with 0, and this code will translate that to "a". | |
# TODO: REMOVE | |
>>> q = QuestionMultipleChoice(question_text = "How are you?", question_options = ["Good", "Great", "OK", "Bad"], question_name = "how_feeling") | |
>>> q.translate_answer_code_to_answer(0, {}) | |
'Good' | |
""" | |
scenario = scenario or dict() | |
translated_options = [ | |
Template(str(option)).render(scenario) for option in self.question_options | |
] | |
return translated_options[int(answer_code)] | |
def simulate_answer(self, human_readable=True) -> dict[str, str]: | |
"""Simulates a valid answer for debugging purposes""" | |
if human_readable: | |
answer = random.choice(self.question_options) | |
else: | |
answer = random.choice(range(len(self.question_options))) | |
return { | |
"answer": answer, | |
"comment": random_string(), | |
} | |
if __name__ == "__main__": | |
q = QuestionMultipleChoice( | |
question_text="Do you enjoying eating custard while skydiving?", | |
question_options=["yes, somtimes", "no", "only on Tuesdays"], | |
question_name="goose_fight", | |
) | |
results = q.run() | |
results.select("goose_fight").print() | |
q_dict = q.to_dict() | |
print(f"Serialized dictionary:{q_dict}") | |
new_q = Question.from_dict(q_dict) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment