Last active
April 14, 2024 03:21
-
-
Save AlexTMallen/77190b0971deca739acb11f2c73c3212 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
import random | |
import pickle | |
from openai import OpenAI | |
import matplotlib.pyplot as plt | |
import numpy as np | |
def get_search_problem(list_len, mode): | |
assert list_len < 10 | |
available_digits = list("123456789"[:list_len]) | |
idxs = list(range(list_len)) | |
random.shuffle(idxs) | |
pairs = [] | |
for i, idx in enumerate(idxs): | |
av_digits = available_digits.copy() | |
digit = av_digits[idx] | |
random.shuffle(av_digits) | |
new_idx = av_digits.index(digit) | |
pairs.append((new_idx, list(map(int, av_digits)))) | |
alpha = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
prompt = "\n".join([f"{alpha[i]}. ({idx}, {digits})" for i, (idx, digits) in enumerate(pairs)]) | |
instructions = "Given the list of `(idx, digits)` pairs, tell me for which pair is `digits[idx]` largest (indexing from 0)." | |
if mode == "list-first": | |
prompt += f"\n\n{instructions} Answer only with the letter." | |
elif mode == "list-last": | |
prompt = f"{instructions}\n\n{prompt}\n\nAnswer only with the letter." | |
elif mode == "CoT": | |
prompt += "\n\nGiven the list of `(idx, digits)` pairs, tell me for which pair is the `idx`th digit of `digits` largest (indexing from 0). Explain and conclude your answer with \"Answer: \" and the letter." | |
elif mode == "list-first-filler": | |
pairs_str = ", ".join([f"pair {alpha[i]}" for i in range(len(pairs))]) | |
prompt += "\n\nGiven the list of `(idx, digits)` pairs, tell me for which pair is `digits[idx]` largest (indexing from 0). " \ | |
f"I will now say \"{pairs_str}\", to give you some time to think. On \"pair A\", think about what the " \ | |
"value of digits[idx] is for pair A, and so on for the remaining pairs. Then when answering you will be able " \ | |
f"to recover the value from the corresponding pairs by paying attention to them. Here is the list of strings I " \ | |
f"told you I would say: {pairs_str}. Please answer now. Answer only with the letter." | |
elif mode == "report-value-first": | |
idx, digits = random.choice(pairs) | |
prompt = f"({idx}, {digits})\n\nGiven the above pair `(idx, digits)`, tell me what is `digits[idx]` (indexing from 0). Answer only with the digit." | |
elif mode == "report-value-last": | |
idx, digits = random.choice(pairs) | |
prompt = f"Given the below pair `(idx, digits)`, tell me what is `digits[idx]` (indexing from 0).\n\n({idx}, {digits})\n\nAnswer only with the digit." | |
else: | |
raise ValueError(f"Unknown mode {mode}") | |
if mode.startswith("report-value"): | |
correct_answer = str(digits[idx]) | |
candidates = list(map(str, digits)) | |
else: | |
correct_answer = alpha[max(range(len(pairs)), key=lambda i: int(pairs[i][1][pairs[i][0]]))] | |
candidates = list(alpha[:len(pairs)]) | |
return prompt, correct_answer, candidates | |
client = OpenAI() | |
results = dict() | |
n_trials = 100 | |
for mode in ["CoT", "list-first", "list-last", "report-value-first", "report-value-last"]: | |
for list_len in [3, 5]: | |
for model in ["gpt-3.5-turbo-0125", "gpt-4-turbo-2024-04-09"]: | |
cost_per_m = { | |
"gpt-3.5-turbo-0125": 0.5, | |
"gpt-4-turbo-2024-04-09": 10., | |
}[model] | |
correct_probs = [] | |
total_usage = 0 | |
for i in tqdm(range(n_trials), total=n_trials): | |
prompt, correct_ans, candidates = get_search_problem(list_len, mode) | |
top_logprobs = 1 if mode.startswith("CoT") else 8 | |
max_tokens = 1024 if mode.startswith("CoT") else 1 | |
stop = None | |
completion = client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "user", "content": prompt}, | |
], | |
logprobs=True, | |
top_logprobs=top_logprobs, | |
max_tokens=max_tokens, | |
stop=stop, | |
) | |
if mode.startswith("CoT"): | |
if completion.choices[0].finish_reason == "length": | |
print("Ran out of tokens") | |
response = completion.choices[0].message.content | |
if len(response) < 20: | |
print(f"Response didn't do CoT: {response}") | |
ans_idx = response.find("Answer: ") + len("Answer: ") | |
ans = response[ans_idx:ans_idx + 1] | |
if ans not in candidates: | |
print(f"Answer \"{ans}\" not in candidates {candidates}") | |
correct_prob = 1 if ans == correct_ans else 0 | |
else: | |
first_token_logprobs = completion.choices[0].logprobs.content[0] | |
probs = {k: 0 for k in candidates} | |
for cand in candidates: | |
for item in first_token_logprobs.top_logprobs: | |
if item.token == cand: | |
probs[cand] = np.exp(item.logprob) | |
break | |
ans = completion.choices[0].message.content | |
if ans not in candidates: | |
print(f"Answer \"{ans}\" not in candidates {candidates}") | |
total_prob = sum(probs.values()) | |
correct_prob = probs[correct_ans] / total_prob | |
correct_probs.append(correct_prob) | |
total_usage += completion.usage.total_tokens | |
results[(mode, list_len, model)] = correct_probs | |
print(f"Average probability of correct answer: {np.mean(correct_probs)}") | |
print(f"API Usage: {total_usage} (${total_usage * cost_per_m / 1e6})") | |
with open("lm_search_results.pkl", "wb") as f: | |
pickle.dump(results, f) | |
modes = ["CoT", "list-first", "list-last", "report-value-first", "report-value-last"] | |
mode_names = { | |
"list-first": "internal search\n(instructions after list)", | |
"list-last": "internal search\n(instructions before list)", | |
"CoT": "chain of thought\n(instructions after list)", | |
"report-value-first": "compute value\n(instructions after pair)", | |
"report-value-last": "compute value\n(instructions before pair)", | |
} | |
list_lengths = [3, 5] | |
# model = "gpt-3.5-turbo-0125" | |
model = "gpt-4-turbo-2024-04-09" | |
colors = ['blue', 'orange'] | |
plt.figure(dpi=200) | |
width = 0.35 # Width of the bars | |
for idx, list_len in enumerate(list_lengths): | |
random_perf = 1 / list_len | |
perfs = [np.mean(results[(mode, list_len, model)]) for mode in modes] | |
stderrs = [2 * np.std(results[(mode, list_len, model)]) / np.sqrt(len(results[(mode, list_len, "gpt-3.5-turbo-0125")])) for mode in modes] | |
xpos = np.arange(len(modes)) + idx * width | |
plt.bar(xpos, perfs, width, label=f'List length {list_len}', color=colors[idx], alpha=0.8, yerr=stderrs, capsize=5) | |
plt.axhline(random_perf, color=colors[idx], linestyle="--", label=f"Random (list length {list_len})") | |
plt.axhline(1, color="black", linestyle="-") | |
plt.ylabel("P(correct)") | |
plt.title(f"{model} performance on search problems") | |
plt.xticks(np.arange(len(modes)) + width/2, [mode_names[mode] for mode in modes], rotation=45) | |
plt.legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment