Skip to content

Instantly share code, notes, and snippets.

@habedi
Created March 27, 2024 16:56
Show Gist options
  • Save habedi/9d028b464d183ff8f8d0526fd4add524 to your computer and use it in GitHub Desktop.
Save habedi/9d028b464d183ff8f8d0526fd4add524 to your computer and use it in GitHub Desktop.
Embed discourses
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.7.12",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"kaggle": {
"accelerator": "nvidiaTeslaT4",
"dataSources": [
{
"sourceId": 31779,
"databundleVersionId": 2970755,
"sourceType": "competition"
},
{
"sourceId": 3161854,
"sourceType": "datasetVersion",
"datasetId": 1913766
},
{
"sourceId": 3185027,
"sourceType": "datasetVersion",
"datasetId": 1934428
},
{
"sourceId": 3219458,
"sourceType": "datasetVersion",
"datasetId": 1952618
},
{
"sourceId": 3224725,
"sourceType": "datasetVersion",
"datasetId": 1952316
},
{
"sourceId": 3228278,
"sourceType": "datasetVersion",
"datasetId": 1957675
},
{
"sourceId": 3229886,
"sourceType": "datasetVersion",
"datasetId": 1958424
},
{
"sourceId": 3232969,
"sourceType": "datasetVersion",
"datasetId": 1960014
},
{
"sourceId": 3248794,
"sourceType": "datasetVersion",
"datasetId": 1968880
},
{
"sourceId": 3276749,
"sourceType": "datasetVersion",
"datasetId": 1984876
},
{
"sourceId": 3286735,
"sourceType": "datasetVersion",
"datasetId": 1990259
},
{
"sourceId": 3291540,
"sourceType": "datasetVersion",
"datasetId": 1992408
},
{
"sourceId": 3301199,
"sourceType": "datasetVersion",
"datasetId": 1996951
},
{
"sourceId": 3303452,
"sourceType": "datasetVersion",
"datasetId": 1997502
},
{
"sourceId": 7833451,
"sourceType": "datasetVersion",
"datasetId": 4591288
}
],
"dockerImageVersionId": 30163,
"isInternetEnabled": false,
"language": "python",
"sourceType": "notebook",
"isGpuEnabled": true
},
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"accelerator": "GPU"
},
"nbformat_minor": 0,
"nbformat": 4,
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/habedi/9d028b464d183ff8f8d0526fd4add524/embed-discourses.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"pip install sentence_transformers"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "b2dTeZNQsVuu",
"outputId": "cfc0f29b-697b-4ee7-ece4-fbc9b8e2b459"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: sentence_transformers in /usr/local/lib/python3.10/dist-packages (2.6.1)\n",
"Requirement already satisfied: transformers<5.0.0,>=4.32.0 in /usr/local/lib/python3.10/dist-packages (from sentence_transformers) (4.38.2)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from sentence_transformers) (4.66.2)\n",
"Requirement already satisfied: torch>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from sentence_transformers) (2.2.1+cu121)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from sentence_transformers) (1.25.2)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from sentence_transformers) (1.2.2)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from sentence_transformers) (1.11.4)\n",
"Requirement already satisfied: huggingface-hub>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from sentence_transformers) (0.20.3)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from sentence_transformers) (9.4.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence_transformers) (3.13.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence_transformers) (2023.6.0)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence_transformers) (2.31.0)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence_transformers) (6.0.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence_transformers) (4.10.0)\n",
"Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.15.1->sentence_transformers) (24.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (3.2.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (3.1.3)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (12.1.105)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (12.1.105)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (8.9.2.26)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (12.1.3.1)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (11.0.2.54)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (10.3.2.106)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (11.4.5.107)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (12.1.0.106)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (2.19.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (12.1.105)\n",
"Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.11.0->sentence_transformers) (2.2.0)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.11.0->sentence_transformers) (12.4.99)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers<5.0.0,>=4.32.0->sentence_transformers) (2023.12.25)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers<5.0.0,>=4.32.0->sentence_transformers) (0.15.2)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers<5.0.0,>=4.32.0->sentence_transformers) (0.4.2)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->sentence_transformers) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->sentence_transformers) (3.3.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.11.0->sentence_transformers) (2.1.5)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.15.1->sentence_transformers) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.15.1->sentence_transformers) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.15.1->sentence_transformers) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.15.1->sentence_transformers) (2024.2.2)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.11.0->sentence_transformers) (1.3.0)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"gpu_info = !nvidia-smi\n",
"gpu_info = '\\n'.join(gpu_info)\n",
"if gpu_info.find('failed') >= 0:\n",
" print('Not connected to a GPU')\n",
"else:\n",
" print(gpu_info)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KjSIiGERHaOd",
"outputId": "4235671d-7721-494d-89b0-ea7dbe78835a"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Wed Mar 27 16:51:54 2024 \n",
"+---------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n",
"|-----------------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 39C P8 9W / 70W | 0MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
" \n",
"+---------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=======================================================================================|\n",
"| No running processes found |\n",
"+---------------------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from psutil import virtual_memory\n",
"ram_gb = virtual_memory().total / 1e9\n",
"print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n",
"\n",
"if ram_gb < 20:\n",
" print('Not using a high-RAM runtime')\n",
"else:\n",
" print('You are using a high-RAM runtime!')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3Wq3ll5uHh9a",
"outputId": "423e942f-abc2-4c13-af22-0ddf79f3cdaa"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Your runtime has 13.6 gigabytes of available RAM\n",
"\n",
"Not using a high-RAM runtime\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from tensorflow.keras import mixed_precision"
],
"metadata": {
"id": "MBRpeyPeHq7g"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"policy = mixed_precision.Policy('mixed_float16')\n",
"mixed_precision.set_global_policy(policy)"
],
"metadata": {
"id": "t6XehLFYHrlJ"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"print('Compute dtype: %s' % policy.compute_dtype)\n",
"print('Variable dtype: %s' % policy.variable_dtype)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vsvU3abjHu39",
"outputId": "7755d6c7-830b-4e4f-8800-e89a312d4985"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Compute dtype: float16\n",
"Variable dtype: float32\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from pathlib import Path\n",
"from sentence_transformers import SentenceTransformer\n",
"import pandas as pd"
],
"metadata": {
"id": "Br7JIi_pfkvN"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"id": "4te5zAxkMrBp",
"outputId": "4583138f-c15a-4ca1-efcb-bc6b44a60d9e",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "IZ36mmYqsHZc"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "K5N_bGYgsj8R",
"outputId": "8884b42e-fbde-4a8a-8235-9d54d18ef56e"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n",
"The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
"To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
"You will be able to reuse this secret in all of your notebooks.\n",
"Please note that authentication is recommended but still optional to access public models or datasets.\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"data_dir = Path(\"data\")\n",
"data_dir = Path(\"/content/drive/MyDrive/موقتی/RATER1\")"
],
"metadata": {
"id": "A328nBlptFlF"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_df = pd.read_csv(data_dir / \"train.csv\", low_memory=False)"
],
"metadata": {
"id": "dqaiN6BBtPnN"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_df.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 417
},
"id": "Qkf2S4ymtcCo",
"outputId": "5ea0773a-9eaa-491c-9074-6890f95a17b0"
},
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" essay_id_comp discourse_id discourse_start discourse_end discourse_type \\\n",
"0 423A1CA112E2 1.622628e+12 0 7 Unannotated \n",
"1 423A1CA112E2 1.622628e+12 8 229 Lead \n",
"2 423A1CA112E2 1.622628e+12 230 312 Position \n",
"3 423A1CA112E2 1.622628e+12 313 400 Evidence \n",
"4 423A1CA112E2 1.622628e+12 401 756 Evidence \n",
"\n",
" predictionstring \\\n",
"0 0 \n",
"1 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 1... \n",
"2 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 \n",
"3 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 \n",
"4 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 9... \n",
"\n",
" discourse_text discourse_effectiveness \\\n",
"0 Phones\\n\\n NaN \n",
"1 Modern humans today are always on their phone.... Non-Effective \n",
"2 They are some really bad consequences when stu... Non-Effective \n",
"3 Some certain areas in the United States ban ph... Non-Effective \n",
"4 When people have phones, they know about certa... Non-Effective \n",
"\n",
" discourse_type_num source_text \n",
"0 Unannotated 1 NaN \n",
"1 Lead 1 NaN \n",
"2 Position 1 NaN \n",
"3 Evidence 1 NaN \n",
"4 Evidence 2 NaN "
],
"text/html": [
"\n",
" <div id=\"df-241321a1-451d-4b57-8594-bdc803c8fa13\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>essay_id_comp</th>\n",
" <th>discourse_id</th>\n",
" <th>discourse_start</th>\n",
" <th>discourse_end</th>\n",
" <th>discourse_type</th>\n",
" <th>predictionstring</th>\n",
" <th>discourse_text</th>\n",
" <th>discourse_effectiveness</th>\n",
" <th>discourse_type_num</th>\n",
" <th>source_text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>423A1CA112E2</td>\n",
" <td>1.622628e+12</td>\n",
" <td>0</td>\n",
" <td>7</td>\n",
" <td>Unannotated</td>\n",
" <td>0</td>\n",
" <td>Phones\\n\\n</td>\n",
" <td>NaN</td>\n",
" <td>Unannotated 1</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>423A1CA112E2</td>\n",
" <td>1.622628e+12</td>\n",
" <td>8</td>\n",
" <td>229</td>\n",
" <td>Lead</td>\n",
" <td>1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 1...</td>\n",
" <td>Modern humans today are always on their phone....</td>\n",
" <td>Non-Effective</td>\n",
" <td>Lead 1</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>423A1CA112E2</td>\n",
" <td>1.622628e+12</td>\n",
" <td>230</td>\n",
" <td>312</td>\n",
" <td>Position</td>\n",
" <td>45 46 47 48 49 50 51 52 53 54 55 56 57 58 59</td>\n",
" <td>They are some really bad consequences when stu...</td>\n",
" <td>Non-Effective</td>\n",
" <td>Position 1</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>423A1CA112E2</td>\n",
" <td>1.622628e+12</td>\n",
" <td>313</td>\n",
" <td>400</td>\n",
" <td>Evidence</td>\n",
" <td>60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75</td>\n",
" <td>Some certain areas in the United States ban ph...</td>\n",
" <td>Non-Effective</td>\n",
" <td>Evidence 1</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>423A1CA112E2</td>\n",
" <td>1.622628e+12</td>\n",
" <td>401</td>\n",
" <td>756</td>\n",
" <td>Evidence</td>\n",
" <td>76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 9...</td>\n",
" <td>When people have phones, they know about certa...</td>\n",
" <td>Non-Effective</td>\n",
" <td>Evidence 2</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-241321a1-451d-4b57-8594-bdc803c8fa13')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-241321a1-451d-4b57-8594-bdc803c8fa13 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-241321a1-451d-4b57-8594-bdc803c8fa13');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-d7331ebf-cc08-43b8-9aea-ee8d7fed23d4\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-d7331ebf-cc08-43b8-9aea-ee8d7fed23d4')\"\n",
" title=\"Suggest charts\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" --bg-color: #E8F0FE;\n",
" --fill-color: #1967D2;\n",
" --hover-bg-color: #E2EBFA;\n",
" --hover-fill-color: #174EA6;\n",
" --disabled-fill-color: #AAA;\n",
" --disabled-bg-color: #DDD;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" --bg-color: #3B4455;\n",
" --fill-color: #D2E3FC;\n",
" --hover-bg-color: #434B5C;\n",
" --hover-fill-color: #FFFFFF;\n",
" --disabled-bg-color: #3B4455;\n",
" --disabled-fill-color: #666;\n",
" }\n",
"\n",
" .colab-df-quickchart {\n",
" background-color: var(--bg-color);\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: var(--fill-color);\n",
" height: 32px;\n",
" padding: 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: var(--hover-bg-color);\n",
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: var(--button-hover-fill-color);\n",
" }\n",
"\n",
" .colab-df-quickchart-complete:disabled,\n",
" .colab-df-quickchart-complete:disabled:hover {\n",
" background-color: var(--disabled-bg-color);\n",
" fill: var(--disabled-fill-color);\n",
" box-shadow: none;\n",
" }\n",
"\n",
" .colab-df-spinner {\n",
" border: 2px solid var(--fill-color);\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" animation:\n",
" spin 1s steps(1) infinite;\n",
" }\n",
"\n",
" @keyframes spin {\n",
" 0% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" border-left-color: var(--fill-color);\n",
" }\n",
" 20% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 30% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 40% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 60% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 80% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" 90% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const quickchartButtonEl =\n",
" document.querySelector('#' + key + ' button');\n",
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
" quickchartButtonEl.classList.add('colab-df-spinner');\n",
" try {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" } catch (error) {\n",
" console.error('Error during call to suggestCharts:', error);\n",
" }\n",
" quickchartButtonEl.classList.remove('colab-df-spinner');\n",
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-d7331ebf-cc08-43b8-9aea-ee8d7fed23d4 button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
" </div>\n",
" </div>\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "train_df"
}
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"source": [
"interim_submission_df = pd.read_csv(data_dir / \"interim_submission.csv\", low_memory=False)"
],
"metadata": {
"id": "rsng0kE5tj1w"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"interim_submission_df.head()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 383
},
"id": "cRZ1FWgCtnS8",
"outputId": "128a8ce5-be76-4c52-c615-1c135157ea4d"
},
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" essay_id_comp predictionstring \\\n",
"0 E74F2616693B 299 300 301 302 303 304 305 306 307 308 309 31... \n",
"1 E74F2616693B 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 \n",
"2 E74F2616693B 247 248 249 250 251 252 253 254 255 256 257 25... \n",
"3 E74F2616693B 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 3... \n",
"4 E74F2616693B 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 9... \n",
"\n",
" predictionstring_text \\\n",
"0 l phones during lunch, brunch, \n",
"1 : TEACHER_NAME,\\n\\n \n",
"2 ing school. I think that you should let kids use \n",
"3 ,\\n\\nKids theses days \n",
"4 o school. I think that you should let \n",
"\n",
" score_discourse_effectiveness_0 score_discourse_effectiveness_1 \\\n",
"0 0.999 0.001 \n",
"1 0.999 0.001 \n",
"2 0.999 0.001 \n",
"3 0.999 0.001 \n",
"4 0.999 0.001 \n",
"\n",
" class discourse_type \n",
"0 Concluding Statement 6 \n",
"1 Lead 0 \n",
"2 Evidence 3 \n",
"3 Position 1 \n",
"4 Evidence 3 "
],
"text/html": [
"\n",
" <div id=\"df-a8d8fb8f-1464-4451-b1af-72636a954239\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>essay_id_comp</th>\n",
" <th>predictionstring</th>\n",
" <th>predictionstring_text</th>\n",
" <th>score_discourse_effectiveness_0</th>\n",
" <th>score_discourse_effectiveness_1</th>\n",
" <th>class</th>\n",
" <th>discourse_type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>E74F2616693B</td>\n",
" <td>299 300 301 302 303 304 305 306 307 308 309 31...</td>\n",
" <td>l phones during lunch, brunch,</td>\n",
" <td>0.999</td>\n",
" <td>0.001</td>\n",
" <td>Concluding Statement</td>\n",
" <td>6</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>E74F2616693B</td>\n",
" <td>2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18</td>\n",
" <td>: TEACHER_NAME,\\n\\n</td>\n",
" <td>0.999</td>\n",
" <td>0.001</td>\n",
" <td>Lead</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>E74F2616693B</td>\n",
" <td>247 248 249 250 251 252 253 254 255 256 257 25...</td>\n",
" <td>ing school. I think that you should let kids use</td>\n",
" <td>0.999</td>\n",
" <td>0.001</td>\n",
" <td>Evidence</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>E74F2616693B</td>\n",
" <td>16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 3...</td>\n",
" <td>,\\n\\nKids theses days</td>\n",
" <td>0.999</td>\n",
" <td>0.001</td>\n",
" <td>Position</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>E74F2616693B</td>\n",
" <td>80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 9...</td>\n",
" <td>o school. I think that you should let</td>\n",
" <td>0.999</td>\n",
" <td>0.001</td>\n",
" <td>Evidence</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-a8d8fb8f-1464-4451-b1af-72636a954239')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-a8d8fb8f-1464-4451-b1af-72636a954239 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-a8d8fb8f-1464-4451-b1af-72636a954239');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-28bef502-96c7-4904-8f3b-3f2af4353fd8\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-28bef502-96c7-4904-8f3b-3f2af4353fd8')\"\n",
" title=\"Suggest charts\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" --bg-color: #E8F0FE;\n",
" --fill-color: #1967D2;\n",
" --hover-bg-color: #E2EBFA;\n",
" --hover-fill-color: #174EA6;\n",
" --disabled-fill-color: #AAA;\n",
" --disabled-bg-color: #DDD;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" --bg-color: #3B4455;\n",
" --fill-color: #D2E3FC;\n",
" --hover-bg-color: #434B5C;\n",
" --hover-fill-color: #FFFFFF;\n",
" --disabled-bg-color: #3B4455;\n",
" --disabled-fill-color: #666;\n",
" }\n",
"\n",
" .colab-df-quickchart {\n",
" background-color: var(--bg-color);\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: var(--fill-color);\n",
" height: 32px;\n",
" padding: 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: var(--hover-bg-color);\n",
" box-shadow: 0 1px 2px rgba(60, 64, 67, 0.3), 0 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: var(--button-hover-fill-color);\n",
" }\n",
"\n",
" .colab-df-quickchart-complete:disabled,\n",
" .colab-df-quickchart-complete:disabled:hover {\n",
" background-color: var(--disabled-bg-color);\n",
" fill: var(--disabled-fill-color);\n",
" box-shadow: none;\n",
" }\n",
"\n",
" .colab-df-spinner {\n",
" border: 2px solid var(--fill-color);\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" animation:\n",
" spin 1s steps(1) infinite;\n",
" }\n",
"\n",
" @keyframes spin {\n",
" 0% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" border-left-color: var(--fill-color);\n",
" }\n",
" 20% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 30% {\n",
" border-color: transparent;\n",
" border-left-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 40% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-top-color: var(--fill-color);\n",
" }\n",
" 60% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" }\n",
" 80% {\n",
" border-color: transparent;\n",
" border-right-color: var(--fill-color);\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" 90% {\n",
" border-color: transparent;\n",
" border-bottom-color: var(--fill-color);\n",
" }\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const quickchartButtonEl =\n",
" document.querySelector('#' + key + ' button');\n",
" quickchartButtonEl.disabled = true; // To prevent multiple clicks.\n",
" quickchartButtonEl.classList.add('colab-df-spinner');\n",
" try {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" } catch (error) {\n",
" console.error('Error during call to suggestCharts:', error);\n",
" }\n",
" quickchartButtonEl.classList.remove('colab-df-spinner');\n",
" quickchartButtonEl.classList.add('colab-df-quickchart-complete');\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-28bef502-96c7-4904-8f3b-3f2af4353fd8 button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
" </div>\n",
" </div>\n"
],
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "dataframe",
"variable_name": "interim_submission_df",
"summary": "{\n \"name\": \"interim_submission_df\",\n \"rows\": 87792,\n \"fields\": [\n {\n \"column\": \"essay_id_comp\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 10402,\n \"samples\": [\n \"41CE9804B226\",\n \"740969A36CE1\",\n \"8EA41EF1EC70\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"predictionstring\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 36744,\n \"samples\": [\n \"198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287\",\n \"510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562\",\n \"409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"predictionstring_text\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 85649,\n \"samples\": [\n \"g to convince you.\\n\\nSo luke is trying to tell you to become a seagoing cowboy but luke qutoes that he had experenced being a cowbo\",\n \"ps up in ev\",\n \"long as the phones are turne\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"score_discourse_effectiveness_0\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 3.33068804312966e-16,\n \"min\": 0.999,\n \"max\": 0.999,\n \"num_unique_values\": 1,\n \"samples\": [\n 0.999\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"score_discourse_effectiveness_1\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 6.505250084237618e-19,\n \"min\": 0.001,\n \"max\": 0.001,\n \"num_unique_values\": 1,\n \"samples\": [\n 0.001\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"class\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 7,\n \"samples\": [\n \"Concluding Statement\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"discourse_type\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 1,\n \"min\": 0,\n \"max\": 6,\n \"num_unique_values\": 7,\n \"samples\": [\n 6\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
}
},
"metadata": {},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"source": [
"train_embeddings = model.encode([i for i in\n",
" train_df[\"discourse_text\"]])"
],
"metadata": {
"id": "cbZEc41vtqZi"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "code",
"source": [
"interim_submission_embeddings = model.encode([i for i in\n",
" interim_submission_df[\"predictionstring_text\"]])"
],
"metadata": {
"id": "aYDFP3VTuNr0"
},
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_embeddings.shape, interim_submission_embeddings.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YUmoOzirxG69",
"outputId": "a817dd98-70c9-45ea-f2ff-879d5fee9ab0"
},
"execution_count": 19,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((173266, 384), (87792, 384))"
]
},
"metadata": {},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"source": [
"train_df[\"discourse_text\"].shape, interim_submission_df.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lWWGi-qlxPk5",
"outputId": "271298dc-085e-49b8-aa22-47f564c62ac4"
},
"execution_count": 20,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((173266,), (87792, 7))"
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "code",
"source": [
"import pickle"
],
"metadata": {
"id": "BxdJIrgavJoG"
},
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"source": [
"with open(\"train_embeddings.pkl\", \"wb\") as f:\n",
" pickle.dump(train_embeddings, f)"
],
"metadata": {
"id": "YSRod5x0vg0H"
},
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"source": [
"with open(\"interim_submission_embeddings.pkl\", \"wb\") as f:\n",
" pickle.dump(interim_submission_embeddings, f)"
],
"metadata": {
"id": "o4C8UjR3vkg_"
},
"execution_count": 23,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "p4GplBHLvomx"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment