Skip to content

Instantly share code, notes, and snippets.

@goodside
Created October 8, 2022 05:06
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save goodside/454169131c93ab5c57a9cdfac0754028 to your computer and use it in GitHub Desktop.
Save goodside/454169131c93ab5c57a9cdfac0754028 to your computer and use it in GitHub Desktop.
"""
Toy demonstration of chain-of-thought and consensus prompting using OpenAI API.
© Riley Goodside 2022
"""
import os
import re
from statistics import mode
import openai
try:
openai.api_key = os.environ["OPENAI_API_KEY"]
except KeyError:
raise RuntimeError("Please set the OPENAI_API_KEY environment variable.")
def complete(prompt: str, **kwargs):
defaults = {"engine": "text-davinci-002"}
kwargs = defaults | kwargs
response = openai.Completion.create(prompt=prompt, **kwargs)
return response.choices[0].text.strip()
def calculate_step_by_step(question: str, max_retries=3) -> str:
"""
Answer a math question via chain-of-thought prompting.
Retry until the result looks like a single number.
"""
prompt = f"Q: {question}\nA: Let's think step by step."
long_answer = complete(prompt, max_tokens=128, temperature=0.5)
extraction_prompt = (
f"{prompt} {long_answer}\nTherefore, the answer (Arabic numerals) is"
)
short_answer = complete(extraction_prompt, max_tokens=32, temperature=0)
short_answer = short_answer.strip().rstrip(".").replace(",", "").split("=")[-1]
try:
short_answer = re.findall(r"-?\d+\.?\d*", short_answer)[0]
except IndexError:
if max_retries > 0:
return calculate_step_by_step(question, max_retries - 1)
else:
raise RuntimeError(f"Could not extract answer from '{short_answer}'")
return short_answer
def answer_by_consensus(question: str, n=10) -> str:
return mode(calculate_step_by_step(question) for _ in range(n))
EXAMPLE_QUESTION = """\
Q: What is x + 6 * y^5 where x is the sum of the squares of the individual digits of the
release year of Miley Cyrus's "Bangerz" and y is the day-of-month portion of Harry
Styles's birthday?\
"""
if __name__ == "__main__":
print("Q:", EXAMPLE_QUESTION)
print("A:", answer_by_consensus(EXAMPLE_QUESTION))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment