Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brockmanmatt/55e8fb65c5dea180f47aef8921b60ec7 to your computer and use it in GitHub Desktop.
Save brockmanmatt/55e8fb65c5dea180f47aef8921b60ec7 to your computer and use it in GitHub Desktop.
WiC_SelfContextStuffingImproved3.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "WiC_SelfContextStuffingImproved3.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNuBEqDy3UUUPOTiw8FcABr",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/brockmanmatt/55e8fb65c5dea180f47aef8921b60ec7/wic_selfcontextstuffingimproved3.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "J7wnsgT2kPut",
"colab_type": "code",
"colab": {
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCkgewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwogICAgICBwZXJjZW50LnRleHRDb250ZW50ID0KICAgICAgICAgIGAke01hdGgucm91bmQoKHBvc2l0aW9uIC8gZmlsZURhdGEuYnl0ZUxlbmd0aCkgKiAxMDApfSUgZG9uZWA7CiAgICB9CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
"ok": true,
"headers": [
[
"content-type",
"application/javascript"
]
],
"status": 200,
"status_text": ""
}
},
"base_uri": "https://localhost:8080/",
"height": 89
},
"outputId": "38d2da67-67ed-4462-f952-8371e559379e"
},
"source": [
"from google.colab import files\n",
"uploaded = files.upload()\n",
"print(\"done\")"
],
"execution_count": 106,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-b38436f9-ce90-4d8e-9cc4-bee8ab0b344f\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-b38436f9-ce90-4d8e-9cc4-bee8ab0b344f\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script src=\"/nbextensions/google.colab/files.js\"></script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Saving key.json to key (1).json\n",
"done\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WHPHrUnhpKnI",
"colab_type": "text"
},
"source": [
"I'll install the API"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h4BX5shF5urh",
"colab_type": "text"
},
"source": [
"get wic dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zq0ltp2xn4yt",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 139
},
"outputId": "894894ba-1f56-46a0-980c-f7b2902b4ee8"
},
"source": [
"!pip install openai\n",
"import openai, json, pandas as pd, numpy as np, random"
],
"execution_count": 107,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: openai in /usr/local/lib/python3.6/dist-packages (0.2.4)\n",
"Requirement already satisfied: requests>=2.20; python_version >= \"3.0\" in /usr/local/lib/python3.6/dist-packages (from openai) (2.23.0)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (2.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (2020.6.20)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.20; python_version >= \"3.0\"->openai) (3.0.4)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k67w5H0fpTkT",
"colab_type": "text"
},
"source": [
"args to pass API; one for 1 line 1 for 2 line"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e1EwpqqJkTYh",
"colab_type": "code",
"colab": {}
},
"source": [
"#arguments to send the API\n",
"kwargs = {\n",
"\"engine\":\"davinci\",\n",
"\"temperature\":0,\n",
"\"max_tokens\":200,\n",
"\"stop\":\"\\n\",\n",
"}\n",
"\n",
"kwargs2 = {\n",
"\"engine\":\"davinci\",\n",
"\"temperature\":0,\n",
"\"max_tokens\":200,\n",
"\"stop\":\"\\n\\n\",\n",
"}\n",
"\n"
],
"execution_count": 240,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XPqahGflR_Tc",
"colab_type": "code",
"colab": {}
},
"source": [
"openai.api_key = json.load(open(\"key.json\", \"r\"))[\"key\"]"
],
"execution_count": 109,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "sXTDJx0An9Bl",
"colab_type": "code",
"colab": {}
},
"source": [
"def queryTwoLine(prompt, myKwargs = kwargs2):\n",
" \"\"\"\n",
" wrapper for the API (get 2 newlines)\n",
" \"\"\"\n",
" r = openai.Completion.create(prompt=prompt, **myKwargs)[\"choices\"][0][\"text\"].strip()\n",
" return r\n",
"\n",
"\n",
"def queryOneLine(prompt, myKwargs = kwargs):\n",
" \"\"\"\n",
" wrapper for the API (get 1 newlines)\n",
" \"\"\"\n",
" r = openai.Completion.create(prompt=prompt, **myKwargs)[\"choices\"][0][\"text\"].strip()\n",
" return r"
],
"execution_count": 413,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "EdFXafcJpZ3Q",
"colab_type": "text"
},
"source": [
"Test to make sure my query works"
]
},
{
"cell_type": "code",
"metadata": {
"id": "4SlyKgjyopPn",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "29292c93-9851-4a12-db38-1c761489c7d2"
},
"source": [
"newKwargs = kwargs.copy()\n",
"newKwargs[\"stop\"] = \"\\n\"\n",
"queryOneLine(\"q: what is 1+1?\\na:\")"
],
"execution_count": 111,
"outputs": [
{
"output_type": "execute_result",
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'2'"
]
},
"metadata": {
"tags": []
},
"execution_count": 111
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AAR9Guv1QPZk",
"colab_type": "text"
},
"source": [
"Get the WiC dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "XDy96ovI9hJm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 224
},
"outputId": "7ac46d54-7501-48ed-f9ec-c17291a20b86"
},
"source": [
"!wget https://pilehvar.github.io/wic/package/WiC_dataset.zip"
],
"execution_count": 112,
"outputs": [
{
"output_type": "stream",
"text": [
"--2020-07-31 05:43:48-- https://pilehvar.github.io/wic/package/WiC_dataset.zip\n",
"Resolving pilehvar.github.io (pilehvar.github.io)... 185.199.109.153, 185.199.111.153, 185.199.110.153, ...\n",
"Connecting to pilehvar.github.io (pilehvar.github.io)|185.199.109.153|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 275513 (269K) [application/zip]\n",
"Saving to: ‘WiC_dataset.zip.1’\n",
"\n",
"\rWiC_dataset.zip.1 0%[ ] 0 --.-KB/s \rWiC_dataset.zip.1 100%[===================>] 269.06K --.-KB/s in 0.04s \n",
"\n",
"2020-07-31 05:43:48 (6.01 MB/s) - ‘WiC_dataset.zip.1’ saved [275513/275513]\n",
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mh8Ztw-JQRMr",
"colab_type": "text"
},
"source": [
"Unzip the dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Gx-SlrBpSVXJ",
"colab_type": "code",
"colab": {}
},
"source": [
"import zipfile\n",
"with zipfile.ZipFile(\"WiC_dataset.zip\",\"r\") as zip_ref:\n",
" zip_ref.extractall(\".\")\n"
],
"execution_count": 113,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "wOjsEzALQTXW",
"colab_type": "text"
},
"source": [
"Read in train and add the T/F label for if they're the same"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "CAHaOQKx9ogG",
"colab": {}
},
"source": [
"train = pd.read_csv(\"train/train.data.txt\", sep='\\t', header=None)\n",
"train.columns = [\"target\", \"pos\", \"position\", \"context-1\", \"context-2\"]\n",
"train_gold = pd.read_csv(\"train/train.gold.txt\", sep='\\t', header=None)\n",
"train_gold.columns = [\"label\"]\n",
"train = pd.concat([train_gold,train], axis=1)"
],
"execution_count": 114,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XC1pGirRwF8Z",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 272
},
"outputId": "37fb3248-ec7d-452a-e53a-9d2d18b457f4"
},
"source": [
"train.head()"
],
"execution_count": 167,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>target</th>\n",
" <th>pos</th>\n",
" <th>position</th>\n",
" <th>context-1</th>\n",
" <th>context-2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>F</td>\n",
" <td>carry</td>\n",
" <td>V</td>\n",
" <td>2-1</td>\n",
" <td>You must carry your camping gear .</td>\n",
" <td>Sound carries well over water .</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>F</td>\n",
" <td>go</td>\n",
" <td>V</td>\n",
" <td>2-6</td>\n",
" <td>Messages must go through diplomatic channels .</td>\n",
" <td>Do you think the sofa will go through the door ?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>F</td>\n",
" <td>break</td>\n",
" <td>V</td>\n",
" <td>0-2</td>\n",
" <td>Break an alibi .</td>\n",
" <td>The wholesaler broke the container loads into palettes and boxes for local retailers .</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>T</td>\n",
" <td>cup</td>\n",
" <td>N</td>\n",
" <td>8-4</td>\n",
" <td>He wore a jock strap with a metal cup .</td>\n",
" <td>Bees filled the waxen cups with honey .</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>F</td>\n",
" <td>academy</td>\n",
" <td>N</td>\n",
" <td>1-2</td>\n",
" <td>The Academy of Music .</td>\n",
" <td>The French Academy .</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" label ... context-2\n",
"0 F ... Sound carries well over water . \n",
"1 F ... Do you think the sofa will go through the door ? \n",
"2 F ... The wholesaler broke the container loads into palettes and boxes for local retailers .\n",
"3 T ... Bees filled the waxen cups with honey . \n",
"4 F ... The French Academy . \n",
"\n",
"[5 rows x 6 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 167
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3neGYQj6d9zU",
"colab_type": "text"
},
"source": [
"Def to bootstrap meanings; this prompt asks it what a term means in context, then returns the question + response "
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "JE6GhNTi4hTT",
"colab": {}
},
"source": [
"def getContextualMeaningExample(content, term):\n",
" prompt = \"Tom said '{}'.\\n\".format(content, term)\n",
" prompt += \"I asked Tom what '{}' means in this context, he clarified it is another word for\".format(term) \n",
"\n",
" return (prompt + \" \" +queryOneLine(prompt))"
],
"execution_count": 352,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "g3IO1TK64hTX"
},
"source": [
"Get response to end of training set as my fewshot examples, print to make sure it makes sense"
]
},
{
"cell_type": "code",
"metadata": {
"colab_type": "code",
"id": "FD6Sh9oY4hTX",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 289
},
"outputId": "1b8240b8-049b-4cc8-eeb1-6f07285a4141"
},
"source": [
"fewShotsDefinition = \"\" #I'll build this from the training set\n",
"for row in train.tail(8).iterrows():\n",
" s1 = row[1][\"context-1\"]\n",
" target = row[1][\"target\"]\n",
" r = getContextualMeaningExample(s1, target)\n",
" print(r)\n",
" fewShotsDefinition += \" \".join(r.split(\"\\n\")) #add the definition to fewShotsDefinition as 1 line\n",
" fewShotsDefinition += \"\\n\""
],
"execution_count": 354,
"outputs": [
{
"output_type": "stream",
"text": [
"Tom said 'We added a new rosebush to our rose bed .'.\n",
"I asked Tom what 'bed' means in this context, he clarified it is another word for 'garden'.\n",
"Tom said 'His state of health .'.\n",
"I asked Tom what 'state' means in this context, he clarified it is another word for 'condition'.\n",
"Tom said 'Likes a drink before dinner .'.\n",
"I asked Tom what 'drink' means in this context, he clarified it is another word for 'alcoholic drink'.\n",
"Tom said 'Piecas kronas — five krona .'.\n",
"I asked Tom what 'krona' means in this context, he clarified it is another word for 'dollar'.\n",
"Tom said 'The harder the conflict the more glorious the triumph \"-- Thomas Paine .'.\n",
"I asked Tom what 'conflict' means in this context, he clarified it is another word for 'war'.\n",
"Tom said 'Answer the riddle .'.\n",
"I asked Tom what 'answer' means in this context, he clarified it is another word for 'solve'.\n",
"Tom said 'Play the casinos in Trouville .'.\n",
"I asked Tom what 'play' means in this context, he clarified it is another word for gamble.\n",
"Tom said 'An invasion of bees .'.\n",
"I asked Tom what 'invasion' means in this context, he clarified it is another word for 'attack'.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WvZMJAhDQjgb",
"colab_type": "text"
},
"source": [
"Now use the few shots for my actual method, and allow me to context stuff the previous answer so if the meaning is the same it'll likely choose it and if they're different it'll pick up on the difference hopefully and not choose it again. Using presense_penalty to increase odds of choosing new token. \n",
"\n",
"Potentially adjusting this is a parameter to tune."
]
},
{
"cell_type": "code",
"metadata": {
"id": "BEmsnBCErBZX",
"colab_type": "code",
"colab": {}
},
"source": [
"def getContextualMeaning(content, term, contexts = []):\n",
" prompt = fewShotsDefinition\n",
" \n",
" for context in contexts:\n",
" prompt += \"Tom said '{}'.\\n\".format(context[\"content\"])\n",
" prompt += \"I asked Tom what '{}' means in this context, he clarified it is another word for {}\\n\\n\".format(context[\"term\"], context[\"meaning\"]) \n",
"\n",
" prompt += \"Tom said '{}'.\\n\".format(content, term)\n",
" prompt += \"I asked Tom what '{}' means in this context, he clarified it is another word for\".format(term) \n",
"\n",
" r = queryOneLine(prompt, myKwargs = {'engine': 'davinci', 'max_tokens': 20, 'stop': '\\n', 'temperature': 0, \"presence_penalty\":.5})\n",
" if not r.startswith(\"'\"):\n",
" r = \"'\" + r\n",
" \n",
" return r"
],
"execution_count": 412,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "pPWT-NZwRq5r",
"colab_type": "text"
},
"source": [
"Now I make a method to build comparison examples; takes two sentences and a label (which is the actual answer if they're true/false); it generates an example"
]
},
{
"cell_type": "code",
"metadata": {
"id": "g5E68j_bRii0",
"colab_type": "code",
"colab": {}
},
"source": [
"def generateComparisonExample(s1, s2, label):\n",
" prompt = \"Tom says that this means {} Jerry says this means {}\\n\".format(s1, s2)\n",
" prompt += \"Q: Are Tom and Jerry basically saying the same thing here?\\nA:\"\n",
" if label == \"T\":\n",
" prompt += \" Yes\"\n",
" else:\n",
" prompt += \" No\"\n",
" return prompt"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "jHBUBl2eQ5vs",
"colab_type": "text"
},
"source": [
"Now generate a list of examples from the builder method. I'll be able to select them from the list, but I'll end up just joining all of them together for the actual prompt (possibly something to use to improve)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "DNSFTSUt07CB",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 969
},
"outputId": "820f73dd-c620-4e88-d7e8-763e4fa17fbb"
},
"source": [
"comparisonFewShotExamples = []\n",
"for row in train[-18:-10].iterrows():\n",
" s1 = row[1][\"context-1\"]\n",
" s2 = row[1][\"context-2\"]\n",
" label = row[1][\"label\"]\n",
" target = row[1][\"target\"]\n",
" r1 = getContextualMeaning(s1, target)\n",
" r2 = getContextualMeaning(s2, target, contexts=[{\"content\":s1, \"term\":target, \"meaning\":r1}])\n",
" print(s1)\n",
" print(s2)\n",
" \n",
" print(r1)\n",
" print(r2)\n",
"\n",
" r= generateComparisonExample(r1, r2, label)\n",
" comparisonFewShotExamples.append(r)\n",
" print(r)\n"
],
"execution_count": 356,
"outputs": [
{
"output_type": "stream",
"text": [
"Women carrying home shopping did n't give me a second glance .\n",
"On Saturdays we usually do the shopping .\n",
"'basket'.\n",
"'basket'.\n",
"Tom says that this means 'basket'. Jerry says this means 'basket'.\n",
"Q: Are Tom and Jerry basically saying the same thing here?\n",
"A: No\n",
"While being impulsive can be great for artists , it is not a desirable quality for engineers .\n",
"Security , stability , and efficiency are good qualities of an operating system .\n",
"'characteristic'.\n",
"'characteristic'.\n",
"Tom says that this means 'characteristic'. Jerry says this means 'characteristic'.\n",
"Q: Are Tom and Jerry basically saying the same thing here?\n",
"A: Yes\n",
"The cinema relies on apparent motion .\n",
"He made a motion to adjourn .\n",
"'movement'.\n",
"'suggestion'.\n",
"Tom says that this means 'movement'. Jerry says this means 'suggestion'.\n",
"Q: Are Tom and Jerry basically saying the same thing here?\n",
"A: No\n",
"Render thanks .\n",
"Render fat in a casserole .\n",
"'give'.\n",
"'cook'.\n",
"Tom says that this means 'give'. Jerry says this means 'cook'.\n",
"Q: Are Tom and Jerry basically saying the same thing here?\n",
"A: No\n",
"It vanished into the night .\n",
"The cat disappeared into the night .\n",
"'darkness'.\n",
"'darkness'.\n",
"Tom says that this means 'darkness'. Jerry says this means 'darkness'.\n",
"Q: Are Tom and Jerry basically saying the same thing here?\n",
"A: Yes\n",
"I drive to work every day .\n",
"We drove to the university every morning .\n",
"'travel by car'.\n",
"'travel by car'.\n",
"Tom says that this means 'travel by car'. Jerry says this means 'travel by car'.\n",
"Q: Are Tom and Jerry basically saying the same thing here?\n",
"A: Yes\n",
"He threw the ball into the air .\n",
"A smell of chemicals in the air .\n",
"'sky'.\n",
"'smell'.\n",
"Tom says that this means 'sky'. Jerry says this means 'smell'.\n",
"Q: Are Tom and Jerry basically saying the same thing here?\n",
"A: No\n",
"Keep open the possibility of a merger .\n",
"Keep my seat , please .\n",
"'maintain'.\n",
"'hold'.\n",
"Tom says that this means 'maintain'. Jerry says this means 'hold'.\n",
"Q: Are Tom and Jerry basically saying the same thing here?\n",
"A: Yes\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Frih6lJBSDJ6",
"colab_type": "text"
},
"source": [
"Now I have my actual comparison prompt that uses the few shots from above to let it know I just want a Yes/No answer."
]
},
{
"cell_type": "code",
"metadata": {
"id": "dmtBiasT1SUP",
"colab_type": "code",
"colab": {}
},
"source": [
"def generateComparison(s1, s2):\n",
" prompt = \"\\n\\n\".join(comparisonFewShotExamples)\n",
" prompt += \"\\n\\n\"\n",
"\n",
" prompt += \"Tom says that this means {}; Jerry says this means {}\\n\".format(s1, s2)\n",
" prompt += \"Q: Are Tom and Jerry basically saying the same thing here?\\nA:\"\n",
"\n",
"\n",
" return queryTwoLine(prompt, myKwargs = kwargs2Short)"
],
"execution_count": 362,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "U-NL6c6TSLux",
"colab_type": "text"
},
"source": [
"Test to make sure it works OK"
]
},
{
"cell_type": "code",
"metadata": {
"id": "L58zdQUp2B9v",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 377
},
"outputId": "fef4158e-f66d-4925-9947-050578ac2d8e"
},
"source": [
"for row in train[:5].iterrows():\n",
" s1 = row[1][\"context-1\"]\n",
" s2 = row[1][\"context-2\"]\n",
" label = row[1][\"label\"]\n",
" target = row[1][\"target\"]\n",
" r1 = getContextualMeaning(s1, target)\n",
" r2 = getContextualMeaning(s2, target, contexts=[{\"content\":s1, \"term\":target, \"meaning\":r1}])\n",
" r= generateComparison(r1, r2)\n",
" print(\"'{}' v '{}'\".format(s1, s2))\n",
" print(\"{} v {}\".format(r1, r2))\n",
" print(\"returned: {}\".format(r))\n",
" print(\"actual: {}\".format(label))\n"
],
"execution_count": 365,
"outputs": [
{
"output_type": "stream",
"text": [
"'You must carry your camping gear .' v 'Sound carries well over water .'\n",
"'bring'. v 'travel'.\n",
"returned: No\n",
"actual: F\n",
"'Messages must go through diplomatic channels .' v 'Do you think the sofa will go through the door ?'\n",
"'be sent'. v 'fit'.\n",
"returned: No\n",
"actual: F\n",
"'Break an alibi .' v 'The wholesaler broke the container loads into palettes and boxes for local retailers .'\n",
"'destroy'. v 'separate'.\n",
"returned: No\n",
"actual: F\n",
"'He wore a jock strap with a metal cup .' v 'Bees filled the waxen cups with honey .'\n",
"'protector'. v 'honeycomb'.\n",
"returned: No\n",
"actual: T\n",
"'The Academy of Music .' v 'The French Academy .'\n",
"'university'. v 'university'.\n",
"returned: Yes\n",
"actual: F\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bpc6A5pl68RM",
"colab_type": "text"
},
"source": [
"k, so now I'll load the dev set"
]
},
{
"cell_type": "code",
"metadata": {
"id": "eru96XmC6iWt",
"colab_type": "code",
"colab": {}
},
"source": [
"dev = pd.read_csv(\"dev/dev.data.txt\", sep='\\t', header=None)\n",
"dev.columns = [\"target\", \"pos\", \"position\", \"context-1\", \"context-2\"]\n",
"dev_gold = pd.read_csv(\"dev/dev.gold.txt\", sep='\\t', header=None)\n",
"dev_gold.columns = [\"label\"]\n",
"dev = pd.concat([dev_gold,dev], axis=1)\n"
],
"execution_count": 366,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "v5ycY33vSQ_O",
"colab_type": "text"
},
"source": [
"Keep track of right/wrong as I go along"
]
},
{
"cell_type": "code",
"metadata": {
"id": "CY6ixlHl77U7",
"colab_type": "code",
"colab": {}
},
"source": [
"devResults = {}\n",
"correct = 0\n",
"complete = 0\n"
],
"execution_count": 371,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "d-PZePcLSTZ4",
"colab_type": "text"
},
"source": [
"check entire dev set"
]
},
{
"cell_type": "code",
"metadata": {
"id": "k7P4pzAL2Pxl",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 663
},
"outputId": "556440a1-5802-4934-a403-c633fb539095"
},
"source": [
"for row in dev.iterrows():\n",
" if row[0] in devResults: continue\n",
" s1 = row[1][\"context-1\"]\n",
" s2 = row[1][\"context-2\"]\n",
" label = row[1][\"label\"]\n",
" target = row[1][\"target\"]\n",
" r1 = getContextualMeaning(s1, target)\n",
" r2 = getContextualMeaning(s2, target, contexts=[{\"content\":s1, \"term\":target, \"meaning\":r1}])\n",
" r= generateComparison(r1, r2)\n",
"\n",
" myResults = {}\n",
" myResults[\"s1\"] = s1\n",
" myResults[\"s2\"] = s2\n",
"\n",
" myResults[\"pos\"] = row[1][\"pos\"]\n",
"\n",
" myResults[\"target\"] = target\n",
"\n",
" myResults[\"pred\"] = r\n",
"\n",
" myResults[\"actual\"] = label\n",
"\n",
" devResults[row[0]] = myResults\n",
"\n",
" complete +=1\n",
" if label == \"T\":\n",
" if r.strip()==\"Yes\":\n",
" correct += 1\n",
" if label == \"F\":\n",
" if r.strip()==\"No\":\n",
" correct += 1\n",
"\n",
" if row[0] %10 ==0:print (\"Complete: {} Correct: {} Wrong: {}\".format(complete, correct, complete-correct))\n"
],
"execution_count": 376,
"outputs": [
{
"output_type": "stream",
"text": [
"Complete: 251 Correct: 178 Wrong: 73\n",
"Complete: 261 Correct: 184 Wrong: 77\n",
"Complete: 271 Correct: 192 Wrong: 79\n",
"Complete: 281 Correct: 199 Wrong: 82\n",
"Complete: 291 Correct: 207 Wrong: 84\n",
"Complete: 301 Correct: 210 Wrong: 91\n",
"Complete: 311 Correct: 215 Wrong: 96\n",
"Complete: 321 Correct: 222 Wrong: 99\n",
"Complete: 331 Correct: 230 Wrong: 101\n",
"Complete: 341 Correct: 233 Wrong: 108\n",
"Complete: 351 Correct: 240 Wrong: 111\n",
"Complete: 361 Correct: 245 Wrong: 116\n",
"Complete: 371 Correct: 254 Wrong: 117\n",
"Complete: 381 Correct: 259 Wrong: 122\n",
"Complete: 391 Correct: 264 Wrong: 127\n",
"Complete: 401 Correct: 270 Wrong: 131\n",
"Complete: 411 Correct: 277 Wrong: 134\n",
"Complete: 421 Correct: 284 Wrong: 137\n",
"Complete: 431 Correct: 289 Wrong: 142\n",
"Complete: 441 Correct: 293 Wrong: 148\n",
"Complete: 451 Correct: 302 Wrong: 149\n",
"Complete: 461 Correct: 311 Wrong: 150\n",
"Complete: 471 Correct: 319 Wrong: 152\n",
"Complete: 481 Correct: 324 Wrong: 157\n",
"Complete: 491 Correct: 332 Wrong: 159\n",
"Complete: 501 Correct: 337 Wrong: 164\n",
"Complete: 511 Correct: 345 Wrong: 166\n",
"Complete: 521 Correct: 351 Wrong: 170\n",
"Complete: 531 Correct: 358 Wrong: 173\n",
"Complete: 541 Correct: 363 Wrong: 178\n",
"Complete: 551 Correct: 372 Wrong: 179\n",
"Complete: 561 Correct: 376 Wrong: 185\n",
"Complete: 571 Correct: 382 Wrong: 189\n",
"Complete: 581 Correct: 391 Wrong: 190\n",
"Complete: 591 Correct: 397 Wrong: 194\n",
"Complete: 601 Correct: 405 Wrong: 196\n",
"Complete: 611 Correct: 413 Wrong: 198\n",
"Complete: 631 Correct: 425 Wrong: 206\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vWyri1BiSWDX",
"colab_type": "text"
},
"source": [
"save dev set"
]
},
{
"cell_type": "code",
"metadata": {
"id": "55MlUCdD7sUa",
"colab_type": "code",
"colab": {}
},
"source": [
"devDf = pd.DataFrame(devResults).T"
],
"execution_count": 406,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WcSxf4IX8C0S",
"colab_type": "code",
"colab": {}
},
"source": [
"devDf.to_pickle(\"newDevResults.pkl\")"
],
"execution_count": 390,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "DRVg4lTKHeHZ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17
},
"outputId": "7a7cf2ef-99ec-423d-8645-b9e6e049b69a"
},
"source": [
"files.download(\"newDevResults.pkl\")"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/javascript": [
"\n",
" async function download(id, filename, size) {\n",
" if (!google.colab.kernel.accessAllowed) {\n",
" return;\n",
" }\n",
" const div = document.createElement('div');\n",
" const label = document.createElement('label');\n",
" label.textContent = `Downloading \"${filename}\": `;\n",
" div.appendChild(label);\n",
" const progress = document.createElement('progress');\n",
" progress.max = size;\n",
" div.appendChild(progress);\n",
" document.body.appendChild(div);\n",
"\n",
" const buffers = [];\n",
" let downloaded = 0;\n",
"\n",
" const channel = await google.colab.kernel.comms.open(id);\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
"\n",
" for await (const message of channel.messages) {\n",
" // Send a message to notify the kernel that we're ready.\n",
" channel.send({})\n",
" if (message.buffers) {\n",
" for (const buffer of message.buffers) {\n",
" buffers.push(buffer);\n",
" downloaded += buffer.byteLength;\n",
" progress.value = downloaded;\n",
" }\n",
" }\n",
" }\n",
" const blob = new Blob(buffers, {type: 'application/binary'});\n",
" const a = document.createElement('a');\n",
" a.href = window.URL.createObjectURL(blob);\n",
" a.download = filename;\n",
" div.appendChild(a);\n",
" a.click();\n",
" div.remove();\n",
" }\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "display_data",
"data": {
"application/javascript": [
"download(\"download_6cb69e4d-d79a-4cc7-bf2a-bed93f9fbd4d\", \"newDevResults.pkl\", 79159)"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "M6JUOi7HHaJm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"outputId": "2a61acef-8132-4cd8-bb33-63e46bab7469"
},
"source": [
"devDf.head()"
],
"execution_count": 392,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>s1</th>\n",
" <th>s2</th>\n",
" <th>pos</th>\n",
" <th>target</th>\n",
" <th>pred</th>\n",
" <th>actual</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Room and board .</td>\n",
" <td>He nailed boards across the windows .</td>\n",
" <td>N</td>\n",
" <td>board</td>\n",
" <td>No</td>\n",
" <td>F</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Circulate a rumor .</td>\n",
" <td>This letter is being circulated among the faculty .</td>\n",
" <td>V</td>\n",
" <td>circulate</td>\n",
" <td>No</td>\n",
" <td>F</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Hook a fish .</td>\n",
" <td>He hooked a snake accidentally , and was so scared he dropped his rod into the water .</td>\n",
" <td>V</td>\n",
" <td>hook</td>\n",
" <td>Yes</td>\n",
" <td>T</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>For recreation he wrote poetry and solved crossword puzzles .</td>\n",
" <td>Drug abuse is often regarded as a form of recreation .</td>\n",
" <td>N</td>\n",
" <td>recreation</td>\n",
" <td>No</td>\n",
" <td>T</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Making a hobby of domesticity .</td>\n",
" <td>A royal family living in unpretentious domesticity .</td>\n",
" <td>N</td>\n",
" <td>domesticity</td>\n",
" <td>No</td>\n",
" <td>F</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" s1 ... actual\n",
"0 Room and board . ... F \n",
"1 Circulate a rumor . ... F \n",
"2 Hook a fish . ... T \n",
"3 For recreation he wrote poetry and solved crossword puzzles . ... T \n",
"4 Making a hobby of domesticity . ... F \n",
"\n",
"[5 rows x 6 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 392
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RMbeh9h_SXVy",
"colab_type": "text"
},
"source": [
"Convert labels to the WiC labels"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Qkj00g8zOxuq",
"colab_type": "code",
"colab": {}
},
"source": [
"devDf[\"pred\"] = devDf[\"pred\"].apply(lambda x: \"T\" if x.strip() ==\"Yes\" else \"F\")"
],
"execution_count": 407,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "mZDo2UViSaC1",
"colab_type": "text"
},
"source": [
"67.24 on dev overall"
]
},
{
"cell_type": "code",
"metadata": {
"id": "AxnOkF0b8G2P",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "ac21df08-0e54-49d6-d447-9d82211d1eed"
},
"source": [
"tmp = devDf.copy()\n",
"tmp[\"accurate\"] = tmp[\"actual\"] == tmp[\"pred\"]\n",
"tmp[\"accurate\"].sum()/len(tmp)"
],
"execution_count": 411,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.6724137931034483"
]
},
"metadata": {
"tags": []
},
"execution_count": 411
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ryj3aNwCSb8h",
"colab_type": "text"
},
"source": [
"70 on the nouns"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nsrx68C3OrQ5",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "3321ffad-2b3f-4a6a-c4f2-c10b9f3655c2"
},
"source": [
"tmp = devDf[devDf.pos==\"N\"].copy()\n",
"tmp[\"accurate\"] = tmp[\"actual\"] == tmp[\"pred\"]\n",
"tmp[\"accurate\"].sum()/len(tmp)\n"
],
"execution_count": 409,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.7037974683544304"
]
},
"metadata": {
"tags": []
},
"execution_count": 409
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LfixNMBPSdkk",
"colab_type": "text"
},
"source": [
"62 on verbs"
]
},
{
"cell_type": "code",
"metadata": {
"id": "91CpA465OrJF",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "25129394-6c5f-4438-a164-77ce6ba98723"
},
"source": [
"tmp = devDf[devDf.pos==\"V\"].copy()\n",
"tmp[\"accurate\"] = tmp[\"actual\"] == tmp[\"pred\"]\n",
"tmp[\"accurate\"].sum()/len(tmp)\n"
],
"execution_count": 400,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.6213991769547325"
]
},
"metadata": {
"tags": []
},
"execution_count": 400
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSYCMotZSfBX",
"colab_type": "text"
},
"source": [
"59% accuracy on True match ones"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Y6s-Jj1IOvg_",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "56d70020-2d59-4b94-a59a-a696ca7deb5a"
},
"source": [
"tmp = devDf[devDf.actual==\"T\"].copy()\n",
"tmp[\"accurate\"] = tmp[\"actual\"] == tmp[\"pred\"]\n",
"tmp[\"accurate\"].sum()/len(tmp)\n"
],
"execution_count": 401,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.5956112852664577"
]
},
"metadata": {
"tags": []
},
"execution_count": 401
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NFDcsJuLSjqv",
"colab_type": "text"
},
"source": [
"74% accuracy on False match ones"
]
},
{
"cell_type": "code",
"metadata": {
"id": "FU3RVeBNOxGd",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "736de4fa-a436-46b8-cda6-99e173810e22"
},
"source": [
"tmp = devDf[devDf.actual==\"F\"].copy()\n",
"tmp[\"accurate\"] = tmp[\"actual\"] == tmp[\"pred\"]\n",
"tmp[\"accurate\"].sum()/len(tmp)\n"
],
"execution_count": 402,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.7492163009404389"
]
},
"metadata": {
"tags": []
},
"execution_count": 402
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "vGwWtYarPBi2",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment