This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import openai\n", | |
"openai.api_key = \"blank\"\n", | |
"\n", | |
"from transformers import GPT2TokenizerFast\n", | |
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\", add_prefix_space=True)\n", | |
"\n", | |
"import math" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def runscore(hint, word, use_fakes=False):\n", | |
" preprompt = \"\"\"\n", | |
" I am thinking of something big and full of water.\n", | |
" Am I thinking of ocean? Yes.\n", | |
" Am I thinking of chair? No.\n", | |
" Am I thinking of tree? No.\n", | |
"\n", | |
" I am thinking of something yellow and bright in the sky.\n", | |
" Am I thinking of dog? No.\n", | |
" Am I thinking of sun? Yes.\n", | |
" Am I thinking of shout? No.\n", | |
"\n", | |
" I am thinking of what you do to a soccer ball.\n", | |
" Am I thinking of kick? Yes.\n", | |
" Am I thinking of drink? No.\n", | |
" Am I thinking of house? No.\n", | |
"\n", | |
" I am thinking of {}.\n", | |
" Am I thinking of {}?\"\"\"\n", | |
" fake_words = [\"door\", \"justice\", \"ocean\", \"chair\", \"tree\", \"dog\", \"sun\", \"shout\", \"kick\", \"drink\", \"house\"]\n", | |
"\n", | |
" humanprompt = \"\"\"\n", | |
" I am thinking of [{}].\n", | |
" Do you think I am thinking of {}?\"\"\"\n", | |
"\n", | |
" truth_tokens = tokenizer.tokenize(\"Yes.\")\n", | |
" truth_tokens = [t.replace(\"Ġ\",\" \") for t in truth_tokens]\n", | |
"\n", | |
" allprobs = []\n", | |
" all_words = ([word] + fake_words) if use_fakes else [word]\n", | |
" for w in all_words:\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=preprompt.format(hint, w), temperature=0.1, logprobs=10, max_tokens=1)\n", | |
" probs = response.choices[0].logprobs.top_logprobs\n", | |
" total_prob = 1\n", | |
" correct = True\n", | |
" if ' Yes' in probs[0]:\n", | |
" total_prob = math.exp(probs[0][' Yes'])\n", | |
" else:\n", | |
" total_prob = 0\n", | |
" allprobs.append(total_prob)\n", | |
" \n", | |
" total_score = (allprobs[0] - sum(allprobs[1:])/(len(allprobs)-1)) if use_fakes else allprobs[0]\n", | |
"# print(\"Main score is {:.3f}\".format(allprobs[0]))\n", | |
"# print(\"Other scores are \" + \" \".join(['-{:.3f}'.format(x) for x in allprobs[1:]]))\n", | |
"# print(\"Total is {:.3f}\".format(total_score))\n", | |
" return total_score\n", | |
"\n", | |
"hint = \"a popular drink that keeps you awake and is brown\"\n", | |
"word = \"coffee\"\n", | |
"print(runscore(hint, word))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"forward_prompt = \"\"\"A description of ocean:\n", | |
"something big and full of water.\n", | |
"\n", | |
"A description of sun:\n", | |
"something yellow and bright in the sky.\n", | |
"\n", | |
"A description of kick:\n", | |
"what you do to a soccer ball.\n", | |
"\n", | |
"A description of{}:\n", | |
"\"\"\"\n", | |
"\n", | |
"def best_sentence(word):\n", | |
" word = \" \"+word\n", | |
" logit_bias = {tokenizer.encode(word)[0]: - 100}\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=forward_prompt.format(word), temperature=0,\n", | |
" max_tokens=10, n=1, stop=\".\", logit_bias=logit_bias)\n", | |
" return response.choices[0].text\n", | |
"\n", | |
"def causal_sentence(word, use_fakes=False):\n", | |
" word = \" \"+word\n", | |
" logit_bias = {tokenizer.encode(word)[0]: - 100}\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=forward_prompt.format(word), temperature=0.8,\n", | |
" max_tokens=10, n=20, stop=\".\", logit_bias=logit_bias)\n", | |
" resps = [x.text for x in response.choices]\n", | |
" mx = -1000\n", | |
" best_r = \"[null]\"\n", | |
" for r in resps:\n", | |
" score = runscore(r, word, use_fakes)\n", | |
" if score > mx:\n", | |
" mx = score\n", | |
" best_r = r\n", | |
"# print(score, r)\n", | |
" return best_r\n", | |
"\n", | |
"# hint = \"a popular drink that keeps you awake and is brown\"\n", | |
"# word = \"coffee\"\n", | |
"# print(runscore(hint, word))\n", | |
"causal_sentence(\"coffee\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for w in [\"coffee\", \"singing\", \"earth\", \"head\", \"table\"]:\n", | |
" print()\n", | |
" print(w)\n", | |
" print(\"Top probability sentence:\", best_sentence(w))\n", | |
" print(\"Top causal sentence (no fakes):\", causal_sentence(w))\n", | |
"# print(\"Top causal sentence (w/ fakes):\", causal_sentence(w, use_fakes=True))\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tokenizer.encode(\" coffee\")[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def runjudge(hint):\n", | |
"# preprompt = \"\"\"\n", | |
"# Kids these days are so loud.\n", | |
"# Is this something an old man would say? Yes.\n", | |
"\n", | |
"# Can I copy your homework?\n", | |
"# Is this something an old man would say? No.\n", | |
" \n", | |
"# My bones ache today.\n", | |
"# Is this something an old man would say? Yes.\n", | |
"\n", | |
"# {}.\n", | |
"# Is this something an old man would say?\"\"\"\n", | |
" preprompt = \"\"\"\n", | |
" apple, pomegranate, cherry\n", | |
" Are these all red fruits? Yes.\n", | |
" Are these all blue fruits? No.\n", | |
"\n", | |
" car, truck, skateboard.\n", | |
" Are these all cars? No.\n", | |
" Are these all vehicles with four wheels? Yes.\n", | |
" \n", | |
" dolphin, fish, octopus.\n", | |
" Are these all plants? No.\n", | |
" Are these all underwater animals? Yes.\n", | |
"\n", | |
" book, paper, painting.\n", | |
" Are these all {}?\"\"\"\n", | |
"\n", | |
" truth_tokens = tokenizer.tokenize(\"Yes.\")\n", | |
" truth_tokens = [t.replace(\"Ġ\",\" \") for t in truth_tokens]\n", | |
"\n", | |
" allprobs = []\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=preprompt.format(hint), temperature=0.1, logprobs=10, max_tokens=1)\n", | |
" probs = response.choices[0].logprobs.top_logprobs\n", | |
" total_prob = 1\n", | |
" correct = True\n", | |
" if ' Yes' in probs[0]:\n", | |
" total_prob = math.exp(probs[0][' Yes'])\n", | |
" else:\n", | |
" total_prob = 0\n", | |
" allprobs.append(total_prob)\n", | |
" \n", | |
" total_score = allprobs[0]\n", | |
" return total_score\n", | |
"\n", | |
"hint = \"towels\"\n", | |
"print(runjudge(hint))\n", | |
"\n", | |
"hint = \"square objects\"\n", | |
"print(runjudge(hint))\n", | |
"\n", | |
"hint = \"things in your house\"\n", | |
"print(runjudge(hint))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"forward_prompt = \"\"\"apple, pomegranate, cherry\n", | |
"These are all red fruits.\n", | |
"\n", | |
"car, truck, skateboard.\n", | |
"These are all vehicles with four wheels.\n", | |
"\n", | |
"dolphin, fish, octopus.\n", | |
"These are all underwater animals.\n", | |
"\n", | |
"book, paper, painting.\n", | |
"These are all\"\"\"\n", | |
"\n", | |
"def best_sentence():\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=forward_prompt, temperature=0, max_tokens=10, n=1, stop=\".\")\n", | |
" return response.choices[0].text\n", | |
"\n", | |
"def causal_sentence():\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=forward_prompt, temperature=0.8, max_tokens=10, n=20, stop=\".\")\n", | |
" resps = [x.text for x in response.choices]\n", | |
" mx = -1000\n", | |
" best_r = \"[null]\"\n", | |
" for r in resps:\n", | |
" score = runjudge(r)\n", | |
" if score > mx:\n", | |
" mx = score\n", | |
" best_r = r\n", | |
" print(score, r)\n", | |
" return best_r\n", | |
"\n", | |
"print(best_sentence()+\".\")\n", | |
"print(causal_sentence()+\".\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"forward_prompt = \"\"\"An old man would say:\n", | |
"\"\"\"\n", | |
"\n", | |
"def best_sentence():\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=forward_prompt, temperature=0, max_tokens=40, n=1, stop=\".\")\n", | |
" return response.choices[0].text\n", | |
"\n", | |
"def causal_sentence():\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=forward_prompt, temperature=0.5, max_tokens=40, n=10, stop=\".\")\n", | |
" resps = [x.text for x in response.choices]\n", | |
" mx = -1000\n", | |
" best_r = \"[null]\"\n", | |
" for r in resps:\n", | |
" score = runjudge(r)\n", | |
" if score > mx:\n", | |
" mx = score\n", | |
" best_r = r\n", | |
"# print(score, r)\n", | |
" return best_r\n", | |
"\n", | |
"print(best_sentence()+\".\")\n", | |
"print(causal_sentence()+\".\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def make_story():\n", | |
" starting_prompt = \"Here is a story about apples that an old man would say:\"\n", | |
" for n in range(5):\n", | |
" response = openai.Completion.create(engine=\"davinci\", prompt=starting_prompt, temperature=0.8, max_tokens=40, n=1, stop=\".\")\n", | |
" starting_prompt += response.choices[0].text + \".\"\n", | |
" return starting_prompt\n", | |
"make_story()\n", | |
"\n", | |
"# def make_story_causal():\n", | |
"# starting_prompt = \"Here is a story about apples.\"\n", | |
"# for n in range(10):\n", | |
"# response = openai.Completion.create(engine=\"davinci\", prompt=starting_prompt, temperature=0.8, max_tokens=40, n=20, stop=\".\")\n", | |
"# resps = [x.text for x in response.choices]\n", | |
"# mx = -1000\n", | |
"# best_r = \"[null]\"\n", | |
"# for r in resps:\n", | |
"# score = runjudge(r)\n", | |
"# if score > mx:\n", | |
"# mx = score\n", | |
"# best_r = r\n", | |
"# starting_prompt += best_r + \".\"\n", | |
"# print(starting_prompt)\n", | |
"# return starting_prompt\n", | |
"# make_story_causal()\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "tadpole", | |
"language": "python", | |
"name": "tadpole" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment