Skip to content

Instantly share code, notes, and snippets.

@brockmanmatt
Last active July 26, 2020 04:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brockmanmatt/9e6cb08b5742108df3f5896ecafcdafa to your computer and use it in GitHub Desktop.
Save brockmanmatt/9e6cb08b5742108df3f5896ecafcdafa to your computer and use it in GitHub Desktop.
FindaWordThatRhymes.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "FindaWordThatRhymes.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyPYWOM9ErPrrbg8umKEfDVT",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/brockmanmatt/9e6cb08b5742108df3f5896ecafcdafa/findawordthatrhymes.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "J7wnsgT2kPut",
"colab_type": "code",
"colab": {
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
"ok": true,
"headers": [
[
"content-type",
"application/javascript"
]
],
"status": 200,
"status_text": ""
}
},
"base_uri": "https://localhost:8080/",
"height": 89
},
"outputId": "4d96e8e3-b41f-4441-eaa2-2f9b5d2470cf"
},
"source": [
"from google.colab import files\n",
"uploaded = files.upload()\n",
"print(\"done\")"
],
"execution_count": 1,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-2c7c85f7-ebbb-49b5-939c-41d237cda743\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-2c7c85f7-ebbb-49b5-939c-41d237cda743\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script src=\"/nbextensions/google.colab/files.js\"></script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Saving key.json to key.json\n",
"done\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WHPHrUnhpKnI",
"colab_type": "text"
},
"source": [
"I'll install the API"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zq0ltp2xn4yt",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 292
},
"outputId": "d87d9902-c98c-43f7-e3ac-33f02916f2f3"
},
"source": [
"!pip install openai\n",
"import openai, json, pandas as pd, numpy as np"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting openai\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a8/65/c7461f4c87984534683f480ea5742777bc39bbf5721123194c2d0347dc1f/openai-0.2.4.tar.gz (157kB)\n",
"\r\u001b[K |██ | 10kB 13.1MB/s eta 0:00:01\r\u001b[K |████▏ | 20kB 1.7MB/s eta 0:00:01\r\u001b[K |██████▎ | 30kB 2.4MB/s eta 0:00:01\r\u001b[K |████████▍ | 40kB 2.6MB/s eta 0:00:01\r\u001b[K |██████████▍ | 51kB 2.0MB/s eta 0:00:01\r\u001b[K |████████████▌ | 61kB 2.2MB/s eta 0:00:01\r\u001b[K |██████████████▋ | 71kB 2.5MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 81kB 2.7MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 92kB 2.9MB/s eta 0:00:01\r\u001b[K |████████████████████▉ | 102kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 112kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 122kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 133kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▏ | 143kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▎| 153kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 163kB 2.8MB/s \n",
"\u001b[?25hRequirement already satisfied: requests>=2.20 in /usr/local/lib/python3.6/dist-packages (from openai) (2.23.0)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (2020.6.20)\n",
"Building wheels for collected packages: openai\n",
" Building wheel for openai (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for openai: filename=openai-0.2.4-cp36-none-any.whl size=170709 sha256=fa007004b1de497170ec6b15a124b6957783c1d059b1b38a3ebb196696b575de\n",
" Stored in directory: /root/.cache/pip/wheels/74/96/c8/c6e170929c276b836613e1b9985343b501fe455e53d85e7d48\n",
"Successfully built openai\n",
"Installing collected packages: openai\n",
"Successfully installed openai-0.2.4\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q2yE0jcnpMEV",
"colab_type": "text"
},
"source": [
"Loading in key.json that I uploaded; I do this so I don't need to worry about accidently leaking creds if I share the colab (which I'm 99% sure is just a json file that won't expose them)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "bwNXXwHen5x9",
"colab_type": "code",
"colab": {}
},
"source": [
"openai.api_key = json.load(open(\"key.json\", \"r\"))[\"key\"]"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "k67w5H0fpTkT",
"colab_type": "text"
},
"source": [
"Default keyword arguments to pass the aPI"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e1EwpqqJkTYh",
"colab_type": "code",
"colab": {}
},
"source": [
"#arguments to send the API\n",
"kwargs = {\n",
"\"engine\":\"davinci\",\n",
"\"temperature\":0,\n",
"\"max_tokens\":10,\n",
"\"stop\":\"\\n\",\n",
"}"
],
"execution_count": 4,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "b0zgRWqAy1GA",
"colab_type": "text"
},
"source": [
"So what what we might think is is for each sentence we generate, we might want to make sure that the log prob is above some threshhold. We can also check if words are in there and if so, discard.\n",
"\n",
"For instance, let's say that we want to generate a rhyme"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7AdSYZJYueSj",
"colab_type": "code",
"colab": {}
},
"source": [
"prompt = \"\"\"The following word rhymes with happy: snappy\n",
"The following word rhymes with dog:\"\"\"\n",
"kwargs[\"logprobs\"] = 1\n",
"kwargs[\"max_tokens\"] = 5\n",
"kwargs[\"temperature\"] = .5\n",
"kwargs[\"stop\"] = \"\\n\""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_kYaKW9u1l7Q",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 272
},
"outputId": "be83120f-8a96-4250-9b01-1dafa39346b8"
},
"source": [
"myWords = []\n",
"attempt = 0\n",
"while len(myWords) < 10:\n",
" attempt += 1\n",
" r = openai.Completion.create(prompt=prompt, **kwargs)\n",
" newWord = r[\"choices\"][0][\"text\"]\n",
" if attempt % 10 == 0:\n",
" print(\"word: {}\\nFound: {}\\nAttempts: {}\".format(newWord, len(myWords), attempt))\n",
" if newWord.lower() not in myWords:\n",
" myWords.append(newWord)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"word: fog\n",
"Found: 4\n",
"Attempts: 10\n",
"word: log\n",
"Found: 6\n",
"Attempts: 20\n",
"word: fog\n",
"Found: 6\n",
"Attempts: 30\n",
"word: god\n",
"Found: 7\n",
"Attempts: 40\n",
"word: haggard\n",
"Found: 8\n",
"Attempts: 50\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LJxHkt5i1maw",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "9dd85ef5-f50c-4b03-d8be-f97884064238"
},
"source": [
"attempt, \",\".join(myWords)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(51, ' log, fog, bag, bog, baggy, god, hoggish, dog, haggard, hound')"
]
},
"metadata": {
"tags": []
},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Vonc2oDf3KOK",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85
},
"outputId": "0818b6e4-9bb8-4d7d-df51-dac9d4b55db5"
},
"source": [
"kwargs[\"temperature\"] = 1\n",
"myWords = []\n",
"attempt = 0\n",
"while len(myWords) < 10:\n",
" attempt += 1\n",
" r = openai.Completion.create(prompt=prompt, **kwargs)\n",
" newWord = r[\"choices\"][0][\"text\"].strip()\n",
" if attempt % 10 == 0:\n",
" print(\"word: {}\\nFound: {}\\nAttempts: {}\".format(newWord, len(myWords), attempt))\n",
" if newWord.lower() not in myWords:\n",
" myWords.append(newWord)\n",
"print(\"{} attempts: {}\".format(attempt, \",\".join(myWords)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"word: thog\n",
"Found: 8\n",
"Attempts: 10\n",
"11 attempts: flog,wog,shlug,bling,sloth,yodle,jogg,hog,thog,fogg\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "7VRiL4Tf5l6Y",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 187
},
"outputId": "6de378c5-8d89-4498-bf62-6824feb37679"
},
"source": [
"import random\n",
"\n",
"prefix = \"\"\"The following word rhymes with happy: snappy\\nThe following word rhymes with snap: trap\\n\"\"\"\n",
"suffix = \"The following word rhymes with dog:\"\n",
"prompt = prefix + suffix\n",
"kwargs[\"logprobs\"] = 1\n",
"kwargs[\"max_tokens\"] = 5\n",
"kwargs[\"temperature\"] = .5\n",
"kwargs[\"stop\"] = \"\\n\"\n",
"kwargs[\"presence_penalty\"] = 1\n",
"\n",
"myWords = []\n",
"attempt = 0\n",
"examples = []\n",
"while len(myWords) < 10:\n",
" attempt += 1\n",
" r = openai.Completion.create(prompt=prompt, **kwargs)\n",
" newWord = r[\"choices\"][0][\"text\"].strip()\n",
" if attempt % 10 == 0:\n",
" print(\"word: {}\\nFound: {}\\nAttempts: {}\".format(newWord, len(myWords), attempt))\n",
" if newWord.lower() not in myWords:\n",
" myWords.append(newWord)\n",
" newExample = \"{} {}\".format(suffix, newWord)\n",
" examples.append(newExample)\n",
" if len(examples) > 3:\n",
" prompt = \"\"\n",
" for example in random.sample(examples, k=min(10, len(examples))):\n",
" prompt += example + \"\\n\"\n",
" prompt += suffix\n",
"\n",
"print(\"{} attempts: {}\".format(attempt, \",\".join(myWords)))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"word: log\n",
"Found: 3\n",
"Attempts: 10\n",
"word: log\n",
"Found: 3\n",
"Attempts: 20\n",
"word: flog\n",
"Found: 8\n",
"Attempts: 30\n",
"32 attempts: log,god,bog,fog,plog,hog,slog,shlog,flog,zog\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HRaI_015AjTR",
"colab_type": "text"
},
"source": [
"let's redo that, except use the logprobs"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wHklFpu8i-OX",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 136
},
"outputId": "9531f1bb-a2c8-4f29-ac67-83afa276a104"
},
"source": [
"import random\n",
"\n",
"prefix = \"\"\"The following word rhymes with happy: snappy\\nThe following word rhymes with snap: trap\\n\"\"\"\n",
"suffix = \"The following word rhymes with dog:\"\n",
"prompt = prefix + suffix\n",
"kwargs[\"logprobs\"] = 1\n",
"kwargs[\"max_tokens\"] = 5\n",
"kwargs[\"temperature\"] = .5\n",
"kwargs[\"stop\"] = \"\\n\"\n",
"kwargs[\"presence_penalty\"] = 1\n",
"\n",
"myWords = []\n",
"attempt = 0\n",
"examples = []\n",
"example_logprobs = []\n",
"while len(myWords) < 10:\n",
" attempt += 1\n",
" r = openai.Completion.create(prompt=prompt, **kwargs)\n",
" newWord = r[\"choices\"][0][\"text\"].strip()\n",
" if attempt % 10 == 0:\n",
" print(\"word: {}\\nFound: {}\\nAttempts: {}\".format(newWord, len(myWords), attempt))\n",
" if newWord.lower() not in myWords:\n",
" myWords.append(newWord)\n",
" newExample = \"{} {}\".format(suffix, newWord)\n",
" examples.append(newExample)\n",
"\n",
" # get logprob for examples\n",
" r_tokens = r[\"choices\"][0][\"logprobs\"][\"tokens\"]\n",
" logprobs = r[\"choices\"][0][\"logprobs\"][\"token_logprobs\"]\n",
" if \"\\n\" in r_tokens:\n",
" logprobs = logprobs[:r_tokens.index(\"\\n\")]\n",
" example_logprobs.append((newWord, np.mean(logprobs)))\n",
"\n",
"\n",
" if len(examples) > 3:\n",
" prompt = \"\"\n",
" for example in random.sample(examples, k=min(10, len(examples))):\n",
" prompt += example + \"\\n\"\n",
" prompt += suffix\n",
"\n",
"print(\"{} attempts: {}\".format(attempt, \",\".join(myWords)))"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"text": [
"word: bog\n",
"Found: 3\n",
"Attempts: 10\n",
"word: sod\n",
"Found: 6\n",
"Attempts: 20\n",
"25 attempts: log,god,bog,fog,hog,nod,sod,dog,frog,cod\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "69YOktXxnroc",
"colab_type": "code",
"colab": {}
},
"source": [
"df = pd.DataFrame(example_logprobs, columns=[\"word\", \"logprob\"])"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QRmXNuXhn-3I",
"colab_type": "code",
"colab": {}
},
"source": [
"df[\"%\"] = df.logprob.apply(lambda x: 100*np.e**x)"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "oJND-MMvoGcD",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 359
},
"outputId": "43e70980-6225-41bd-cf24-0961b78ca81c"
},
"source": [
"df"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>word</th>\n",
" <th>logprob</th>\n",
" <th>%</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>log</td>\n",
" <td>-1.130146</td>\n",
" <td>32.298610</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>god</td>\n",
" <td>-2.627420</td>\n",
" <td>7.226464</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>bog</td>\n",
" <td>-2.385975</td>\n",
" <td>9.199924</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>fog</td>\n",
" <td>-2.689186</td>\n",
" <td>6.793622</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>hog</td>\n",
" <td>-3.782045</td>\n",
" <td>2.277606</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>nod</td>\n",
" <td>-3.258999</td>\n",
" <td>3.842685</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>sod</td>\n",
" <td>-4.729776</td>\n",
" <td>0.882844</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>dog</td>\n",
" <td>-2.877510</td>\n",
" <td>5.627471</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>frog</td>\n",
" <td>-3.591209</td>\n",
" <td>2.756497</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>cod</td>\n",
" <td>-3.758816</td>\n",
" <td>2.331133</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" word logprob %\n",
"0 log -1.130146 32.298610\n",
"1 god -2.627420 7.226464\n",
"2 bog -2.385975 9.199924\n",
"3 fog -2.689186 6.793622\n",
"4 hog -3.782045 2.277606\n",
"5 nod -3.258999 3.842685\n",
"6 sod -4.729776 0.882844\n",
"7 dog -2.877510 5.627471\n",
"8 frog -3.591209 2.756497\n",
"9 cod -3.758816 2.331133"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YnN6ZfpnQLjD",
"colab_type": "text"
},
"source": [
"what if set temp to 0"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment