Last active
July 26, 2020 04:04
-
-
Save brockmanmatt/9e6cb08b5742108df3f5896ecafcdafa to your computer and use it in GitHub Desktop.
FindaWordThatRhymes.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "name": "FindaWordThatRhymes.ipynb", | |
| "provenance": [], | |
| "collapsed_sections": [], | |
| "authorship_tag": "ABX9TyPYWOM9ErPrrbg8umKEfDVT", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/brockmanmatt/9e6cb08b5742108df3f5896ecafcdafa/findawordthatrhymes.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "J7wnsgT2kPut", | |
| "colab_type": "code", | |
| "colab": { | |
| "resources": { | |
| "http://localhost:8080/nbextensions/google.colab/files.js": { | |
| "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK", | |
| "ok": true, | |
| "headers": [ | |
| [ | |
| "content-type", | |
| "application/javascript" | |
| ] | |
| ], | |
| "status": 200, | |
| "status_text": "" | |
| } | |
| }, | |
| "base_uri": "https://localhost:8080/", | |
| "height": 89 | |
| }, | |
| "outputId": "4d96e8e3-b41f-4441-eaa2-2f9b5d2470cf" | |
| }, | |
| "source": [ | |
| "from google.colab import files\n", | |
| "uploaded = files.upload()\n", | |
| "print(\"done\")" | |
| ], | |
| "execution_count": 1, | |
| "outputs": [ | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/html": [ | |
| "\n", | |
| " <input type=\"file\" id=\"files-2c7c85f7-ebbb-49b5-939c-41d237cda743\" name=\"files[]\" multiple disabled\n", | |
| " style=\"border:none\" />\n", | |
| " <output id=\"result-2c7c85f7-ebbb-49b5-939c-41d237cda743\">\n", | |
| " Upload widget is only available when the cell has been executed in the\n", | |
| " current browser session. Please rerun this cell to enable.\n", | |
| " </output>\n", | |
| " <script src=\"/nbextensions/google.colab/files.js\"></script> " | |
| ], | |
| "text/plain": [ | |
| "<IPython.core.display.HTML object>" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| } | |
| }, | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Saving key.json to key.json\n", | |
| "done\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "WHPHrUnhpKnI", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "I'll install the API" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "zq0ltp2xn4yt", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 292 | |
| }, | |
| "outputId": "d87d9902-c98c-43f7-e3ac-33f02916f2f3" | |
| }, | |
| "source": [ | |
| "!pip install openai\n", | |
| "import openai, json, pandas as pd, numpy as np" | |
| ], | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "Collecting openai\n", | |
| "\u001b[?25l Downloading https://files.pythonhosted.org/packages/a8/65/c7461f4c87984534683f480ea5742777bc39bbf5721123194c2d0347dc1f/openai-0.2.4.tar.gz (157kB)\n", | |
| "\r\u001b[K |██ | 10kB 13.1MB/s eta 0:00:01\r\u001b[K |████▏ | 20kB 1.7MB/s eta 0:00:01\r\u001b[K |██████▎ | 30kB 2.4MB/s eta 0:00:01\r\u001b[K |████████▍ | 40kB 2.6MB/s eta 0:00:01\r\u001b[K |██████████▍ | 51kB 2.0MB/s eta 0:00:01\r\u001b[K |████████████▌ | 61kB 2.2MB/s eta 0:00:01\r\u001b[K |██████████████▋ | 71kB 2.5MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 81kB 2.7MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 92kB 2.9MB/s eta 0:00:01\r\u001b[K |████████████████████▉ | 102kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 112kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 122kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 133kB 2.8MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▏ | 143kB 2.8MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▎| 153kB 2.8MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 163kB 2.8MB/s \n", | |
| "\u001b[?25hRequirement already satisfied: requests>=2.20 in /usr/local/lib/python3.6/dist-packages (from openai) (2.23.0)\n", | |
| "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (3.0.4)\n", | |
| "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (1.24.3)\n", | |
| "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (2.10)\n", | |
| "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (2020.6.20)\n", | |
| "Building wheels for collected packages: openai\n", | |
| " Building wheel for openai (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
| " Created wheel for openai: filename=openai-0.2.4-cp36-none-any.whl size=170709 sha256=fa007004b1de497170ec6b15a124b6957783c1d059b1b38a3ebb196696b575de\n", | |
| " Stored in directory: /root/.cache/pip/wheels/74/96/c8/c6e170929c276b836613e1b9985343b501fe455e53d85e7d48\n", | |
| "Successfully built openai\n", | |
| "Installing collected packages: openai\n", | |
| "Successfully installed openai-0.2.4\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "Q2yE0jcnpMEV", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "Loading in key.json that I uploaded; I do this so I don't need to worry about accidently leaking creds if I share the colab (which I'm 99% sure is just a json file that won't expose them)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "bwNXXwHen5x9", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "openai.api_key = json.load(open(\"key.json\", \"r\"))[\"key\"]" | |
| ], | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "k67w5H0fpTkT", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "Default keyword arguments to pass the aPI" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "e1EwpqqJkTYh", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "#arguments to send the API\n", | |
| "kwargs = {\n", | |
| "\"engine\":\"davinci\",\n", | |
| "\"temperature\":0,\n", | |
| "\"max_tokens\":10,\n", | |
| "\"stop\":\"\\n\",\n", | |
| "}" | |
| ], | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "b0zgRWqAy1GA", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "So what what we might think is is for each sentence we generate, we might want to make sure that the log prob is above some threshhold. We can also check if words are in there and if so, discard.\n", | |
| "\n", | |
| "For instance, let's say that we want to generate a rhyme" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "7AdSYZJYueSj", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "prompt = \"\"\"The following word rhymes with happy: snappy\n", | |
| "The following word rhymes with dog:\"\"\"\n", | |
| "kwargs[\"logprobs\"] = 1\n", | |
| "kwargs[\"max_tokens\"] = 5\n", | |
| "kwargs[\"temperature\"] = .5\n", | |
| "kwargs[\"stop\"] = \"\\n\"" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "_kYaKW9u1l7Q", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 272 | |
| }, | |
| "outputId": "be83120f-8a96-4250-9b01-1dafa39346b8" | |
| }, | |
| "source": [ | |
| "myWords = []\n", | |
| "attempt = 0\n", | |
| "while len(myWords) < 10:\n", | |
| " attempt += 1\n", | |
| " r = openai.Completion.create(prompt=prompt, **kwargs)\n", | |
| " newWord = r[\"choices\"][0][\"text\"]\n", | |
| " if attempt % 10 == 0:\n", | |
| " print(\"word: {}\\nFound: {}\\nAttempts: {}\".format(newWord, len(myWords), attempt))\n", | |
| " if newWord.lower() not in myWords:\n", | |
| " myWords.append(newWord)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "word: fog\n", | |
| "Found: 4\n", | |
| "Attempts: 10\n", | |
| "word: log\n", | |
| "Found: 6\n", | |
| "Attempts: 20\n", | |
| "word: fog\n", | |
| "Found: 6\n", | |
| "Attempts: 30\n", | |
| "word: god\n", | |
| "Found: 7\n", | |
| "Attempts: 40\n", | |
| "word: haggard\n", | |
| "Found: 8\n", | |
| "Attempts: 50\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "LJxHkt5i1maw", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 34 | |
| }, | |
| "outputId": "9dd85ef5-f50c-4b03-d8be-f97884064238" | |
| }, | |
| "source": [ | |
| "attempt, \",\".join(myWords)" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "(51, ' log, fog, bag, bog, baggy, god, hoggish, dog, haggard, hound')" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 11 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "Vonc2oDf3KOK", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 85 | |
| }, | |
| "outputId": "0818b6e4-9bb8-4d7d-df51-dac9d4b55db5" | |
| }, | |
| "source": [ | |
| "kwargs[\"temperature\"] = 1\n", | |
| "myWords = []\n", | |
| "attempt = 0\n", | |
| "while len(myWords) < 10:\n", | |
| " attempt += 1\n", | |
| " r = openai.Completion.create(prompt=prompt, **kwargs)\n", | |
| " newWord = r[\"choices\"][0][\"text\"].strip()\n", | |
| " if attempt % 10 == 0:\n", | |
| " print(\"word: {}\\nFound: {}\\nAttempts: {}\".format(newWord, len(myWords), attempt))\n", | |
| " if newWord.lower() not in myWords:\n", | |
| " myWords.append(newWord)\n", | |
| "print(\"{} attempts: {}\".format(attempt, \",\".join(myWords)))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "word: thog\n", | |
| "Found: 8\n", | |
| "Attempts: 10\n", | |
| "11 attempts: flog,wog,shlug,bling,sloth,yodle,jogg,hog,thog,fogg\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "7VRiL4Tf5l6Y", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 187 | |
| }, | |
| "outputId": "6de378c5-8d89-4498-bf62-6824feb37679" | |
| }, | |
| "source": [ | |
| "import random\n", | |
| "\n", | |
| "prefix = \"\"\"The following word rhymes with happy: snappy\\nThe following word rhymes with snap: trap\\n\"\"\"\n", | |
| "suffix = \"The following word rhymes with dog:\"\n", | |
| "prompt = prefix + suffix\n", | |
| "kwargs[\"logprobs\"] = 1\n", | |
| "kwargs[\"max_tokens\"] = 5\n", | |
| "kwargs[\"temperature\"] = .5\n", | |
| "kwargs[\"stop\"] = \"\\n\"\n", | |
| "kwargs[\"presence_penalty\"] = 1\n", | |
| "\n", | |
| "myWords = []\n", | |
| "attempt = 0\n", | |
| "examples = []\n", | |
| "while len(myWords) < 10:\n", | |
| " attempt += 1\n", | |
| " r = openai.Completion.create(prompt=prompt, **kwargs)\n", | |
| " newWord = r[\"choices\"][0][\"text\"].strip()\n", | |
| " if attempt % 10 == 0:\n", | |
| " print(\"word: {}\\nFound: {}\\nAttempts: {}\".format(newWord, len(myWords), attempt))\n", | |
| " if newWord.lower() not in myWords:\n", | |
| " myWords.append(newWord)\n", | |
| " newExample = \"{} {}\".format(suffix, newWord)\n", | |
| " examples.append(newExample)\n", | |
| " if len(examples) > 3:\n", | |
| " prompt = \"\"\n", | |
| " for example in random.sample(examples, k=min(10, len(examples))):\n", | |
| " prompt += example + \"\\n\"\n", | |
| " prompt += suffix\n", | |
| "\n", | |
| "print(\"{} attempts: {}\".format(attempt, \",\".join(myWords)))" | |
| ], | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "word: log\n", | |
| "Found: 3\n", | |
| "Attempts: 10\n", | |
| "word: log\n", | |
| "Found: 3\n", | |
| "Attempts: 20\n", | |
| "word: flog\n", | |
| "Found: 8\n", | |
| "Attempts: 30\n", | |
| "32 attempts: log,god,bog,fog,plog,hog,slog,shlog,flog,zog\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "HRaI_015AjTR", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "let's redo that, except use the logprobs" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "wHklFpu8i-OX", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 136 | |
| }, | |
| "outputId": "9531f1bb-a2c8-4f29-ac67-83afa276a104" | |
| }, | |
| "source": [ | |
| "import random\n", | |
| "\n", | |
| "prefix = \"\"\"The following word rhymes with happy: snappy\\nThe following word rhymes with snap: trap\\n\"\"\"\n", | |
| "suffix = \"The following word rhymes with dog:\"\n", | |
| "prompt = prefix + suffix\n", | |
| "kwargs[\"logprobs\"] = 1\n", | |
| "kwargs[\"max_tokens\"] = 5\n", | |
| "kwargs[\"temperature\"] = .5\n", | |
| "kwargs[\"stop\"] = \"\\n\"\n", | |
| "kwargs[\"presence_penalty\"] = 1\n", | |
| "\n", | |
| "myWords = []\n", | |
| "attempt = 0\n", | |
| "examples = []\n", | |
| "example_logprobs = []\n", | |
| "while len(myWords) < 10:\n", | |
| " attempt += 1\n", | |
| " r = openai.Completion.create(prompt=prompt, **kwargs)\n", | |
| " newWord = r[\"choices\"][0][\"text\"].strip()\n", | |
| " if attempt % 10 == 0:\n", | |
| " print(\"word: {}\\nFound: {}\\nAttempts: {}\".format(newWord, len(myWords), attempt))\n", | |
| " if newWord.lower() not in myWords:\n", | |
| " myWords.append(newWord)\n", | |
| " newExample = \"{} {}\".format(suffix, newWord)\n", | |
| " examples.append(newExample)\n", | |
| "\n", | |
| " # get logprob for examples\n", | |
| " r_tokens = r[\"choices\"][0][\"logprobs\"][\"tokens\"]\n", | |
| " logprobs = r[\"choices\"][0][\"logprobs\"][\"token_logprobs\"]\n", | |
| " if \"\\n\" in r_tokens:\n", | |
| " logprobs = logprobs[:r_tokens.index(\"\\n\")]\n", | |
| " example_logprobs.append((newWord, np.mean(logprobs)))\n", | |
| "\n", | |
| "\n", | |
| " if len(examples) > 3:\n", | |
| " prompt = \"\"\n", | |
| " for example in random.sample(examples, k=min(10, len(examples))):\n", | |
| " prompt += example + \"\\n\"\n", | |
| " prompt += suffix\n", | |
| "\n", | |
| "print(\"{} attempts: {}\".format(attempt, \",\".join(myWords)))" | |
| ], | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": [ | |
| "word: bog\n", | |
| "Found: 3\n", | |
| "Attempts: 10\n", | |
| "word: sod\n", | |
| "Found: 6\n", | |
| "Attempts: 20\n", | |
| "25 attempts: log,god,bog,fog,hog,nod,sod,dog,frog,cod\n" | |
| ], | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "69YOktXxnroc", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "df = pd.DataFrame(example_logprobs, columns=[\"word\", \"logprob\"])" | |
| ], | |
| "execution_count": 7, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "QRmXNuXhn-3I", | |
| "colab_type": "code", | |
| "colab": {} | |
| }, | |
| "source": [ | |
| "df[\"%\"] = df.logprob.apply(lambda x: 100*np.e**x)" | |
| ], | |
| "execution_count": 8, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "oJND-MMvoGcD", | |
| "colab_type": "code", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 359 | |
| }, | |
| "outputId": "43e70980-6225-41bd-cf24-0961b78ca81c" | |
| }, | |
| "source": [ | |
| "df" | |
| ], | |
| "execution_count": 9, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/html": [ | |
| "<div>\n", | |
| "<style scoped>\n", | |
| " .dataframe tbody tr th:only-of-type {\n", | |
| " vertical-align: middle;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe tbody tr th {\n", | |
| " vertical-align: top;\n", | |
| " }\n", | |
| "\n", | |
| " .dataframe thead th {\n", | |
| " text-align: right;\n", | |
| " }\n", | |
| "</style>\n", | |
| "<table border=\"1\" class=\"dataframe\">\n", | |
| " <thead>\n", | |
| " <tr style=\"text-align: right;\">\n", | |
| " <th></th>\n", | |
| " <th>word</th>\n", | |
| " <th>logprob</th>\n", | |
| " <th>%</th>\n", | |
| " </tr>\n", | |
| " </thead>\n", | |
| " <tbody>\n", | |
| " <tr>\n", | |
| " <th>0</th>\n", | |
| " <td>log</td>\n", | |
| " <td>-1.130146</td>\n", | |
| " <td>32.298610</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>1</th>\n", | |
| " <td>god</td>\n", | |
| " <td>-2.627420</td>\n", | |
| " <td>7.226464</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>2</th>\n", | |
| " <td>bog</td>\n", | |
| " <td>-2.385975</td>\n", | |
| " <td>9.199924</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>3</th>\n", | |
| " <td>fog</td>\n", | |
| " <td>-2.689186</td>\n", | |
| " <td>6.793622</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>4</th>\n", | |
| " <td>hog</td>\n", | |
| " <td>-3.782045</td>\n", | |
| " <td>2.277606</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>5</th>\n", | |
| " <td>nod</td>\n", | |
| " <td>-3.258999</td>\n", | |
| " <td>3.842685</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>6</th>\n", | |
| " <td>sod</td>\n", | |
| " <td>-4.729776</td>\n", | |
| " <td>0.882844</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>7</th>\n", | |
| " <td>dog</td>\n", | |
| " <td>-2.877510</td>\n", | |
| " <td>5.627471</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>8</th>\n", | |
| " <td>frog</td>\n", | |
| " <td>-3.591209</td>\n", | |
| " <td>2.756497</td>\n", | |
| " </tr>\n", | |
| " <tr>\n", | |
| " <th>9</th>\n", | |
| " <td>cod</td>\n", | |
| " <td>-3.758816</td>\n", | |
| " <td>2.331133</td>\n", | |
| " </tr>\n", | |
| " </tbody>\n", | |
| "</table>\n", | |
| "</div>" | |
| ], | |
| "text/plain": [ | |
| " word logprob %\n", | |
| "0 log -1.130146 32.298610\n", | |
| "1 god -2.627420 7.226464\n", | |
| "2 bog -2.385975 9.199924\n", | |
| "3 fog -2.689186 6.793622\n", | |
| "4 hog -3.782045 2.277606\n", | |
| "5 nod -3.258999 3.842685\n", | |
| "6 sod -4.729776 0.882844\n", | |
| "7 dog -2.877510 5.627471\n", | |
| "8 frog -3.591209 2.756497\n", | |
| "9 cod -3.758816 2.331133" | |
| ] | |
| }, | |
| "metadata": { | |
| "tags": [] | |
| }, | |
| "execution_count": 9 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "YnN6ZfpnQLjD", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "what if set temp to 0" | |
| ] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment