Skip to content

Instantly share code, notes, and snippets.

@brockmanmatt
Created August 8, 2020 03:00
Show Gist options
  • Save brockmanmatt/572d8f98830f6202446c5737a43ee014 to your computer and use it in GitHub Desktop.
Save brockmanmatt/572d8f98830f6202446c5737a43ee014 to your computer and use it in GitHub Desktop.
commas_vs_ints_davinci2B.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "commas_vs_ints_davinci2B.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyOe3bQGyNGhhzMNiOM9OPET",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/brockmanmatt/572d8f98830f6202446c5737a43ee014/commas_vs_ints_davinci2b.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "J7wnsgT2kPut",
"colab_type": "code",
"colab": {
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
"ok": true,
"headers": [
[
"content-type",
"application/javascript"
]
],
"status": 200,
"status_text": ""
}
},
"base_uri": "https://localhost:8080/",
"height": 89
},
"outputId": "b06c74c9-1fbb-404a-b835-4ead99e9baab"
},
"source": [
"from google.colab import files\n",
"uploaded = files.upload()\n",
"print(\"done\")"
],
"execution_count": 1,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-0743d66a-d8b9-4f61-95c3-e08bcc962ffc\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-0743d66a-d8b9-4f61-95c3-e08bcc962ffc\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script src=\"/nbextensions/google.colab/files.js\"></script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Saving key.json to key.json\n",
"done\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WHPHrUnhpKnI",
"colab_type": "text"
},
"source": [
"I'll install the API"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zq0ltp2xn4yt",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 292
},
"outputId": "0eab0c06-b186-4192-c7ea-56a371c55c5e"
},
"source": [
"!pip install openai\n",
"import openai, json, pandas as pd, numpy as np, random"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting openai\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a8/65/c7461f4c87984534683f480ea5742777bc39bbf5721123194c2d0347dc1f/openai-0.2.4.tar.gz (157kB)\n",
"\r\u001b[K |██ | 10kB 18.9MB/s eta 0:00:01\r\u001b[K |████▏ | 20kB 1.6MB/s eta 0:00:01\r\u001b[K |██████▎ | 30kB 2.1MB/s eta 0:00:01\r\u001b[K |████████▍ | 40kB 2.3MB/s eta 0:00:01\r\u001b[K |██████████▍ | 51kB 1.9MB/s eta 0:00:01\r\u001b[K |████████████▌ | 61kB 2.1MB/s eta 0:00:01\r\u001b[K |██████████████▋ | 71kB 2.3MB/s eta 0:00:01\r\u001b[K |████████████████▊ | 81kB 2.6MB/s eta 0:00:01\r\u001b[K |██████████████████▊ | 92kB 2.7MB/s eta 0:00:01\r\u001b[K |████████████████████▉ | 102kB 2.6MB/s eta 0:00:01\r\u001b[K |███████████████████████ | 112kB 2.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████ | 122kB 2.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████ | 133kB 2.6MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▏ | 143kB 2.6MB/s eta 0:00:01\r\u001b[K |███████████████████████████████▎| 153kB 2.6MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 163kB 2.6MB/s \n",
"\u001b[?25hRequirement already satisfied: requests>=2.20 in /usr/local/lib/python3.6/dist-packages (from openai) (2.23.0)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (2020.6.20)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (2.10)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20->openai) (3.0.4)\n",
"Building wheels for collected packages: openai\n",
" Building wheel for openai (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for openai: filename=openai-0.2.4-cp36-none-any.whl size=170709 sha256=4ae5feab1611e97f68fc99db74a07d1f06a1b23fc1a629630d6567a1a0c2ac53\n",
" Stored in directory: /root/.cache/pip/wheels/74/96/c8/c6e170929c276b836613e1b9985343b501fe455e53d85e7d48\n",
"Successfully built openai\n",
"Installing collected packages: openai\n",
"Successfully installed openai-0.2.4\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q2yE0jcnpMEV",
"colab_type": "text"
},
"source": [
"Loading in key.json that I uploaded; I do this so I don't need to worry about accidently leaking creds if I share the colab (which I'm 99% sure is just a json file that won't expose them)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "bwNXXwHen5x9",
"colab_type": "code",
"colab": {}
},
"source": [
"openai.api_key = json.load(open(\"key.json\", \"r\"))[\"key\"]"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "k67w5H0fpTkT",
"colab_type": "text"
},
"source": [
"Default keyword arguments to pass the aPI"
]
},
{
"cell_type": "code",
"metadata": {
"id": "s4Xl3jO9py9V",
"colab_type": "code",
"colab": {}
},
"source": [
"def query(prompt):\n",
" \"\"\"\n",
" wrapper for the API\n",
" \"\"\"\n",
" kwargs = {\n",
" \"engine\":\"davinci-v2b\",\n",
" \"temperature\":0,\n",
" \"max_tokens\":40,\n",
" \"stop\":\"\\n\",\n",
" }\n",
"\n",
" r = openai.Completion.create(prompt=prompt, **kwargs)[\"choices\"][0][\"text\"].strip()\n",
" return r"
],
"execution_count": 63,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "zZubgPoOpWDH",
"colab_type": "text"
},
"source": [
"Quick wrapper to automatically save prompts and responses sent for later analysis if needed"
]
},
{
"cell_type": "code",
"metadata": {
"id": "MXxpacj5oPMJ",
"colab_type": "code",
"colab": {}
},
"source": [
"random.seed(42)"
],
"execution_count": 64,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VoGjY-vDj9Kw",
"colab_type": "code",
"colab": {}
},
"source": [
"train = {}\n",
"test = {}\n",
"\n",
"# for ints of int_length digits, randomly add them together. Train set will be context, test will be test\n",
"for int_length in range(3,12):\n",
" sums = []\n",
" for example in range(30):\n",
" a = random.randint(1,10**int_length)\n",
" b = random.randint(1,10**int_length)\n",
" c = a + b\n",
" sums.append({\"a\":a, \"b\":b, \"c\":c})\n",
" train[int_length] = pd.DataFrame(sums)\n",
" \n",
" sums = []\n",
" for example in range(100):\n",
" a = random.randint(1,10**int_length)\n",
" b = random.randint(1,10**int_length)\n",
" c = a + b\n",
" sums.append({\"a\":a, \"b\":b, \"c\":c})\n",
" test[int_length] = pd.DataFrame(sums)"
],
"execution_count": 65,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "l85AaDxSkvfs",
"colab_type": "code",
"colab": {}
},
"source": [
"ints_context = {}\n",
"commas_context = {}\n",
"\n",
"# set up the contexts (solved); convert comma one to commass\n",
"for int_length in range(3,12):\n",
" int_df = train[int_length].copy()\n",
" int_df[\"format\"] = int_df[\"a\"].astype(str) + \" + \" + int_df[\"b\"].astype(str) + \" = \" + int_df[\"c\"].astype(str)\n",
" ints_context[int_length] = int_df.copy()\n",
" \n",
" comma_df = train[int_length].copy()\n",
" comma_df[\"format\"] = comma_df[\"a\"].apply(lambda x: \"{:,}\".format(x)) + \" + \" + int_df[\"b\"].apply(lambda x: \"{:,}\".format(x)) + \" = \" + int_df[\"c\"].apply(lambda x: \"{:,}\".format(x))\n",
" commas_context[int_length] = comma_df.copy()\n"
],
"execution_count": 66,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9jJP5cfioh7w",
"colab_type": "code",
"colab": {}
},
"source": [
"ints_test = {}\n",
"commas_test = {}\n",
"\n",
"# set up the prompts for the test sets ) + b)\n",
"for int_length in range(3,12):\n",
" int_df = test[int_length].copy()\n",
" int_df[\"prompt\"] = int_df[\"a\"].astype(str) + \" + \" + int_df[\"b\"].astype(str) + \" =\"\n",
" ints_test[int_length] = int_df.copy()\n",
"\n",
"\n",
" comma_df = test[int_length].copy()\n",
" comma_df[\"prompt\"] = comma_df[\"a\"].apply(lambda x: \"{:,}\".format(x)) + \" + \" + int_df[\"b\"].apply(lambda x: \"{:,}\".format(x)) + \" =\"\n",
" commas_test[int_length] = comma_df.copy()\n"
],
"execution_count": 67,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "EdFXafcJpZ3Q",
"colab_type": "text"
},
"source": [
"Test to make sure my query works"
]
},
{
"cell_type": "code",
"metadata": {
"id": "B4mEj1pAo_z4",
"colab_type": "code",
"colab": {}
},
"source": [
"int_results = {}\n",
"comma_results = {}"
],
"execution_count": 68,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4SlyKgjyopPn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "252500e0-1fba-45b8-c4c1-2a0c409adc04"
},
"source": [
"#actually do the tests; trying with just 3 first\n",
"for int_length in range(3,4):\n",
" #check ints\n",
" if int_length not in int_results:\n",
" print(\"Checking int results for {}\".format(int_length))\n",
" int_results[int_length] = []\n",
" context = \"\\n\".join(ints_context[int_length].format) + \"\\n\"\n",
" for example in ints_test[int_length][\"prompt\"].to_list():\n",
" int_results[int_length].append(query(context + example))\n",
" \n",
" if int_length not in comma_results:\n",
" print(\"Checking comma results for {}\".format(int_length))\n",
" comma_results[int_length] = []\n",
" context = \"\\n\".join(commas_context[int_length].format) + \"\\n\"\n",
" for example in commas_test[int_length][\"prompt\"].to_list():\n",
" comma_results[int_length].append(query(context + example))\n",
"\n"
],
"execution_count": 69,
"outputs": [
{
"output_type": "stream",
"text": [
"Checking int results for 3\n",
"Checking comma results for 3\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "CFUH3pO_urcw",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
},
"outputId": "cc4794a4-cabe-4c37-8678-cedd853cee16"
},
"source": [
"#k, doing the rest now\n",
"for int_length in range(4,12):\n",
" #check ints\n",
" if int_length not in int_results:\n",
" print(\"Checking int results for {}\".format(int_length))\n",
" int_results[int_length] = []\n",
" context = \"\\n\".join(ints_context[int_length].format) + \"\\n\"\n",
" for example in ints_test[int_length][\"prompt\"].to_list():\n",
" int_results[int_length].append(query(context + example))\n",
" \n",
" if int_length not in comma_results:\n",
" print(\"Checking comma results for {}\".format(int_length))\n",
" comma_results[int_length] = []\n",
" context = \"\\n\".join(commas_context[int_length].format) + \"\\n\"\n",
" for example in commas_test[int_length][\"prompt\"].to_list():\n",
" comma_results[int_length].append(query(context + example))\n",
"\n"
],
"execution_count": 70,
"outputs": [
{
"output_type": "stream",
"text": [
"Checking int results for 4\n",
"Checking comma results for 4\n",
"Checking int results for 5\n",
"Checking comma results for 5\n",
"Checking int results for 6\n",
"Checking comma results for 6\n",
"Checking int results for 7\n",
"Checking comma results for 7\n",
"Checking int results for 8\n",
"Checking comma results for 8\n",
"Checking int results for 9\n",
"Checking comma results for 9\n",
"Checking int results for 10\n",
"Checking comma results for 10\n",
"Checking int results for 11\n",
"Checking comma results for 11\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ycFPWpInpK-Y",
"colab_type": "code",
"colab": {}
},
"source": [
"#now I should have results\n",
"test_scored = {}\n",
"for val in test:\n",
" test_scored[val] = test[val].copy()"
],
"execution_count": 71,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "HJ7QXjNZpXG5",
"colab_type": "code",
"colab": {}
},
"source": [
"for int_length in int_results:\n",
" test_scored[int_length][\"ints\"] = int_results[int_length]\n",
" test_scored[int_length][\"ints\"] = test_scored[int_length][\"ints\"].apply(lambda x: int(x.replace(\",\", \"\")))\n",
" test_scored[int_length][\"commas\"] = comma_results[int_length]\n",
" test_scored[int_length][\"commas\"] = test_scored[int_length][\"commas\"].apply(lambda x: int(x.replace(\",\", \"\")))"
],
"execution_count": 72,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "htmYEKkor_Nn",
"colab_type": "code",
"colab": {}
},
"source": [
"mape_df = pd.DataFrame()\n",
"exact_match_df = pd.DataFrame()"
],
"execution_count": 73,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gaIy7yDVsiM3",
"colab_type": "code",
"colab": {}
},
"source": [
"for int_length in int_results:\n",
" mape_df.at[\"int\", int_length] = 100*np.mean(np.abs(test_scored[int_length][\"ints\"].fillna(0)-test_scored[int_length][\"c\"])/test_scored[int_length][\"c\"])\n",
" mape_df.at[\"commas\", int_length] = 100*np.mean(np.abs(test_scored[int_length][\"commas\"].fillna(0)-test_scored[int_length][\"c\"])/test_scored[int_length][\"c\"])\n",
"\n",
" exact_match_df.at[\"int\", int_length] = 100*sum(test_scored[int_length][\"ints\"].fillna(0) == test_scored[int_length][\"c\"])/len(test_scored[int_length])\n",
" exact_match_df.at[\"commas\", int_length] = 100*sum(test_scored[int_length][\"commas\"].fillna(0) == test_scored[int_length][\"c\"])/len(test_scored[int_length])"
],
"execution_count": 74,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0mmKqB2G1ivK",
"colab_type": "code",
"colab": {
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
"ok": true,
"headers": [
[
"content-type",
"application/javascript"
]
],
"status": 200,
"status_text": ""
}
},
"base_uri": "https://localhost:8080/",
"height": 89
},
"outputId": "656bb1b1-e118-49c3-9bad-0635ed10cc3b"
},
"source": [
"from google.colab import files\n",
"uploaded = files.upload()\n",
"print(\"done\")"
],
"execution_count": 53,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-8049ba48-b1d4-435e-b742-32ee2e74ab96\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-8049ba48-b1d4-435e-b742-32ee2e74ab96\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script src=\"/nbextensions/google.colab/files.js\"></script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Saving math_df.pkl to math_df (2).pkl\n",
"done\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "EG8Pg1Rr1yxQ",
"colab_type": "code",
"colab": {}
},
"source": [
"old_df_combined = pd.read_pickle(\"math_df.pkl\")"
],
"execution_count": 75,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "1vy9o_o93VRN",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 111
},
"outputId": "933626cb-c074-4fd6-f50a-8df4f142ebca"
},
"source": [
"mape_df"
],
"execution_count": 52,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>10</th>\n",
" <th>11</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>int</th>\n",
" <td>0.818156</td>\n",
" <td>4.770848</td>\n",
" <td>7.712621</td>\n",
" <td>11.668516</td>\n",
" <td>14.694527</td>\n",
" <td>26.053263</td>\n",
" <td>47.400036</td>\n",
" <td>27.766476</td>\n",
" <td>26.772077</td>\n",
" </tr>\n",
" <tr>\n",
" <th>commas</th>\n",
" <td>0.485746</td>\n",
" <td>1.014961</td>\n",
" <td>0.127750</td>\n",
" <td>0.126378</td>\n",
" <td>0.228100</td>\n",
" <td>0.056372</td>\n",
" <td>0.454149</td>\n",
" <td>0.582585</td>\n",
" <td>17.114274</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 3 4 5 ... 9 10 11\n",
"int 0.818156 4.770848 7.712621 ... 47.400036 27.766476 26.772077\n",
"commas 0.485746 1.014961 0.127750 ... 0.454149 0.582585 17.114274\n",
"\n",
"[2 rows x 9 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 52
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "MxP_qlJU2Ksc",
"colab_type": "code",
"colab": {}
},
"source": [
"old_df = {}\n",
"for i in old_df_combined[\"size\"].unique():\n",
" old_df[i] = old_df_combined[old_df_combined[\"size\"]==i]"
],
"execution_count": 76,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PMVT7vsf2OLI",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 76,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sAQ0PNMT1hlk",
"colab_type": "code",
"colab": {}
},
"source": [
"for int_length in old_df:\n",
" mape_df.at[\"old_int\", int_length] = 100*np.mean(np.abs(old_df[int_length][\"ints\"].fillna(0)-old_df[int_length][\"c\"])/old_df[int_length][\"c\"])\n",
" mape_df.at[\"old_commas\", int_length] = 100*np.mean(np.abs(old_df[int_length][\"commas\"].fillna(0)-old_df[int_length][\"c\"])/old_df[int_length][\"c\"])\n",
"\n",
" exact_match_df.at[\"old_int\", int_length] = 100*sum(old_df[int_length][\"ints\"].fillna(0) == old_df[int_length][\"c\"])/len(old_df[int_length])\n",
" exact_match_df.at[\"old_commas\", int_length] = 100*sum(old_df[int_length][\"commas\"].fillna(0) == old_df[int_length][\"c\"])/len(old_df[int_length])"
],
"execution_count": 77,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IHGyLfO13mB4",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
},
"outputId": "1dba0dae-6aa0-4b3d-9dc8-e07f07186f22"
},
"source": [
"#MEAN ERROR\n",
"mape_df"
],
"execution_count": 78,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>10</th>\n",
" <th>11</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>int</th>\n",
" <td>0.757821</td>\n",
" <td>4.103469</td>\n",
" <td>7.548277</td>\n",
" <td>11.696423</td>\n",
" <td>14.751093</td>\n",
" <td>24.693287</td>\n",
" <td>47.711725</td>\n",
" <td>26.289521</td>\n",
" <td>27.070311</td>\n",
" </tr>\n",
" <tr>\n",
" <th>commas</th>\n",
" <td>0.396372</td>\n",
" <td>1.014230</td>\n",
" <td>0.127750</td>\n",
" <td>0.126378</td>\n",
" <td>0.227927</td>\n",
" <td>0.063925</td>\n",
" <td>0.453638</td>\n",
" <td>0.773123</td>\n",
" <td>17.114365</td>\n",
" </tr>\n",
" <tr>\n",
" <th>old_int</th>\n",
" <td>0.818156</td>\n",
" <td>4.770848</td>\n",
" <td>7.712621</td>\n",
" <td>11.668516</td>\n",
" <td>14.694527</td>\n",
" <td>26.053263</td>\n",
" <td>47.400036</td>\n",
" <td>27.766476</td>\n",
" <td>26.772077</td>\n",
" </tr>\n",
" <tr>\n",
" <th>old_commas</th>\n",
" <td>0.485746</td>\n",
" <td>1.014961</td>\n",
" <td>0.127750</td>\n",
" <td>0.126378</td>\n",
" <td>0.228100</td>\n",
" <td>0.056372</td>\n",
" <td>0.454149</td>\n",
" <td>0.582585</td>\n",
" <td>17.114274</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 3 4 5 ... 9 10 11\n",
"int 0.757821 4.103469 7.548277 ... 47.711725 26.289521 27.070311\n",
"commas 0.396372 1.014230 0.127750 ... 0.453638 0.773123 17.114365\n",
"old_int 0.818156 4.770848 7.712621 ... 47.400036 27.766476 26.772077\n",
"old_commas 0.485746 1.014961 0.127750 ... 0.454149 0.582585 17.114274\n",
"\n",
"[4 rows x 9 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 78
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "kDDew2Mx2mGr",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
},
"outputId": "a33ec602-75c8-4b93-dda0-acb4914c12c3"
},
"source": [
"exact_match_df"
],
"execution_count": 79,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>3</th>\n",
" <th>4</th>\n",
" <th>5</th>\n",
" <th>6</th>\n",
" <th>7</th>\n",
" <th>8</th>\n",
" <th>9</th>\n",
" <th>10</th>\n",
" <th>11</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>int</th>\n",
" <td>87.0</td>\n",
" <td>24.0</td>\n",
" <td>12.0</td>\n",
" <td>4.0</td>\n",
" <td>2.0</td>\n",
" <td>1.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>commas</th>\n",
" <td>93.0</td>\n",
" <td>93.0</td>\n",
" <td>84.0</td>\n",
" <td>81.0</td>\n",
" <td>74.0</td>\n",
" <td>78.0</td>\n",
" <td>62.0</td>\n",
" <td>61.0</td>\n",
" <td>52.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>old_int</th>\n",
" <td>86.0</td>\n",
" <td>24.0</td>\n",
" <td>12.0</td>\n",
" <td>4.0</td>\n",
" <td>2.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>old_commas</th>\n",
" <td>93.0</td>\n",
" <td>92.0</td>\n",
" <td>84.0</td>\n",
" <td>81.0</td>\n",
" <td>72.0</td>\n",
" <td>78.0</td>\n",
" <td>61.0</td>\n",
" <td>60.0</td>\n",
" <td>52.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" 3 4 5 6 7 8 9 10 11\n",
"int 87.0 24.0 12.0 4.0 2.0 1.0 0.0 0.0 0.0\n",
"commas 93.0 93.0 84.0 81.0 74.0 78.0 62.0 61.0 52.0\n",
"old_int 86.0 24.0 12.0 4.0 2.0 0.0 0.0 0.0 0.0\n",
"old_commas 93.0 92.0 84.0 81.0 72.0 78.0 61.0 60.0 52.0"
]
},
"metadata": {
"tags": []
},
"execution_count": 79
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ITRgz0N02Zsy",
"colab_type": "code",
"colab": {}
},
"source": [
"import matplotlib.pyplot as plt"
],
"execution_count": 80,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "vFGLFaXQ2bzc",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 367
},
"outputId": "2ffd9aae-1c9a-455b-d916-3ad242763700"
},
"source": [
"fig, axs = plt.subplots(nrows = 1, ncols =2, figsize=(10,5), facecolor=\"w\")\n",
"mape_df.T.plot(title=\"mean percent error (100 per length)\", ax=axs[0])\n",
"axs[0].set_xlabel(\"mean absolute percent error\")\n",
"axs[0].set_ylabel(\"max digits\")\n",
"exact_match_df.T.plot(title=\"percent exact match (100 per length)\", ax=axs[1])\n",
"axs[1].set_xlabel(\"max digits\")\n",
"axs[1].set_ylabel(\"percent exact match\")\n"
],
"execution_count": 81,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Text(0, 0.5, 'percent exact match')"
]
},
"metadata": {
"tags": []
},
"execution_count": 81
},
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x360 with 2 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hP-pW2BxuT1q",
"colab_type": "code",
"colab": {}
},
"source": [
"all_dfs = []\n",
"for length in test_scored:\n",
" tmp = test_scored[length].copy()\n",
" tmp[\"size\"] = length\n",
" all_dfs.append(tmp)"
],
"execution_count": 82,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9wvPr_h3uZnL",
"colab_type": "code",
"colab": {}
},
"source": [
"pd.concat(all_dfs).to_pickle(\"math_df.pkl\")"
],
"execution_count": 83,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "lE85y3hY2RDo",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
},
"outputId": "2f7a88fc-d956-464c-8298-191e46234322"
},
"source": [
"from google.colab import files\n",
"files.download('math_df.pkl')"
],
"execution_count": 84,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"\n",
" async function download(id, filename, size) {\n",
" if (!google.colab.kernel.accessAllowed) {\n",
" return;\n",
" }\n",
" const div = document.createElement('div');\n",
" const label = document.createElement('label');\n",
" label.textContent = `Downloading \"${filename}\": `;\n",
" div.appendChild(label);\n",
" const progress = document.createElement('progress');\n",
" progress.max = size;\n",
" div.appendChild(progress);\n",
" document.body.appendChild(div);\n",
"\n",
" const buffers = [];\n",
" let downloaded = 0;\n",
"\n",
" const channel = await google.colab.kernel.comms.open(id);\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
"\n",
" for await (const message of channel.messages) {\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
" if (message.buffers) {\n",
" for (const buffer of message.buffers) {\n",
" buffers.push(buffer);\n",
" downloaded += buffer.byteLength;\n",
" progress.value = downloaded;\n",
" }\n",
" }\n",
" }\n",
" const blob = new Blob(buffers, {type: 'application/binary'});\n",
" const a = document.createElement('a');\n",
" a.href = window.URL.createObjectURL(blob);\n",
" a.download = filename;\n",
" div.appendChild(a);\n",
" a.click();\n",
" div.remove();\n",
" }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"application/javascript": [
"download(\"download_bd001835-e84d-49bf-b44d-c6180abb806f\", \"math_df.pkl\", 51144)"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "A50bkKri4Alf",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment