Skip to content

Instantly share code, notes, and snippets.

@Mistobaan
Forked from brockmanmatt/introtologprobs.ipynb
Created August 13, 2020 07:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Mistobaan/448b692f7de9c118b766ad632438d015 to your computer and use it in GitHub Desktop.
Save Mistobaan/448b692f7de9c118b766ad632438d015 to your computer and use it in GitHub Desktop.
introToLogProbs.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "introToLogProbs.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyN0wJnZBPsjhHaevTzCp+o2",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/brockmanmatt/7a346d641e2d2159eb3319f888193212/introtologprobs.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "J7wnsgT2kPut",
"colab_type": "code",
"colab": {
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
"ok": true,
"headers": [
[
"content-type",
"application/javascript"
]
],
"status": 200,
"status_text": ""
}
},
"base_uri": "https://localhost:8080/",
"height": 89
},
"outputId": "98be6105-c5b7-4e07-ae4f-e02b1cda47d0"
},
"source": [
"from google.colab import files\n",
"uploaded = files.upload()\n",
"print(\"done\")"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-014046c1-ce74-4a65-bc7c-70c1a9929292\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-014046c1-ce74-4a65-bc7c-70c1a9929292\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script src=\"/nbextensions/google.colab/files.js\"></script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Saving key.json to key.json\n",
"done\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WHPHrUnhpKnI",
"colab_type": "text"
},
"source": [
"I'll install the API"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zq0ltp2xn4yt",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 139
},
"outputId": "da8426d1-ec6f-4e4c-b02f-57016b1254d2"
},
"source": [
"!pip install openai\n",
"import openai, json, pandas as pd, numpy as np"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: openai in /usr/local/lib/python3.6/dist-packages (0.2.4)\n",
"Requirement already satisfied: requests>=2.20; python_version >= \"3.0\" in /usr/local/lib/python3.6/dist-packages (from openai) (2.23.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (2020.6.20)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (2.10)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q2yE0jcnpMEV",
"colab_type": "text"
},
"source": [
"Loading in key.json that I uploaded; I do this so I don't need to worry about accidently leaking creds if I share the colab (which I'm 99% sure is just a json file that won't expose them)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "bwNXXwHen5x9",
"colab_type": "code",
"colab": {}
},
"source": [
"openai.api_key = json.load(open(\"key.json\", \"r\"))[\"key\"]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "k67w5H0fpTkT",
"colab_type": "text"
},
"source": [
"Default keyword arguments to pass the aPI"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e1EwpqqJkTYh",
"colab_type": "code",
"colab": {}
},
"source": [
"#arguments to send the API\n",
"kwargs = {\n",
"\"engine\":\"davinci\",\n",
"\"temperature\":0,\n",
"\"max_tokens\":10,\n",
"\"stop\":\"\\n\",\n",
"}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "zZubgPoOpWDH",
"colab_type": "text"
},
"source": [
"Quick wrapper to automatically save prompts and responses sent for later analysis if needed"
]
},
{
"cell_type": "code",
"metadata": {
"id": "kY9t_siFKaPc",
"colab_type": "code",
"colab": {}
},
"source": [
"prompt = \"\"\"q: what is the capital of France\n",
"a:\"\"\""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WkLAGMCSKqqn",
"colab_type": "code",
"colab": {}
},
"source": [
"r = openai.Completion.create(prompt=prompt, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XoJucYblKvX4",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "4a796fde-4bc0-4b15-e1cd-15633619f157"
},
"source": [
"r[\"choices\"][0][\"text\"]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"' Paris'"
]
},
"metadata": {
"tags": []
},
"execution_count": 21
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9Po7MjUJKzGJ",
"colab_type": "code",
"colab": {}
},
"source": [
"kwargs[\"logprobs\"] = 5"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0teEX7n6K6qP",
"colab_type": "code",
"colab": {}
},
"source": [
"r = openai.Completion.create(prompt=prompt, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "aH9gsglBMfAi",
"colab_type": "text"
},
"source": [
"So here's all the logprobs for the subsequent tokens; it hit the stop (\\n), generated a few moe followups but still stopped."
]
},
{
"cell_type": "code",
"metadata": {
"id": "WN_UTtS-K7xg",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 297
},
"outputId": "025c7d14-4287-49b4-8d16-9921742e5d66"
},
"source": [
"pd.DataFrame(r[\"choices\"][0][\"logprobs\"])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tokens</th>\n",
" <th>token_logprobs</th>\n",
" <th>top_logprobs</th>\n",
" <th>text_offset</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Paris</td>\n",
" <td>-0.828964</td>\n",
" <td>{' par': -1.6102142, ' Par': -4.235214, ' PAR'...</td>\n",
" <td>35</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>\\n</td>\n",
" <td>-0.364414</td>\n",
" <td>{',': -3.1456642, '.': -2.6144142, '\n",
"': -0.364...</td>\n",
" <td>41</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>q</td>\n",
" <td>-1.213570</td>\n",
" <td>{'\n",
"': -1.5885696, 'The': -4.2291946, 'b': -2.4...</td>\n",
" <td>41</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>:</td>\n",
" <td>-0.004189</td>\n",
" <td>{' :': -7.0354385, '.': -7.0354385, '1': -8.53...</td>\n",
" <td>41</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>what</td>\n",
" <td>-0.479179</td>\n",
" <td>{' What': -2.2916794, ' who': -3.4791794, ' wh...</td>\n",
" <td>41</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>is</td>\n",
" <td>-0.297340</td>\n",
" <td>{' country': -4.4223404, ' color': -4.0473404,...</td>\n",
" <td>41</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>the</td>\n",
" <td>-0.146500</td>\n",
" <td>{' a': -4.0527496, ' the': -0.14649963, ' 1': ...</td>\n",
" <td>41</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>capital</td>\n",
" <td>-0.774006</td>\n",
" <td>{' name': -3.586506, ' color': -3.867756, ' ca...</td>\n",
" <td>41</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" tokens ... text_offset\n",
"0 Paris ... 35\n",
"1 \\n ... 41\n",
"2 q ... 41\n",
"3 : ... 41\n",
"4 what ... 41\n",
"5 is ... 41\n",
"6 the ... 41\n",
"7 capital ... 41\n",
"\n",
"[8 rows x 4 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 31
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "le1fHX-BMtGj",
"colab_type": "text"
},
"source": [
"we can look more at the possibilites it considered for paris, converting the logprobs to % by taking e**logprob\n",
"\n",
"Paris wins with 43%, although it almost went par"
]
},
{
"cell_type": "code",
"metadata": {
"id": "eGfnBWvHK8hS",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"outputId": "cbf7be34-3ca6-46f2-d523-1237edbe4012"
},
"source": [
"scores = pd.DataFrame([r[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]]).T\n",
"scores.columns = [\"logprob\"]\n",
"scores[\"%\"] = scores[\"logprob\"].apply(lambda x: 100*np.e**x)\n",
"scores"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>logprob</th>\n",
" <th>%</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>par</th>\n",
" <td>-1.610214</td>\n",
" <td>19.984480</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Par</th>\n",
" <td>-4.235214</td>\n",
" <td>1.447671</td>\n",
" </tr>\n",
" <tr>\n",
" <th>PAR</th>\n",
" <td>-4.172714</td>\n",
" <td>1.541038</td>\n",
" </tr>\n",
" <tr>\n",
" <th>Paris</th>\n",
" <td>-0.828964</td>\n",
" <td>43.650117</td>\n",
" </tr>\n",
" <tr>\n",
" <th>what</th>\n",
" <td>-4.422714</td>\n",
" <td>1.200162</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" logprob %\n",
" par -1.610214 19.984480\n",
" Par -4.235214 1.447671\n",
" PAR -4.172714 1.541038\n",
" Paris -0.828964 43.650117\n",
" what -4.422714 1.200162"
]
},
"metadata": {
"tags": []
},
"execution_count": 45
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UJKQ0T9gP9Ce",
"colab_type": "text"
},
"source": [
"We can see if we increase the temperature, it takes non-optimal answers. However, it still tries to complete the task and eventually makes it back to Paris (although that's not guaranteed)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "6a4yfgKmP5H9",
"colab_type": "code",
"colab": {}
},
"source": [
"kwargs[\"temperature\"] = 1.2\n",
"r = openai.Completion.create(prompt=prompt, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5f0sTMWCP5NY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 359
},
"outputId": "64cd64ce-80ae-4aad-9138-c34fbf6553a8"
},
"source": [
"pd.DataFrame(r[\"choices\"][0][\"logprobs\"])"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tokens</th>\n",
" <th>token_logprobs</th>\n",
" <th>top_logprobs</th>\n",
" <th>text_offset</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>that</td>\n",
" <td>-5.997170</td>\n",
" <td>{' Paris': -0.8409195, ' par': -1.5284195, ' P...</td>\n",
" <td>35</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>'s</td>\n",
" <td>-0.899242</td>\n",
" <td>{' is': -1.1492424, ''s': -0.8992424, 'bytes:\\...</td>\n",
" <td>40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>an</td>\n",
" <td>-3.084446</td>\n",
" <td>{' easy': -3.006321, ' a': -1.475071, ' an': -...</td>\n",
" <td>42</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>easy</td>\n",
" <td>-1.239227</td>\n",
" <td>{' example': -3.5361023, ' easy': -1.2392273, ...</td>\n",
" <td>45</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>one</td>\n",
" <td>-0.112442</td>\n",
" <td>{' q': -6.128067, ' question': -2.424942, ' an...</td>\n",
" <td>50</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>-</td>\n",
" <td>-4.639725</td>\n",
" <td>{',': -1.1084747, '.': -2.1397247, ':': -2.483...</td>\n",
" <td>54</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>Paris</td>\n",
" <td>-2.285274</td>\n",
" <td>{' par': -3.0977745, ' it': -2.3321495, 'Paris...</td>\n",
" <td>55</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>(</td>\n",
" <td>-4.684345</td>\n",
" <td>{'\n",
"': -0.52809525, '.': -2.0280952, '!': -2.24...</td>\n",
" <td>61</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>Y</td>\n",
" <td>-7.411562</td>\n",
" <td>{'or': -2.895937, 'the': -3.380312, 'correct':...</td>\n",
" <td>63</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>ay</td>\n",
" <td>-1.383808</td>\n",
" <td>{'ahoo': -3.2275581, 'ay': -1.3838081, 'AY': -...</td>\n",
" <td>64</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" tokens ... text_offset\n",
"0 that ... 35\n",
"1 's ... 40\n",
"2 an ... 42\n",
"3 easy ... 45\n",
"4 one ... 50\n",
"5 - ... 54\n",
"6 Paris ... 55\n",
"7 ( ... 61\n",
"8 Y ... 63\n",
"9 ay ... 64\n",
"\n",
"[10 rows x 4 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 87
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Z5CtVmbyP5Kl",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "HMsKJCv1P5FS",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kU5aacCKM1lS",
"colab_type": "code",
"colab": {}
},
"source": [
"prompt = \"\"\"These word rhyme:\n",
"red:led\n",
"dog:frog\n",
"small:tall\n",
"train:\"\"\"\n",
"kwargs[\"logprobs\"] = 10\n",
"kwargs[\"max_tokens\"] = 20\n",
"kwargs[\"temperature\"] = 0\n",
"\n",
"r = openai.Completion.create(prompt=prompt, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kuewt1akNPRM",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 359
},
"outputId": "54ad0bfa-f8aa-4e9c-8b15-750e88998796"
},
"source": [
"scores = pd.DataFrame([r[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]]).T\n",
"scores.columns = [\"logprob\"]\n",
"scores[\"%\"] = scores[\"logprob\"].apply(lambda x: 100*np.e**x)\n",
"scores.sort_values(by=\"%\", ascending=False)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>logprob</th>\n",
" <th>%</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>pain</th>\n",
" <td>-1.277435</td>\n",
" <td>27.875130</td>\n",
" </tr>\n",
" <tr>\n",
" <th>rain</th>\n",
" <td>-2.277435</td>\n",
" <td>10.254687</td>\n",
" </tr>\n",
" <tr>\n",
" <th>brain</th>\n",
" <td>-2.621185</td>\n",
" <td>7.271662</td>\n",
" </tr>\n",
" <tr>\n",
" <th>chain</th>\n",
" <td>-3.277435</td>\n",
" <td>3.772489</td>\n",
" </tr>\n",
" <tr>\n",
" <th>str</th>\n",
" <td>-3.355560</td>\n",
" <td>3.488982</td>\n",
" </tr>\n",
" <tr>\n",
" <th>plane</th>\n",
" <td>-3.621185</td>\n",
" <td>2.675095</td>\n",
" </tr>\n",
" <tr>\n",
" <th>gain</th>\n",
" <td>-3.746185</td>\n",
" <td>2.360763</td>\n",
" </tr>\n",
" <tr>\n",
" <th>main</th>\n",
" <td>-3.933685</td>\n",
" <td>1.957141</td>\n",
" </tr>\n",
" <tr>\n",
" <th>p</th>\n",
" <td>-4.027435</td>\n",
" <td>1.781997</td>\n",
" </tr>\n",
" <tr>\n",
" <th>plant</th>\n",
" <td>-4.089935</td>\n",
" <td>1.674032</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" logprob %\n",
"pain -1.277435 27.875130\n",
"rain -2.277435 10.254687\n",
"brain -2.621185 7.271662\n",
"chain -3.277435 3.772489\n",
"str -3.355560 3.488982\n",
"plane -3.621185 2.675095\n",
"gain -3.746185 2.360763\n",
"main -3.933685 1.957141\n",
"p -4.027435 1.781997\n",
"plant -4.089935 1.674032"
]
},
"metadata": {
"tags": []
},
"execution_count": 116
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "-EoMNmusN1cw",
"colab_type": "code",
"colab": {}
},
"source": [
"prompt = \"\"\"These pairs of sentences rhyme:\n",
"My favorite color is red\n",
"ends with: \"red\"\n",
"\"red\" rhymes with \"bed\"\n",
"Rhyme: It's the color of my bed\n",
"-----\n",
"I once had a dog\n",
"ends with: \"dog\"\n",
"\"dog\" rhymes with \"frog\"\n",
"Rhyme: That good boy ate a frog\n",
"-----\n",
"I wish I was small\n",
"ends with: \"small\"\n",
"\"small\" rhymes with \"tall\"\n",
"Rhyme: Instead I'm so tall ='(\n",
"-----\n",
"That's a cool train\n",
"ends with:\"\"\"\n",
"kwargs[\"logprobs\"] = 5\n",
"kwargs[\"max_tokens\"] = 40\n",
"kwargs[\"temperature\"] = 0\n",
"kwargs[\"stop\"] = \"-----\"\n",
"\n",
"r = openai.Completion.create(prompt=prompt, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LIovt7zaS_eP",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "63f13251-d56d-4001-d54b-9cd6fb44dbd9"
},
"source": [
"r[\"choices\"][0][\"text\"]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"' \"train\"\\n\"train\" rhymes with \"rain\"\\nRhyme: I like to ride the rain\\n'"
]
},
"metadata": {
"tags": []
},
"execution_count": 170
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "pu_sgx2vSuWW",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"outputId": "8348e9fc-1efe-43cc-cc96-d27a4f70a5e1"
},
"source": [
"scores = pd.DataFrame([r[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][11]]).T\n",
"scores.columns = [\"logprob\"]\n",
"scores[\"%\"] = scores[\"logprob\"].apply(lambda x: 100*np.e**x)\n",
"scores.sort_values(by=\"%\", ascending=False)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>logprob</th>\n",
" <th>%</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>rain</th>\n",
" <td>-1.411346</td>\n",
" <td>24.381479</td>\n",
" </tr>\n",
" <tr>\n",
" <th>pain</th>\n",
" <td>-2.051971</td>\n",
" <td>12.848137</td>\n",
" </tr>\n",
" <tr>\n",
" <th>brain</th>\n",
" <td>-2.567596</td>\n",
" <td>7.671973</td>\n",
" </tr>\n",
" <tr>\n",
" <th>plane</th>\n",
" <td>-3.458221</td>\n",
" <td>3.148571</td>\n",
" </tr>\n",
" <tr>\n",
" <th>str</th>\n",
" <td>-3.583221</td>\n",
" <td>2.778604</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" logprob %\n",
"rain -1.411346 24.381479\n",
"pain -2.051971 12.848137\n",
"brain -2.567596 7.671973\n",
"plane -3.458221 3.148571\n",
"str -3.583221 2.778604"
]
},
"metadata": {
"tags": []
},
"execution_count": 171
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "URQ2ENPwSveY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 483
},
"outputId": "3673773e-8673-4f79-ee59-68ae8b37681b"
},
"source": [
"pd.DataFrame(r[\"choices\"][0][\"logprobs\"])[18:]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tokens</th>\n",
" <th>token_logprobs</th>\n",
" <th>top_logprobs</th>\n",
" <th>text_offset</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>I</td>\n",
" <td>-1.807033</td>\n",
" <td>{' That': -1.9320335, ' I': -1.8070335, ' The'...</td>\n",
" <td>407</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>like</td>\n",
" <td>-1.590843</td>\n",
" <td>{''m': -2.6533432, ' love': -2.5908432, ' wish...</td>\n",
" <td>409</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>to</td>\n",
" <td>-1.032188</td>\n",
" <td>{' trains': -2.5634384, ' the': -2.0946884, ' ...</td>\n",
" <td>414</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>ride</td>\n",
" <td>-1.397640</td>\n",
" <td>{' watch': -1.5538902, ' hear': -3.6476402, ' ...</td>\n",
" <td>417</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>the</td>\n",
" <td>-1.274738</td>\n",
" <td>{' a': -2.3059883, ' the': -1.2747383, ' in': ...</td>\n",
" <td>422</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>rain</td>\n",
" <td>-0.667145</td>\n",
" <td>{' Rain': -5.68277, ' subway': -4.823395, ' \"'...</td>\n",
" <td>426</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>\\n</td>\n",
" <td>-0.295723</td>\n",
" <td>{'.': -3.389473, '\n",
"': -0.29572296, ' train': -...</td>\n",
" <td>431</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>-----</td>\n",
" <td>-0.310501</td>\n",
" <td>{'\n",
"': -2.529251, '-----': -0.3105011, 'R': -4....</td>\n",
" <td>432</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>\\n</td>\n",
" <td>-0.019367</td>\n",
" <td>{'\n",
"': -0.019367218, ' ': -6.050617, ' I': -8.0...</td>\n",
" <td>432</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>I</td>\n",
" <td>-1.208935</td>\n",
" <td>{'\n",
"': -2.8651848, 'That': -3.0839348, 'My': -2...</td>\n",
" <td>432</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>like</td>\n",
" <td>-1.733200</td>\n",
" <td>{''m': -2.51445, ' love': -2.70195, ' have': -...</td>\n",
" <td>432</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>to</td>\n",
" <td>-0.730225</td>\n",
" <td>{' the': -3.2614746, ' to': -0.7302246, ' that...</td>\n",
" <td>432</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>eat</td>\n",
" <td>-1.868992</td>\n",
" <td>{' read': -2.9002419, ' sing': -3.4002419, ' e...</td>\n",
" <td>432</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31</th>\n",
" <td>\\n</td>\n",
" <td>-2.445057</td>\n",
" <td>{' pizza': -3.257557, '\n",
"': -2.445057, ' pie': ...</td>\n",
" <td>432</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" tokens ... text_offset\n",
"18 I ... 407\n",
"19 like ... 409\n",
"20 to ... 414\n",
"21 ride ... 417\n",
"22 the ... 422\n",
"23 rain ... 426\n",
"24 \\n ... 431\n",
"25 ----- ... 432\n",
"26 \\n ... 432\n",
"27 I ... 432\n",
"28 like ... 432\n",
"29 to ... 432\n",
"30 eat ... 432\n",
"31 \\n ... 432\n",
"\n",
"[14 rows x 4 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 172
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "BLRp1vNBTYI-",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 80
},
"outputId": "af7037bb-a3fc-49d0-fc53-93c7494b4a51"
},
"source": [
"pd.DataFrame([r[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][23]]).apply(lambda x: 100*np.e**x)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Rain</th>\n",
" <th>subway</th>\n",
" <th>\"</th>\n",
" <th>train</th>\n",
" <th>rain</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0.340412</td>\n",
" <td>0.803945</td>\n",
" <td>0.458074</td>\n",
" <td>42.543428</td>\n",
" <td>51.31717</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Rain subway \" train rain\n",
"0 0.340412 0.803945 0.458074 42.543428 51.31717"
]
},
"metadata": {
"tags": []
},
"execution_count": 173
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "TZGcB9AXYrgs",
"colab_type": "code",
"colab": {}
},
"source": [
"rhymed = pd.DataFrame(r[\"choices\"][0][\"logprobs\"])[18:]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "NdNI_F9wVEUG",
"colab_type": "text"
},
"source": [
"What if we make this more creative"
]
},
{
"cell_type": "code",
"metadata": {
"id": "wRxYDGQdUuMS",
"colab_type": "code",
"colab": {}
},
"source": [
"prompt = \"\"\"These pairs of sentences rhyme:\n",
"My favorite color is red\n",
"ends with: \"red\"\n",
"\"red\" rhymes with \"bed\"\n",
"Rhyme: It's the color of my bed\n",
"-----\n",
"I once had a dog\n",
"ends with: \"dog\"\n",
"\"dog\" rhymes with \"frog\"\n",
"Rhyme: That good boy ate a frog\n",
"-----\n",
"I wish I was small\n",
"ends with: \"small\"\n",
"\"small\" rhymes with \"tall\"\n",
"Rhyme: Instead I'm so tall ='(\n",
"-----\n",
"That's a cool train\n",
"ends with:\"\"\"\n",
"kwargs[\"logprobs\"] = 5\n",
"kwargs[\"max_tokens\"] = 40\n",
"kwargs[\"temperature\"] = .5\n",
"kwargs[\"stop\"] = \"-----\"\n",
"\n",
"r = openai.Completion.create(prompt=prompt, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hStXV_UCVGjO",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "715b868a-132a-4959-8498-4b64dcc0ff50"
},
"source": [
"r[\"choices\"][0][\"text\"]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"' \"train\"\\n\"train\" rhymes with \"plane\"\\nRhyme: It\\'s not a plane\\n'"
]
},
"metadata": {
"tags": []
},
"execution_count": 192
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "08OYB7FjVYw2",
"colab_type": "code",
"colab": {}
},
"source": [
"df = pd.DataFrame(r[\"choices\"][0][\"logprobs\"])[18:]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5udyaq7NVoCv",
"colab_type": "code",
"colab": {}
},
"source": [
"def getTopValueFromDict(someDict):\n",
" myDict = dict(someDict)\n",
" vals = [(myDict[x], x) for x in myDict]\n",
" return max(vals)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Mn0atxbbVa34",
"colab_type": "code",
"colab": {}
},
"source": [
"df[\"actual_top_logprob\"] = df.top_logprobs.apply(lambda x: getTopValueFromDict(x))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XnKf1xSoVIPq",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 483
},
"outputId": "679f20bd-1467-4469-c366-1e553f53d3b4"
},
"source": [
"df"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tokens</th>\n",
" <th>token_logprobs</th>\n",
" <th>top_logprobs</th>\n",
" <th>text_offset</th>\n",
" <th>actual_top_logprob</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>It</td>\n",
" <td>-2.047520</td>\n",
" <td>{' That': -1.9850197, ' I': -1.6725197, ' The'...</td>\n",
" <td>408</td>\n",
" <td>(-1.6725197, I)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>'s</td>\n",
" <td>-0.899273</td>\n",
" <td>{' flies': -2.649273, ' makes': -3.586773, ' g...</td>\n",
" <td>411</td>\n",
" <td>(-0.8992729, 's)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>not</td>\n",
" <td>-2.945133</td>\n",
" <td>{' a': -1.6326332, ' the': -2.4451332, ' not':...</td>\n",
" <td>413</td>\n",
" <td>(-1.6326332, a)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>a</td>\n",
" <td>-1.259831</td>\n",
" <td>{' a': -1.2598305, ' as': -2.4160805, ' like':...</td>\n",
" <td>417</td>\n",
" <td>(-1.2598305, a)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>plane</td>\n",
" <td>-0.845264</td>\n",
" <td>{' big': -3.0952644, ' cool': -3.1577644, ' pl...</td>\n",
" <td>419</td>\n",
" <td>(-0.84526443, plane)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>\\n</td>\n",
" <td>-1.518738</td>\n",
" <td>{',': -1.3312378, '\n",
"': -1.5187378, ' or': -1.7...</td>\n",
" <td>425</td>\n",
" <td>(-1.3312378, ,)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>-----</td>\n",
" <td>-0.515564</td>\n",
" <td>{'\"': -3.343689, 'It': -2.859314, '\n",
"': -3.4843...</td>\n",
" <td>426</td>\n",
" <td>(-0.51556396, -----)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>\\n</td>\n",
" <td>-0.013977</td>\n",
" <td>{'\n",
"': -0.013977051, ' ': -6.263977, ' I': -7.9...</td>\n",
" <td>426</td>\n",
" <td>(-0.013977051, \\n)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>That</td>\n",
" <td>-2.864292</td>\n",
" <td>{'\n",
"': -3.2392921, 'That': -2.8642921, 'My': -2...</td>\n",
" <td>426</td>\n",
" <td>(-1.2080421, I)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>'s</td>\n",
" <td>-4.619221</td>\n",
" <td>{''s': -4.6192207, ' is': -5.1348457, ' was': ...</td>\n",
" <td>426</td>\n",
" <td>(-4.6192207, 's)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>a</td>\n",
" <td>-0.000057</td>\n",
" <td>{' a': -5.722046e-05, ' my': -10.562557, ' not...</td>\n",
" <td>426</td>\n",
" <td>(-5.722046e-05, a)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>cool</td>\n",
" <td>-2.566330</td>\n",
" <td>{' cool': -2.56633, ' big': -2.738205, ' nice'...</td>\n",
" <td>426</td>\n",
" <td>(-2.56633, cool)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>plane</td>\n",
" <td>-0.089695</td>\n",
" <td>{' plane': -0.08969498, ' train': -2.589695, '...</td>\n",
" <td>426</td>\n",
" <td>(-0.08969498, plane)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31</th>\n",
" <td>\\n</td>\n",
" <td>-0.940624</td>\n",
" <td>{'\n",
"': -0.94062424, '!': -1.5343742, '.': -1.53...</td>\n",
" <td>426</td>\n",
" <td>(-0.94062424, \\n)</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" tokens token_logprobs ... text_offset actual_top_logprob\n",
"18 It -2.047520 ... 408 (-1.6725197, I)\n",
"19 's -0.899273 ... 411 (-0.8992729, 's)\n",
"20 not -2.945133 ... 413 (-1.6326332, a)\n",
"21 a -1.259831 ... 417 (-1.2598305, a)\n",
"22 plane -0.845264 ... 419 (-0.84526443, plane)\n",
"23 \\n -1.518738 ... 425 (-1.3312378, ,)\n",
"24 ----- -0.515564 ... 426 (-0.51556396, -----)\n",
"25 \\n -0.013977 ... 426 (-0.013977051, \\n)\n",
"26 That -2.864292 ... 426 (-1.2080421, I)\n",
"27 's -4.619221 ... 426 (-4.6192207, 's)\n",
"28 a -0.000057 ... 426 (-5.722046e-05, a)\n",
"29 cool -2.566330 ... 426 (-2.56633, cool)\n",
"30 plane -0.089695 ... 426 (-0.08969498, plane)\n",
"31 \\n -0.940624 ... 426 (-0.94062424, \\n)\n",
"\n",
"[14 rows x 5 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 196
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZtOVl-C9ZN9m",
"colab_type": "code",
"colab": {}
},
"source": [
"rhyming_pt5 = df.copy()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "GSFVvNSSYKZ4",
"colab_type": "text"
},
"source": [
"How does the logprobs for the bad compare to logprobs for good?"
]
},
{
"cell_type": "code",
"metadata": {
"id": "OjRAY_8iVOCi",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "vjwcNkRsZUln",
"colab": {}
},
"source": [
"prompt = \"\"\"These pairs of sentences rhyme:\n",
"My favorite color is red\n",
"ends with: \"red\"\n",
"\"red\" rhymes with \"bed\"\n",
"Rhyme: It's the color of my bed\n",
"-----\n",
"I once had a dog\n",
"ends with: \"dog\"\n",
"\"dog\" rhymes with \"frog\"\n",
"Rhyme: That good boy ate a frog\n",
"-----\n",
"I wish I was small\n",
"ends with: \"small\"\n",
"\"small\" rhymes with \"tall\"\n",
"Rhyme: Instead I'm so tall ='(\n",
"-----\n",
"That's a cool train\n",
"ends with:\"\"\"\n",
"kwargs[\"logprobs\"] = 5\n",
"kwargs[\"max_tokens\"] = 40\n",
"kwargs[\"temperature\"] = .5\n",
"kwargs[\"stop\"] = \"-----\"\n",
"\n",
"r = openai.Completion.create(prompt=prompt, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "cKiTbWxRZUlr",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "697448e5-d856-4d65-84de-e9a599f7944d"
},
"source": [
"r[\"choices\"][0][\"text\"]"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"' \"train\"\\n\"train\" rhymes with \"rain\"\\nRhyme: The rain is so cool\\n'"
]
},
"metadata": {
"tags": []
},
"execution_count": 199
}
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "SiJNYKcIZUlu",
"colab": {}
},
"source": [
"df = pd.DataFrame(r[\"choices\"][0][\"logprobs\"])[18:]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "9P30yDlrZUlx",
"colab": {}
},
"source": [
"def getTopValueFromDict(someDict):\n",
" myDict = dict(someDict)\n",
" vals = [(myDict[x], x) for x in myDict]\n",
" return max(vals)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "mW_4_NmEZUlz",
"colab": {}
},
"source": [
"df[\"actual_top_logprob\"] = df.top_logprobs.apply(lambda x: getTopValueFromDict(x))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "aVIqqoYEZUl2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 483
},
"outputId": "42e04913-1c2a-4f4a-bc47-d530682fca51"
},
"source": [
"df"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>tokens</th>\n",
" <th>token_logprobs</th>\n",
" <th>top_logprobs</th>\n",
" <th>text_offset</th>\n",
" <th>actual_top_logprob</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>The</td>\n",
" <td>-2.859199</td>\n",
" <td>{' That': -1.9529495, ' I': -1.7966995, ' The'...</td>\n",
" <td>407</td>\n",
" <td>(-1.7966995, I)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>rain</td>\n",
" <td>-0.811218</td>\n",
" <td>{' water': -4.4674683, ' cool': -4.1237183, ' ...</td>\n",
" <td>411</td>\n",
" <td>(-0.81121826, rain)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>is</td>\n",
" <td>-2.080967</td>\n",
" <td>{' came': -1.799717, ' makes': -2.799717, ' co...</td>\n",
" <td>416</td>\n",
" <td>(-1.799717, came)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>so</td>\n",
" <td>-1.509876</td>\n",
" <td>{' coming': -2.7911263, ' a': -2.6348763, ' co...</td>\n",
" <td>419</td>\n",
" <td>(-1.5098763, so)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>cool</td>\n",
" <td>-1.401932</td>\n",
" <td>{' fun': -2.4019318, ' cool': -1.4019318, ' tr...</td>\n",
" <td>422</td>\n",
" <td>(-1.4019318, cool)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>\\n</td>\n",
" <td>-0.392258</td>\n",
" <td>{'!': -3.4860077, ',': -3.7047577, '.': -3.517...</td>\n",
" <td>427</td>\n",
" <td>(-0.3922577, \\n)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>-----</td>\n",
" <td>-0.415947</td>\n",
" <td>{'\n",
"': -2.228447, '-----': -0.41594696, 'R': -4...</td>\n",
" <td>428</td>\n",
" <td>(-0.41594696, -----)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>\\n</td>\n",
" <td>-0.016262</td>\n",
" <td>{'\n",
"': -0.016262054, ' ': -6.297512, ' I': -7.7...</td>\n",
" <td>428</td>\n",
" <td>(-0.016262054, \\n)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>I</td>\n",
" <td>-1.232170</td>\n",
" <td>{'\n",
"': -2.95092, 'That': -2.91967, 'My': -2.482...</td>\n",
" <td>428</td>\n",
" <td>(-1.2321701, I)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>like</td>\n",
" <td>-1.846847</td>\n",
" <td>{''m': -2.5030975, ' love': -2.5030975, ' have...</td>\n",
" <td>428</td>\n",
" <td>(-1.8468475, like)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>to</td>\n",
" <td>-1.022278</td>\n",
" <td>{' the': -2.9597778, ' to': -1.0222778, ' that...</td>\n",
" <td>428</td>\n",
" <td>(-1.0222778, to)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>eat</td>\n",
" <td>-1.644001</td>\n",
" <td>{' read': -2.925251, ' sing': -3.300251, ' eat...</td>\n",
" <td>428</td>\n",
" <td>(-1.644001, eat)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>rice</td>\n",
" <td>-3.705196</td>\n",
" <td>{'\n",
"': -2.8301964, ' ice': -2.8614464, ' pizza'...</td>\n",
" <td>428</td>\n",
" <td>(-2.8301964, \\n)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31</th>\n",
" <td>\\n</td>\n",
" <td>-0.055901</td>\n",
" <td>{',': -5.5871506, '.': -4.6809006, '\n",
"': -0.055...</td>\n",
" <td>428</td>\n",
" <td>(-0.055900574, \\n)</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" tokens token_logprobs ... text_offset actual_top_logprob\n",
"18 The -2.859199 ... 407 (-1.7966995, I)\n",
"19 rain -0.811218 ... 411 (-0.81121826, rain)\n",
"20 is -2.080967 ... 416 (-1.799717, came)\n",
"21 so -1.509876 ... 419 (-1.5098763, so)\n",
"22 cool -1.401932 ... 422 (-1.4019318, cool)\n",
"23 \\n -0.392258 ... 427 (-0.3922577, \\n)\n",
"24 ----- -0.415947 ... 428 (-0.41594696, -----)\n",
"25 \\n -0.016262 ... 428 (-0.016262054, \\n)\n",
"26 I -1.232170 ... 428 (-1.2321701, I)\n",
"27 like -1.846847 ... 428 (-1.8468475, like)\n",
"28 to -1.022278 ... 428 (-1.0222778, to)\n",
"29 eat -1.644001 ... 428 (-1.644001, eat)\n",
"30 rice -3.705196 ... 428 (-2.8301964, \\n)\n",
"31 \\n -0.055901 ... 428 (-0.055900574, \\n)\n",
"\n",
"[14 rows x 5 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 203
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "b2ZItSGmVRU7",
"colab_type": "code",
"colab": {}
},
"source": [
"bad_pt5 = df.copy()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "3DUTQZDpcC1f",
"colab_type": "text"
},
"source": [
"K, real quick, how do the average logprobs compare? The highest logprob average rhymes! So this is a good indication that an average high logprob will be the correct answer"
]
},
{
"cell_type": "code",
"metadata": {
"id": "38Ghx9zpcadg",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "d18ef09d-a8f9-467e-f5a2-0b5cbe21a2b6"
},
"source": [
"rhymed[:rhymed.tokens.to_list().index(\"\\n\")].token_logprobs.mean()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"-1.2949314"
]
},
"metadata": {
"tags": []
},
"execution_count": 213
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6fo2G7SeZau2",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "a2babeaf-11b4-4f56-a48a-12b7c8ce7b6b"
},
"source": [
"rhyming_pt5[:rhyming_pt5.tokens.to_list().index(\"\\n\")].token_logprobs.mean()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"-1.5994041460000001"
]
},
"metadata": {
"tags": []
},
"execution_count": 211
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cTbGva6hcIKv",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "ed0fc4cd-c915-4d87-c78a-a7a8c7b6a8d8"
},
"source": [
"bad_pt5[:bad_pt5.tokens.to_list().index(\"\\n\")].token_logprobs.mean()"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"-1.7326385720000002"
]
},
"metadata": {
"tags": []
},
"execution_count": 212
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k5NDvpIGrQnY",
"colab_type": "text"
},
"source": [
"Cool! This actually how best_of works; for instance, let's get n=10 at temp=.5"
]
},
{
"cell_type": "code",
"metadata": {
"id": "HazROq7GcZ0_",
"colab_type": "code",
"colab": {}
},
"source": [
"prompt = \"\"\"These pairs of sentences rhyme:\n",
"My favorite color is red\n",
"ends with: \"red\"\n",
"\"red\" rhymes with \"bed\"\n",
"Rhyme: It's the color of my bed\n",
"-----\n",
"I once had a dog\n",
"ends with: \"dog\"\n",
"\"dog\" rhymes with \"frog\"\n",
"Rhyme: That good boy ate a frog\n",
"-----\n",
"I wish I was small\n",
"ends with: \"small\"\n",
"\"small\" rhymes with \"tall\"\n",
"Rhyme: Instead I'm so tall ='(\n",
"-----\n",
"That's a cool train\n",
"ends with:\"\"\"\n",
"kwargs[\"logprobs\"] = 5\n",
"kwargs[\"max_tokens\"] = 40\n",
"kwargs[\"temperature\"] = .5\n",
"kwargs[\"stop\"] = \"-----\"\n",
"kwargs[\"n\"] = 10\n",
"\n",
"r = openai.Completion.create(prompt=prompt, **kwargs)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4_9zp9qlrZwC",
"colab_type": "code",
"colab": {}
},
"source": [
"texts = [r[\"choices\"][i][\"text\"].split(\"\\n\")[-2][7:] for i in range(10)]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "dRk6wSS3r8x1",
"colab_type": "code",
"colab": {}
},
"source": [
"logprobs = []\n",
"for i in range(10):\n",
" df = pd.DataFrame(r[\"choices\"][i][\"logprobs\"])[18:]\n",
" df[\"actual_top_logprob\"] = df.top_logprobs.apply(lambda x: getTopValueFromDict(x))\n",
" logprobs.append(df[:df.tokens.to_list().index(\"\\n\")].token_logprobs.mean())"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "TJviL5vXri9-",
"colab_type": "code",
"colab": {}
},
"source": [
"df = pd.DataFrame([texts]).T\n",
"df.columns=[\"text\"]\n",
"df[\"logprob\"] = logprobs\n",
"df[\"%\"] = df.logprob.apply(lambda x: 100*np.e**x)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-vaHwe_1rkTQ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 359
},
"outputId": "3b00ce51-686d-44eb-aade-497f7c2fb4cf"
},
"source": [
"df.sort_values(by=\"logprob\", ascending=False)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>logprob</th>\n",
" <th>%</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>I like the rain</td>\n",
" <td>-1.092228</td>\n",
" <td>33.546824</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>That's a cool raincoat</td>\n",
" <td>-1.343981</td>\n",
" <td>26.080522</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>It's really fun to ride</td>\n",
" <td>-1.435801</td>\n",
" <td>23.792467</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>It goes \"Chugga chugga chugga\"</td>\n",
" <td>-1.829431</td>\n",
" <td>16.050481</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>I can't find my brain</td>\n",
" <td>-1.907298</td>\n",
" <td>14.848097</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>It's a very long train</td>\n",
" <td>-3.161690</td>\n",
" <td>4.235411</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>That's the brain train</td>\n",
" <td>-3.285344</td>\n",
" <td>3.742771</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>It's made of tin and rain</td>\n",
" <td>-3.425831</td>\n",
" <td>3.252225</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>I'm getting wet again</td>\n",
" <td>-4.019777</td>\n",
" <td>1.795697</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>The strain of that train</td>\n",
" <td>-4.875232</td>\n",
" <td>0.763333</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text logprob %\n",
"5 I like the rain -1.092228 33.546824\n",
"9 That's a cool raincoat -1.343981 26.080522\n",
"6 It's really fun to ride -1.435801 23.792467\n",
"2 It goes \"Chugga chugga chugga\" -1.829431 16.050481\n",
"7 I can't find my brain -1.907298 14.848097\n",
"0 It's a very long train -3.161690 4.235411\n",
"8 That's the brain train -3.285344 3.742771\n",
"3 It's made of tin and rain -3.425831 3.252225\n",
"4 I'm getting wet again -4.019777 1.795697\n",
"1 The strain of that train -4.875232 0.763333"
]
},
"metadata": {
"tags": []
},
"execution_count": 242
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4L5-zOSus9qt",
"colab_type": "text"
},
"source": [
"So the problem now is that the average logprob isn't even the best! We'll skin that cat later, but for now, what if we also get rid of the repitition. This might not work because we repeat but we'll penalize repetition just a tad"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rafy5RJGrqRi",
"colab_type": "code",
"colab": {}
},
"source": [
"kwargs[\"logprobs\"] = 1\n",
"kwargs[\"max_tokens\"] = 40\n",
"kwargs[\"temperature\"] = .5\n",
"kwargs[\"stop\"] = \"-----\"\n",
"kwargs[\"n\"] = 5\n",
"kwargs[\"frequency_penalty\"] = .1\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "IIcHlTH6tild",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 204
},
"outputId": "3166ea6b-f2bf-45e2-e0fd-c896ed0c603e"
},
"source": [
"rslts = []\n",
"for frequency_penalty in range(-5, 6):\n",
" print(frequency_penalty)\n",
" kwargs[\"frequency_penalty\"] = np.round(.1 * frequency_penalty, 1)\n",
" r = openai.Completion.create(prompt=prompt, **kwargs)\n",
" texts = [r[\"choices\"][i][\"text\"].split(\"\\n\")[-2][7:] for i in range(5)]\n",
" logprobs = []\n",
" for i in range(5):\n",
" df = pd.DataFrame(r[\"choices\"][i][\"logprobs\"])[18:]\n",
" df[\"actual_top_logprob\"] = df.top_logprobs.apply(lambda x: getTopValueFromDict(x))\n",
" logprobs.append(df[:df.tokens.to_list().index(\"\\n\")].token_logprobs.mean())\n",
" df = pd.DataFrame([texts]).T\n",
" df.columns=[\"text\"]\n",
" df[\"logprob\"] = logprobs\n",
" df[\"%\"] = df.logprob.apply(lambda x: 100*np.e**x)\n",
" df.sort_values(by=\"logprob\", ascending=False)\n",
" df[\"frequency_penalty\"] = np.round(.1 * frequency_penalty, 1)\n",
" rslts.append(df.copy())"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"-5\n",
"-4\n",
"-3\n",
"-2\n",
"-1\n",
"0\n",
"1\n",
"2\n",
"3\n",
"4\n",
"5\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yTdEYFUtyJq8",
"colab_type": "text"
},
"source": [
"Now, one of the big things we should realize is that changing the penalty likely influences the absolute value of the logprobs; \"that' a cool rain\" has basically the same logprob at .3 for some reason, but it drops off significantly at -.3 and .5."
]
},
{
"cell_type": "code",
"metadata": {
"id": "p9z6KxubxTm8",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "150a8056-a1c5-4842-f178-0f75f8294a62"
},
"source": [
"pd.concat(rslts).sort_values(by=\"%\", ascending=False)"
],
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>logprob</th>\n",
" <th>%</th>\n",
" <th>frequency_penalty</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>That's a cool train</td>\n",
" <td>-0.304744</td>\n",
" <td>73.731213</td>\n",
" <td>-0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td></td>\n",
" <td>-0.618064</td>\n",
" <td>53.898677</td>\n",
" <td>-0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>That train is so pain ='(</td>\n",
" <td>-0.764359</td>\n",
" <td>46.563253</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>That's a cool plane</td>\n",
" <td>-0.946062</td>\n",
" <td>38.826682</td>\n",
" <td>0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>That's a cool rain</td>\n",
" <td>-0.949751</td>\n",
" <td>38.683722</td>\n",
" <td>0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>That's a cool rain</td>\n",
" <td>-0.949796</td>\n",
" <td>38.682010</td>\n",
" <td>0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>That train was a strain</td>\n",
" <td>-1.077294</td>\n",
" <td>34.051571</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>It makes a lot of noise</td>\n",
" <td>-1.172296</td>\n",
" <td>30.965518</td>\n",
" <td>-0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>My train has a strain</td>\n",
" <td>-1.250902</td>\n",
" <td>28.624662</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>It's so long and it's so tall</td>\n",
" <td>-1.306161</td>\n",
" <td>27.085801</td>\n",
" <td>-0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>It's raining outside</td>\n",
" <td>-1.312717</td>\n",
" <td>26.908783</td>\n",
" <td>-0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>That's a cool rain</td>\n",
" <td>-1.356037</td>\n",
" <td>25.767990</td>\n",
" <td>-0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>That's a painful train</td>\n",
" <td>-1.401003</td>\n",
" <td>24.634987</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>I want to be a train driver</td>\n",
" <td>-1.417063</td>\n",
" <td>24.242492</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>It's raining outside</td>\n",
" <td>-1.427773</td>\n",
" <td>23.984234</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>That train was a pain</td>\n",
" <td>-1.431793</td>\n",
" <td>23.888017</td>\n",
" <td>0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>That's a train of pain</td>\n",
" <td>-1.437199</td>\n",
" <td>23.759227</td>\n",
" <td>-0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>That's a cool rain</td>\n",
" <td>-1.451405</td>\n",
" <td>23.424087</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>It's so cool it's a plane</td>\n",
" <td>-1.453801</td>\n",
" <td>23.368034</td>\n",
" <td>0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>I got hit by a train ='(</td>\n",
" <td>-1.465098</td>\n",
" <td>23.105533</td>\n",
" <td>0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td></td>\n",
" <td>-1.488245</td>\n",
" <td>22.576853</td>\n",
" <td>-0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>But I don't like pain ='(</td>\n",
" <td>-1.499659</td>\n",
" <td>22.320621</td>\n",
" <td>0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>That train flies through the air</td>\n",
" <td>-1.502256</td>\n",
" <td>22.262740</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>That's a really cool train</td>\n",
" <td>-1.514158</td>\n",
" <td>21.999342</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>I can't find my train</td>\n",
" <td>-1.523579</td>\n",
" <td>21.793059</td>\n",
" <td>-0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>That train made a big rain ='(</td>\n",
" <td>-1.538598</td>\n",
" <td>21.468188</td>\n",
" <td>0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>It makes me feel pain</td>\n",
" <td>-1.558070</td>\n",
" <td>21.054195</td>\n",
" <td>-0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>The train is a pain</td>\n",
" <td>-1.633714</td>\n",
" <td>19.520319</td>\n",
" <td>-0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>It goes on and on and on</td>\n",
" <td>-1.637339</td>\n",
" <td>19.449682</td>\n",
" <td>-0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>That's a cool pain</td>\n",
" <td>-1.637495</td>\n",
" <td>19.446663</td>\n",
" <td>-0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>But it's a pain to clean up the tracks</td>\n",
" <td>-1.639166</td>\n",
" <td>19.414195</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>That train's a pain</td>\n",
" <td>-1.647220</td>\n",
" <td>19.258458</td>\n",
" <td>-0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>That train goes through the rain</td>\n",
" <td>-1.686775</td>\n",
" <td>18.511552</td>\n",
" <td>0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td></td>\n",
" <td>-1.691131</td>\n",
" <td>18.431098</td>\n",
" <td>-0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>It rains on the train</td>\n",
" <td>-1.793993</td>\n",
" <td>16.629476</td>\n",
" <td>-0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>My brain's on a train</td>\n",
" <td>-1.799681</td>\n",
" <td>16.535157</td>\n",
" <td>-0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>I'm gonna fly on that plane</td>\n",
" <td>-1.814585</td>\n",
" <td>16.290550</td>\n",
" <td>-0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>It's the rain that made it cool</td>\n",
" <td>-1.823289</td>\n",
" <td>16.149373</td>\n",
" <td>-0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>It goes on and on</td>\n",
" <td>-1.829408</td>\n",
" <td>16.050864</td>\n",
" <td>0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>That's a big brain</td>\n",
" <td>-1.866597</td>\n",
" <td>15.464904</td>\n",
" <td>-0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>That's a great rain</td>\n",
" <td>-1.927979</td>\n",
" <td>14.544180</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>It's a big rain</td>\n",
" <td>-1.937891</td>\n",
" <td>14.400728</td>\n",
" <td>-0.3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>I love to watch it rain</td>\n",
" <td>-1.984703</td>\n",
" <td>13.742141</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>That train runs on pain</td>\n",
" <td>-1.992455</td>\n",
" <td>13.636021</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>The train is so fast</td>\n",
" <td>-1.997649</td>\n",
" <td>13.565388</td>\n",
" <td>-0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>The rain is so fun</td>\n",
" <td>-2.000404</td>\n",
" <td>13.528068</td>\n",
" <td>-0.1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>That's the brain train</td>\n",
" <td>-2.005497</td>\n",
" <td>13.459339</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>I want to ride that train</td>\n",
" <td>-2.009030</td>\n",
" <td>13.411870</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>It goes up and down the rain</td>\n",
" <td>-2.016399</td>\n",
" <td>13.313404</td>\n",
" <td>0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>I had a train pain</td>\n",
" <td>-2.069944</td>\n",
" <td>12.619285</td>\n",
" <td>0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>It is so cool to play with pain</td>\n",
" <td>-2.074948</td>\n",
" <td>12.556298</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>That train's so awesome</td>\n",
" <td>-2.165770</td>\n",
" <td>11.466161</td>\n",
" <td>-0.4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>That train was so long</td>\n",
" <td>-2.189222</td>\n",
" <td>11.200386</td>\n",
" <td>-0.5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>It makes me feel happy again</td>\n",
" <td>-2.263620</td>\n",
" <td>10.397338</td>\n",
" <td>0.2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>It makes me feel fine</td>\n",
" <td>-2.304157</td>\n",
" <td>9.984290</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" text ... frequency_penalty\n",
"3 That's a cool train ... -0.4\n",
"0 ... -0.4\n",
"3 That train is so pain ='( ... 0.1\n",
"1 That's a cool plane ... 0.3\n",
"0 That's a cool rain ... 0.3\n",
"4 That's a cool rain ... 0.3\n",
"2 That train was a strain ... 0.1\n",
"2 It makes a lot of noise ... -0.4\n",
"4 My train has a strain ... 0.1\n",
"1 It's so long and it's so tall ... -0.5\n",
"0 It's raining outside ... -0.3\n",
"3 That's a cool rain ... -0.3\n",
"0 That's a painful train ... 0.1\n",
"2 I want to be a train driver ... 0.2\n",
"3 It's raining outside ... 0.2\n",
"2 That train was a pain ... 0.3\n",
"3 That's a train of pain ... -0.1\n",
"3 That's a cool rain ... 0.5\n",
"0 It's so cool it's a plane ... 0.4\n",
"3 I got hit by a train ='( ... 0.4\n",
"1 ... -0.3\n",
"4 But I don't like pain ='( ... 0.4\n",
"0 That train flies through the air ... 0.2\n",
"1 That's a really cool train ... 0.2\n",
"0 I can't find my train ... -0.2\n",
"2 That train made a big rain ='( ... 0.4\n",
"3 It makes me feel pain ... -0.2\n",
"4 The train is a pain ... -0.2\n",
"1 It goes on and on and on ... -0.1\n",
"2 That's a cool pain ... -0.2\n",
"0 But it's a pain to clean up the tracks ... 0.5\n",
"2 That train's a pain ... -0.1\n",
"1 That train goes through the rain ... 0.1\n",
"2 ... -0.5\n",
"4 It rains on the train ... -0.1\n",
"4 My brain's on a train ... -0.5\n",
"1 I'm gonna fly on that plane ... -0.2\n",
"4 It's the rain that made it cool ... -0.3\n",
"3 It goes on and on ... 0.3\n",
"3 That's a big brain ... -0.5\n",
"1 That's a great rain ... 0.5\n",
"2 It's a big rain ... -0.3\n",
"3 I love to watch it rain ... 0.0\n",
"1 That train runs on pain ... 0.0\n",
"1 The train is so fast ... -0.4\n",
"0 The rain is so fun ... -0.1\n",
"4 That's the brain train ... 0.5\n",
"2 I want to ride that train ... 0.0\n",
"2 It goes up and down the rain ... 0.5\n",
"1 I had a train pain ... 0.4\n",
"0 It is so cool to play with pain ... 0.0\n",
"4 That train's so awesome ... -0.4\n",
"0 That train was so long ... -0.5\n",
"4 It makes me feel happy again ... 0.2\n",
"4 It makes me feel fine ... 0.0\n",
"\n",
"[55 rows x 4 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 272
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment