Skip to content

Instantly share code, notes, and snippets.

@AlexTMallen
Last active April 14, 2024 03:21
Show Gist options
  • Save AlexTMallen/77190b0971deca739acb11f2c73c3212 to your computer and use it in GitHub Desktop.
Save AlexTMallen/77190b0971deca739acb11f2c73c3212 to your computer and use it in GitHub Desktop.
from tqdm import tqdm
import random
import pickle
from openai import OpenAI
import matplotlib.pyplot as plt
import numpy as np
def get_search_problem(list_len, mode):
assert list_len < 10
available_digits = list("123456789"[:list_len])
idxs = list(range(list_len))
random.shuffle(idxs)
pairs = []
for i, idx in enumerate(idxs):
av_digits = available_digits.copy()
digit = av_digits[idx]
random.shuffle(av_digits)
new_idx = av_digits.index(digit)
pairs.append((new_idx, list(map(int, av_digits))))
alpha = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
prompt = "\n".join([f"{alpha[i]}. ({idx}, {digits})" for i, (idx, digits) in enumerate(pairs)])
instructions = "Given the list of `(idx, digits)` pairs, tell me for which pair is `digits[idx]` largest (indexing from 0)."
if mode == "list-first":
prompt += f"\n\n{instructions} Answer only with the letter."
elif mode == "list-last":
prompt = f"{instructions}\n\n{prompt}\n\nAnswer only with the letter."
elif mode == "CoT":
prompt += "\n\nGiven the list of `(idx, digits)` pairs, tell me for which pair is the `idx`th digit of `digits` largest (indexing from 0). Explain and conclude your answer with \"Answer: \" and the letter."
elif mode == "list-first-filler":
pairs_str = ", ".join([f"pair {alpha[i]}" for i in range(len(pairs))])
prompt += "\n\nGiven the list of `(idx, digits)` pairs, tell me for which pair is `digits[idx]` largest (indexing from 0). " \
f"I will now say \"{pairs_str}\", to give you some time to think. On \"pair A\", think about what the " \
"value of digits[idx] is for pair A, and so on for the remaining pairs. Then when answering you will be able " \
f"to recover the value from the corresponding pairs by paying attention to them. Here is the list of strings I " \
f"told you I would say: {pairs_str}. Please answer now. Answer only with the letter."
elif mode == "report-value-first":
idx, digits = random.choice(pairs)
prompt = f"({idx}, {digits})\n\nGiven the above pair `(idx, digits)`, tell me what is `digits[idx]` (indexing from 0). Answer only with the digit."
elif mode == "report-value-last":
idx, digits = random.choice(pairs)
prompt = f"Given the below pair `(idx, digits)`, tell me what is `digits[idx]` (indexing from 0).\n\n({idx}, {digits})\n\nAnswer only with the digit."
else:
raise ValueError(f"Unknown mode {mode}")
if mode.startswith("report-value"):
correct_answer = str(digits[idx])
candidates = list(map(str, digits))
else:
correct_answer = alpha[max(range(len(pairs)), key=lambda i: int(pairs[i][1][pairs[i][0]]))]
candidates = list(alpha[:len(pairs)])
return prompt, correct_answer, candidates
client = OpenAI()
results = dict()
n_trials = 100
for mode in ["CoT", "list-first", "list-last", "report-value-first", "report-value-last"]:
for list_len in [3, 5]:
for model in ["gpt-3.5-turbo-0125", "gpt-4-turbo-2024-04-09"]:
cost_per_m = {
"gpt-3.5-turbo-0125": 0.5,
"gpt-4-turbo-2024-04-09": 10.,
}[model]
correct_probs = []
total_usage = 0
for i in tqdm(range(n_trials), total=n_trials):
prompt, correct_ans, candidates = get_search_problem(list_len, mode)
top_logprobs = 1 if mode.startswith("CoT") else 8
max_tokens = 1024 if mode.startswith("CoT") else 1
stop = None
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "user", "content": prompt},
],
logprobs=True,
top_logprobs=top_logprobs,
max_tokens=max_tokens,
stop=stop,
)
if mode.startswith("CoT"):
if completion.choices[0].finish_reason == "length":
print("Ran out of tokens")
response = completion.choices[0].message.content
if len(response) < 20:
print(f"Response didn't do CoT: {response}")
ans_idx = response.find("Answer: ") + len("Answer: ")
ans = response[ans_idx:ans_idx + 1]
if ans not in candidates:
print(f"Answer \"{ans}\" not in candidates {candidates}")
correct_prob = 1 if ans == correct_ans else 0
else:
first_token_logprobs = completion.choices[0].logprobs.content[0]
probs = {k: 0 for k in candidates}
for cand in candidates:
for item in first_token_logprobs.top_logprobs:
if item.token == cand:
probs[cand] = np.exp(item.logprob)
break
ans = completion.choices[0].message.content
if ans not in candidates:
print(f"Answer \"{ans}\" not in candidates {candidates}")
total_prob = sum(probs.values())
correct_prob = probs[correct_ans] / total_prob
correct_probs.append(correct_prob)
total_usage += completion.usage.total_tokens
results[(mode, list_len, model)] = correct_probs
print(f"Average probability of correct answer: {np.mean(correct_probs)}")
print(f"API Usage: {total_usage} (${total_usage * cost_per_m / 1e6})")
with open("lm_search_results.pkl", "wb") as f:
pickle.dump(results, f)
modes = ["CoT", "list-first", "list-last", "report-value-first", "report-value-last"]
mode_names = {
"list-first": "internal search\n(instructions after list)",
"list-last": "internal search\n(instructions before list)",
"CoT": "chain of thought\n(instructions after list)",
"report-value-first": "compute value\n(instructions after pair)",
"report-value-last": "compute value\n(instructions before pair)",
}
list_lengths = [3, 5]
# model = "gpt-3.5-turbo-0125"
model = "gpt-4-turbo-2024-04-09"
colors = ['blue', 'orange']
plt.figure(dpi=200)
width = 0.35 # Width of the bars
for idx, list_len in enumerate(list_lengths):
random_perf = 1 / list_len
perfs = [np.mean(results[(mode, list_len, model)]) for mode in modes]
stderrs = [2 * np.std(results[(mode, list_len, model)]) / np.sqrt(len(results[(mode, list_len, "gpt-3.5-turbo-0125")])) for mode in modes]
xpos = np.arange(len(modes)) + idx * width
plt.bar(xpos, perfs, width, label=f'List length {list_len}', color=colors[idx], alpha=0.8, yerr=stderrs, capsize=5)
plt.axhline(random_perf, color=colors[idx], linestyle="--", label=f"Random (list length {list_len})")
plt.axhline(1, color="black", linestyle="-")
plt.ylabel("P(correct)")
plt.title(f"{model} performance on search problems")
plt.xticks(np.arange(len(modes)) + width/2, [mode_names[mode] for mode in modes], rotation=45)
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment