Skip to content

Instantly share code, notes, and snippets.

  • Save emayssat/4725b848d2e0854c1996a9131c30f6c4 to your computer and use it in GitHub Desktop.
Save emayssat/4725b848d2e0854c1996a9131c30f6c4 to your computer and use it in GitHub Desktop.
Text classification via Siamese Network architecture using LSTM encoders.ipynb
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.8.5 64-bit ('nlp-generic': conda)",
"metadata": {
"interpreter": {
"hash": "b481e22569f8c970d02674a5d9e45918cde1d33fcb13233e32a75967a14c2314"
}
}
},
"colab": {
"name": "Text classification via Siamese Network architecture using LSTM encoders.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0,
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ishacusp/32cbd2fdd3d31886f8459d71e8b7860a/text-classification-via-siamese-network-architecture-using-lstm-encoders.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Text classification via Siamese Network architecture using LSTM encoders\n",
"\n",
"**Original Author:** [\n",
"Subhasis Jethy](https://github.com/subhasisj/Few-Shot-Learning)<br>\n",
"**Notebook modified by:** [Isha Chaturvedi](https://www.linkedin.com/in/isha-chaturvedi-18372619/)<br>\n",
"**Date created:** 2020/12/21<br>\n",
"**Last modified:** 2022/10/04<br>\n",
"**Description:** Training a Siamese Network to perform text classification using binary classification task. <br> \n",
"\n",
"\n",
"<br>\n",
"\n",
"**Source:**\n",
"<br>\n",
"https://data4thought.com/fewshot_learning_nlp.html,\n",
"<br>\n",
"https://blog.mlreview.com/implementing-malstm-on-kaggles-quora-question-pairs-competition-8b31b0b16a07 "
],
"metadata": {
"id": "_x9oD4bAWJ3b"
}
},
{
"cell_type": "markdown",
"source": [
"The main idea of siamese networks is to learn vector representation by training a model that discriminates between pairs of examples that are in the same category, and pairs of examples that come from different categories.\n",
"\n"
],
"metadata": {
"id": "xcqwcVjphCLU"
}
},
{
"cell_type": "code",
"source": [
"# Need to install these libraries\n",
"\n",
"!pip install -U spacy\n",
"# You must restart the runtime in order to use newly installed versions.\n",
"\n",
"!pip install texthero\n",
"# !pip install zeugma\n",
"\n",
"# Dataset location (download the dataset and locate it in the current working notebook folder):\n",
"# 1. Train data: https://github.com/subhasisj/Few-Shot-Learning/blob/master/Few-shot-Learning-Siamese-LSTM/final_fewshot_train.csv\n",
"# 2. Test data: https://github.com/subhasisj/Few-Shot-Learning/blob/master/Few-shot-Learning-Siamese-LSTM/final_fewshot_test.csv "
],
"metadata": {
"id": "_SR6VW54X3F6",
"outputId": "6cf1522b-8854-4c0a-a037-25e3aacf0392",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: spacy in /usr/local/lib/python3.7/dist-packages (3.4.1)\n",
"Requirement already satisfied: typer<0.5.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.4.2)\n",
"Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.6.2)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (21.3)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.0.8)\n",
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.8)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.7)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy) (57.4.0)\n",
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.9 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.10)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.9.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.10.1)\n",
"Requirement already satisfied: typing-extensions<4.2.0,>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.21.6)\n",
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.10.0,>=1.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.9.2)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (4.64.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.11.3)\n",
"Requirement already satisfied: thinc<8.2.0,>=8.1.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (8.1.2)\n",
"Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.4.4)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.6)\n",
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.3.0)\n",
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.0.3)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.23.0)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.6->spacy) (3.8.1)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy) (3.0.9)\n",
"Requirement already satisfied: smart-open<6.0.0,>=5.2.1 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy) (5.2.1)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2022.9.24)\n",
"Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.7/dist-packages (from thinc<8.2.0,>=8.1.0->spacy) (0.0.2)\n",
"Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.7/dist-packages (from thinc<8.2.0,>=8.1.0->spacy) (0.7.8)\n",
"Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.5.0,>=0.3.0->spacy) (7.1.2)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy) (2.0.1)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting texthero\n",
" Downloading texthero-1.1.0-py3-none-any.whl (24 kB)\n",
"Requirement already satisfied: gensim<4.0,>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from texthero) (3.6.0)\n",
"Requirement already satisfied: matplotlib>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from texthero) (3.2.2)\n",
"Requirement already satisfied: wordcloud>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from texthero) (1.8.2.2)\n",
"Requirement already satisfied: pandas>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from texthero) (1.3.5)\n",
"Requirement already satisfied: plotly>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from texthero) (5.5.0)\n",
"Requirement already satisfied: nltk>=3.3 in /usr/local/lib/python3.7/dist-packages (from texthero) (3.7)\n",
"Requirement already satisfied: tqdm>=4.3 in /usr/local/lib/python3.7/dist-packages (from texthero) (4.64.1)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from texthero) (1.21.6)\n",
"Collecting spacy<3.0.0\n",
" Downloading spacy-2.3.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.4 MB)\n",
"\u001b[K |████████████████████████████████| 10.4 MB 2.7 MB/s \n",
"\u001b[?25hRequirement already satisfied: scikit-learn>=0.22 in /usr/local/lib/python3.7/dist-packages (from texthero) (1.0.2)\n",
"Collecting unidecode>=1.1.1\n",
" Downloading Unidecode-1.3.6-py3-none-any.whl (235 kB)\n",
"\u001b[K |████████████████████████████████| 235 kB 44.4 MB/s \n",
"\u001b[?25hRequirement already satisfied: six>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from gensim<4.0,>=3.6.0->texthero) (1.15.0)\n",
"Requirement already satisfied: smart-open>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from gensim<4.0,>=3.6.0->texthero) (5.2.1)\n",
"Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.7/dist-packages (from gensim<4.0,>=3.6.0->texthero) (1.7.3)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.1.0->texthero) (1.4.4)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.1.0->texthero) (3.0.9)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.1.0->texthero) (2.8.2)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.1.0->texthero) (0.11.0)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib>=3.1.0->texthero) (4.1.1)\n",
"Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.7/dist-packages (from nltk>=3.3->texthero) (2022.6.2)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from nltk>=3.3->texthero) (7.1.2)\n",
"Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from nltk>=3.3->texthero) (1.2.0)\n",
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.0.2->texthero) (2022.4)\n",
"Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from plotly>=4.2.0->texthero) (8.1.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.22->texthero) (3.1.0)\n",
"Collecting thinc<7.5.0,>=7.4.1\n",
" Downloading thinc-7.4.5-cp37-cp37m-manylinux2014_x86_64.whl (1.0 MB)\n",
"\u001b[K |████████████████████████████████| 1.0 MB 36.5 MB/s \n",
"\u001b[?25hCollecting catalogue<1.1.0,>=0.0.7\n",
" Downloading catalogue-1.0.1-py2.py3-none-any.whl (16 kB)\n",
"Collecting plac<1.2.0,>=0.9.6\n",
" Downloading plac-1.1.3-py2.py3-none-any.whl (20 kB)\n",
"Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (0.7.8)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (0.10.1)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (2.23.0)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (57.4.0)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (1.0.8)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (2.0.6)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (3.0.7)\n",
"Collecting srsly<1.1.0,>=1.0.2\n",
" Downloading srsly-1.0.5-cp37-cp37m-manylinux2014_x86_64.whl (184 kB)\n",
"\u001b[K |████████████████████████████████| 184 kB 74.7 MB/s \n",
"\u001b[?25hRequirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy<3.0.0->texthero) (3.8.1)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0->texthero) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0->texthero) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0->texthero) (2022.9.24)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0->texthero) (2.10)\n",
"Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from wordcloud>=1.5.0->texthero) (7.1.2)\n",
"Installing collected packages: srsly, plac, catalogue, thinc, unidecode, spacy, texthero\n",
" Attempting uninstall: srsly\n",
" Found existing installation: srsly 2.4.4\n",
" Uninstalling srsly-2.4.4:\n",
" Successfully uninstalled srsly-2.4.4\n",
" Attempting uninstall: catalogue\n",
" Found existing installation: catalogue 2.0.8\n",
" Uninstalling catalogue-2.0.8:\n",
" Successfully uninstalled catalogue-2.0.8\n",
" Attempting uninstall: thinc\n",
" Found existing installation: thinc 8.1.2\n",
" Uninstalling thinc-8.1.2:\n",
" Successfully uninstalled thinc-8.1.2\n",
" Attempting uninstall: spacy\n",
" Found existing installation: spacy 3.4.1\n",
" Uninstalling spacy-3.4.1:\n",
" Successfully uninstalled spacy-3.4.1\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"en-core-web-sm 3.4.0 requires spacy<3.5.0,>=3.4.0, but you have spacy 2.3.7 which is incompatible.\n",
"confection 0.0.2 requires srsly<3.0.0,>=2.4.0, but you have srsly 1.0.5 which is incompatible.\u001b[0m\n",
"Successfully installed catalogue-1.0.1 plac-1.1.3 spacy-2.3.7 srsly-1.0.5 texthero-1.1.0 thinc-7.4.5 unidecode-1.3.6\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# If above installation doesn't let you import texthero successfully, run the code below and restart run-time\n",
"\n",
"!pip install -U spacy\n",
"import spacy"
],
"metadata": {
"id": "KvwIy41fkLt6",
"outputId": "6592e865-7060-4e25-aa9c-57c487019b82",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: spacy in /usr/local/lib/python3.7/dist-packages (2.3.7)\n",
"Collecting spacy\n",
" Downloading spacy-3.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.3 MB)\n",
"\u001b[K |████████████████████████████████| 6.3 MB 2.1 MB/s \n",
"\u001b[?25hRequirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.11.3)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (21.3)\n",
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.9 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.10)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (4.64.1)\n",
"Collecting srsly<3.0.0,>=2.4.3\n",
" Downloading srsly-2.4.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (458 kB)\n",
"\u001b[K |████████████████████████████████| 458 kB 57.3 MB/s \n",
"\u001b[?25hRequirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.0.8)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.9.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.10.1)\n",
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.21.6)\n",
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.10.0,>=1.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.9.2)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.23.0)\n",
"Requirement already satisfied: typing-extensions<4.2.0,>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy) (4.1.1)\n",
"Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.6.2)\n",
"Collecting thinc<8.2.0,>=8.1.0\n",
" Downloading thinc-8.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (804 kB)\n",
"\u001b[K |████████████████████████████████| 804 kB 46.0 MB/s \n",
"\u001b[?25hRequirement already satisfied: typer<0.5.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.4.2)\n",
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.0.3)\n",
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.3.0)\n",
"Collecting catalogue<2.1.0,>=2.0.6\n",
" Downloading catalogue-2.0.8-py3-none-any.whl (17 kB)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.6)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.7)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy) (57.4.0)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.6->spacy) (3.8.1)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy) (3.0.9)\n",
"Requirement already satisfied: smart-open<6.0.0,>=5.2.1 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy) (5.2.1)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (3.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2022.9.24)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2.10)\n",
"Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.7/dist-packages (from thinc<8.2.0,>=8.1.0->spacy) (0.7.8)\n",
"Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.7/dist-packages (from thinc<8.2.0,>=8.1.0->spacy) (0.0.2)\n",
"Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.5.0,>=0.3.0->spacy) (7.1.2)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy) (2.0.1)\n",
"Installing collected packages: catalogue, srsly, thinc, spacy\n",
" Attempting uninstall: catalogue\n",
" Found existing installation: catalogue 1.0.1\n",
" Uninstalling catalogue-1.0.1:\n",
" Successfully uninstalled catalogue-1.0.1\n",
" Attempting uninstall: srsly\n",
" Found existing installation: srsly 1.0.5\n",
" Uninstalling srsly-1.0.5:\n",
" Successfully uninstalled srsly-1.0.5\n",
" Attempting uninstall: thinc\n",
" Found existing installation: thinc 7.4.5\n",
" Uninstalling thinc-7.4.5:\n",
" Successfully uninstalled thinc-7.4.5\n",
" Attempting uninstall: spacy\n",
" Found existing installation: spacy 2.3.7\n",
" Uninstalling spacy-2.3.7:\n",
" Successfully uninstalled spacy-2.3.7\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"texthero 1.1.0 requires spacy<3.0.0, but you have spacy 3.4.1 which is incompatible.\u001b[0m\n",
"Successfully installed catalogue-2.0.8 spacy-3.4.1 srsly-2.4.4 thinc-8.1.3\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "z5TGYMYQmDLU",
"outputId": "5d378f16-8b28-4cc9-8e6a-42128b886be7",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"[nltk_data] Downloading package stopwords to /root/nltk_data...\n",
"[nltk_data] Unzipping corpora/stopwords.zip.\n"
]
}
],
"source": [
"import pandas as pd \n",
"import numpy as np \n",
"import tensorflow as tf \n",
"import itertools\n",
"import texthero as hero\n",
"# from zeugma import EmbeddingTransformer\n",
"from google.colab import drive\n",
"import os"
]
},
{
"cell_type": "code",
"source": [
"# Mount google drive and point to the current repo\n",
"drive.mount('/content/drive')\n",
"\n",
"os.chdir('/content/drive/My Drive/Colab Notebooks')\n",
"!pwd"
],
"metadata": {
"id": "srLyPop3oxhu",
"outputId": "9f815da4-409b-446b-b262-9a5f6a27b47a",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n",
"/content/drive/My Drive/Colab Notebooks\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "euqY8H1lmDLY"
},
"outputs": [],
"source": [
"# Optional if running locally\n",
"# physical_devices = tf.config.list_physical_devices('GPU')\n",
"# tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7WVlkpElmDLZ"
},
"outputs": [],
"source": [
"df_train = pd.read_csv('final_fewshot_train.csv')\n",
"df_test = pd.read_csv('final_fewshot_test.csv')\n",
"df_train=df_train[['text','class']]\n",
"df_test=df_test[['text','class']]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "8Z2qNt56mDLZ",
"outputId": "39d7b2da-d456-4564-e6db-f0bfda83a463"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text class\n",
"0 [ALLUXIO-2743] Fix failing unit tests 1\n",
"1 #2 Refactored structure of Argument 3\n",
"2 Remove some features from JwtTokenStore 4\n",
"3 Remove duplicated 1.613 section from changelog 2\n",
"4 * webapp structure refactoring 3"
],
"text/html": [
"\n",
" <div id=\"df-694efd91-8328-4311-89cb-7cdbd70f7929\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[ALLUXIO-2743] Fix failing unit tests</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>#2 Refactored structure of Argument</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Remove some features from JwtTokenStore</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Remove duplicated 1.613 section from changelog</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>* webapp structure refactoring</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-694efd91-8328-4311-89cb-7cdbd70f7929')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-694efd91-8328-4311-89cb-7cdbd70f7929 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-694efd91-8328-4311-89cb-7cdbd70f7929');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 6
}
],
"source": [
"df_train.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "xKc-goeumDLa",
"outputId": "a35d0027-de4d-4a3c-f47b-c4b457fda56b"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text\n",
"class \n",
"1 20\n",
"2 20\n",
"3 20\n",
"4 20\n",
"5 20"
],
"text/html": [
"\n",
" <div id=\"df-7abeefb7-1c71-4658-8436-160f6409cb4e\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" </tr>\n",
" <tr>\n",
" <th>class</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-7abeefb7-1c71-4658-8436-160f6409cb4e')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-7abeefb7-1c71-4658-8436-160f6409cb4e button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-7abeefb7-1c71-4658-8436-160f6409cb4e');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 7
}
],
"source": [
"df_train.groupby('class').count()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Xe9mo8UVmDLh",
"outputId": "e83be5a4-9fdd-45cd-e533-7bc39880f5b0"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1, 3, 4, 2, 5])"
]
},
"metadata": {},
"execution_count": 8
}
],
"source": [
"labels = df_train['class'].unique()\n",
"labels"
]
},
{
"source": [
"The dataset here looks resembles git code commits. We have 20 examples each for 5 classes in the data."
],
"cell_type": "markdown",
"metadata": {
"id": "WXciZgKfmDLa"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WBBrhcGnmDLf",
"outputId": "eeec1b33-ea79-4e9c-9d2b-679ff9cb7fa9"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 3277 entries, 0 to 3276\n",
"Data columns (total 2 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 text 3277 non-null object\n",
" 1 class 3277 non-null int64 \n",
"dtypes: int64(1), object(1)\n",
"memory usage: 51.3+ KB\n"
]
}
],
"source": [
"df_test.info()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bE5BBOT2mDLi"
},
"outputs": [],
"source": [
"# Clean text\n",
"# def text_cleaner(s):\n",
"# s = hero.remove_digits(s)\n",
"# s = hero.remove_brackets(s)\n",
"# s = hero.remove_punctuation(s)\n",
"# s = hero.remove_whitespace(s)\n",
"# s = hero.remove_stopwords(s)\n",
"\n",
"# return s\n",
"\n",
"df_train['cleaned_text'] = hero.clean(df_train['text'])\n",
"df_test['cleaned_text'] = hero.clean(df_test['text'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "TQVXcGh7mDLi",
"outputId": "d719da1e-f555-4347-b4da-dcd77861f040"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text class \\\n",
"0 [ALLUXIO-2743] Fix failing unit tests 1 \n",
"1 #2 Refactored structure of Argument 3 \n",
"2 Remove some features from JwtTokenStore 4 \n",
"3 Remove duplicated 1.613 section from changelog 2 \n",
"4 * webapp structure refactoring 3 \n",
"\n",
" cleaned_text \n",
"0 alluxio fix failing unit tests \n",
"1 refactored structure argument \n",
"2 remove features jwttokenstore \n",
"3 remove duplicated section changelog \n",
"4 webapp structure refactoring "
],
"text/html": [
"\n",
" <div id=\"df-fe8fa77e-f8f0-433e-a452-0072369b8c3a\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>class</th>\n",
" <th>cleaned_text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[ALLUXIO-2743] Fix failing unit tests</td>\n",
" <td>1</td>\n",
" <td>alluxio fix failing unit tests</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>#2 Refactored structure of Argument</td>\n",
" <td>3</td>\n",
" <td>refactored structure argument</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Remove some features from JwtTokenStore</td>\n",
" <td>4</td>\n",
" <td>remove features jwttokenstore</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Remove duplicated 1.613 section from changelog</td>\n",
" <td>2</td>\n",
" <td>remove duplicated section changelog</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>* webapp structure refactoring</td>\n",
" <td>3</td>\n",
" <td>webapp structure refactoring</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-fe8fa77e-f8f0-433e-a452-0072369b8c3a')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-fe8fa77e-f8f0-433e-a452-0072369b8c3a button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-fe8fa77e-f8f0-433e-a452-0072369b8c3a');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 11
}
],
"source": [
"df_train.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B01lwItQmDLj"
},
"outputs": [],
"source": [
"# Preparing the training set - group similar class category together, label it 1,\n",
"# and group disimilar class category, label it as 0.\n",
"\n",
"text_left = []\n",
"text_right = []\n",
"target = []\n",
"\n",
"\n",
"for label in labels:\n",
" \n",
" similar_texts = df_train[df_train['class']==label]['cleaned_text']\n",
" group_similar_texts = list(itertools.combinations(similar_texts,2))\n",
" \n",
" text_left.extend([group[0] for group in group_similar_texts])\n",
" text_right.extend([group[1] for group in group_similar_texts])\n",
" target.extend([1.]*len(group_similar_texts))\n",
"\n",
" dissimilar_texts = df_train[df_train['class']!=label]['cleaned_text']\n",
" for i in range(len(group_similar_texts)):\n",
" text_left.append(np.random.choice(similar_texts))\n",
" text_right.append(np.random.choice(dissimilar_texts))\n",
" target.append(0.)\n",
" \n",
"dataset = pd.DataFrame({'text_left':text_left,\n",
" 'text_right':text_right,\n",
" 'target': target})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"id": "qJO7n90mmDLk",
"outputId": "7aa32b93-324c-460d-905e-8ad613da563e"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text_left \\\n",
"449 poulpe structure component classes improved me... \n",
"1327 fix issue duplicate corrected link paths new j... \n",
"1538 improved performance clearing map instead recr... \n",
"1237 remove duplicate scripts move remaining items ... \n",
"1720 improved mmap management buffer pool full perf... \n",
"1386 refactor test code remove duplicates failovero... \n",
"1792 tiny performance improvement \n",
"1414 commandlinerunner handle uris refactored dupli... \n",
"632 introducing basic service data transfer object... \n",
"1090 stability monitors inheritable thus use listen... \n",
"\n",
" text_right target \n",
"449 refactored db structure deal private public au... 1.0 \n",
"1327 remove duplicate code dependency metadata 1.0 \n",
"1538 performance improvements 1.0 \n",
"1237 remove duplicate code dependency metadata 1.0 \n",
"1720 docearevent structure refactored 0.0 \n",
"1386 improved mmap management buffer pool full perf... 0.0 \n",
"1792 bug cloudstack service offering page fix bug n... 0.0 \n",
"1414 improved performance clearing map instead recr... 0.0 \n",
"632 remove duplicate utils getters 0.0 \n",
"1090 modest improvements use jdbc better use prepar... 0.0 "
],
"text/html": [
"\n",
" <div id=\"df-11236021-339f-431e-8a39-74b60399c03e\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text_left</th>\n",
" <th>text_right</th>\n",
" <th>target</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>449</th>\n",
" <td>poulpe structure component classes improved me...</td>\n",
" <td>refactored db structure deal private public au...</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1327</th>\n",
" <td>fix issue duplicate corrected link paths new j...</td>\n",
" <td>remove duplicate code dependency metadata</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1538</th>\n",
" <td>improved performance clearing map instead recr...</td>\n",
" <td>performance improvements</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1237</th>\n",
" <td>remove duplicate scripts move remaining items ...</td>\n",
" <td>remove duplicate code dependency metadata</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1720</th>\n",
" <td>improved mmap management buffer pool full perf...</td>\n",
" <td>docearevent structure refactored</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1386</th>\n",
" <td>refactor test code remove duplicates failovero...</td>\n",
" <td>improved mmap management buffer pool full perf...</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1792</th>\n",
" <td>tiny performance improvement</td>\n",
" <td>bug cloudstack service offering page fix bug n...</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1414</th>\n",
" <td>commandlinerunner handle uris refactored dupli...</td>\n",
" <td>improved performance clearing map instead recr...</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>632</th>\n",
" <td>introducing basic service data transfer object...</td>\n",
" <td>remove duplicate utils getters</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1090</th>\n",
" <td>stability monitors inheritable thus use listen...</td>\n",
" <td>modest improvements use jdbc better use prepar...</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-11236021-339f-431e-8a39-74b60399c03e')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-11236021-339f-431e-8a39-74b60399c03e button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-11236021-339f-431e-8a39-74b60399c03e');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 13
}
],
"source": [
"dataset.sample(10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "t9kx8KEYmDLk",
"outputId": "2d86e494-998e-4078-c225-755ca04aa653"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1900 entries, 0 to 1899\n",
"Data columns (total 3 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 text_left 1900 non-null object \n",
" 1 text_right 1900 non-null object \n",
" 2 target 1900 non-null float64\n",
"dtypes: float64(1), object(2)\n",
"memory usage: 44.7+ KB\n"
]
}
],
"source": [
"dataset.info()"
]
},
{
"source": [
"From a training set of 100 samples were able to create 1900 samples for training the siamese network."
],
"cell_type": "markdown",
"metadata": {
"id": "OyMN6NeimDLk"
}
},
{
"source": [
"## Model"
],
"cell_type": "markdown",
"metadata": {
"id": "or2QnX4PmDLl"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BFMMHrAqmDLl"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.keras import backend as K\n",
"from tensorflow.keras.layers import Input, Dense, Dropout, Lambda, Subtract, LSTM, Embedding, Bidirectional\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
"from tensorflow.keras.initializers import Constant\n",
"from tensorflow.keras.models import Sequential, Model\n",
"import itertools"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QGWIKKMBmDLm"
},
"outputs": [],
"source": [
"MAX_SEQ_LENGTH = 100\n",
"VOCAB_SIZE = 10000"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XxMIExxemDLm",
"outputId": "30b25ad9-e340-40d0-c17b-cb9fbf1545ca",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found 584 unique tokens.\n",
"(1900, 100)\n",
"(1900, 100)\n"
]
}
],
"source": [
"tokenizer = Tokenizer(num_words=VOCAB_SIZE)\n",
"tokenizer.fit_on_texts(df_train.cleaned_text)\n",
"sequences_left = tokenizer.texts_to_sequences(dataset.text_left)\n",
"sequences_right = tokenizer.texts_to_sequences(dataset.text_right)\n",
"\n",
"word_index = tokenizer.word_index\n",
"print('Found %s unique tokens.' % len(word_index))\n",
"\n",
"# Inputs to the network are zero-padded sequences of word indices. These inputs are vectors of fixed length.\n",
"x_left = pad_sequences(sequences_left, maxlen=MAX_SEQ_LENGTH)\n",
"x_right = pad_sequences(sequences_right, maxlen=MAX_SEQ_LENGTH)\n",
"\n",
"print(x_left.shape)\n",
"print(x_right.shape)"
]
},
{
"cell_type": "code",
"source": [
"x_left"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hVC8WiO9oStK",
"outputId": "7cad050a-5b1e-4ec3-8d00-e36436aca532"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 0, 0, 0, ..., 66, 38, 10],\n",
" [ 0, 0, 0, ..., 66, 38, 10],\n",
" [ 0, 0, 0, ..., 66, 38, 10],\n",
" ...,\n",
" [ 0, 0, 0, ..., 20, 5, 275],\n",
" [ 0, 0, 0, ..., 7, 10, 47],\n",
" [ 0, 0, 0, ..., 512, 513, 514]], dtype=int32)"
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"source": [
"# Download glove vectors from: https://www.kaggle.com/datasets/danielwillgeorge/glove6b100dtxt?resource=download\n"
],
"metadata": {
"id": "4Yy_A2VXsTAF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Check the location of the downloaded vectors.\n",
"# !ls"
],
"metadata": {
"id": "XLurFaVvsXEE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9LBSHmT6mDLn"
},
"outputs": [],
"source": [
"embeddings_index = {}\n",
"\n",
"f = open('glove.6B.100d.txt',encoding=\"utf8\")\n",
"for line in f:\n",
" values = line.split(' ')\n",
" word = values[0] ## The first entry is the word\n",
"\n",
" # These are the vectors representing the embedding for the word\n",
" coefs = np.asarray(values[1:], dtype='float32') \n",
" embeddings_index[word] = coefs\n",
"f.close()\n",
"\n",
"print('GloVe data loaded')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vhdmr4ypmDLn"
},
"outputs": [],
"source": [
"# The final state of the LSTM for each text is 100 dimensional vector. \n",
"\n",
"EMBEDDING_DIM = 100"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8T06quW0mDLo"
},
"outputs": [],
"source": [
"num_words = min(VOCAB_SIZE, len(word_index)) + 1\n",
"embedding_matrix = np.zeros((num_words, EMBEDDING_DIM))\n",
"for word, i in word_index.items():\n",
" if i > VOCAB_SIZE:\n",
" continue\n",
" # This references the loaded embeddings dictionary\n",
" embedding_vector = embeddings_index.get(word)\n",
" if embedding_vector is not None:\n",
" # words not found in embedding index will be all-zeros.\n",
" embedding_matrix[i] = embedding_vector\n",
"\n"
]
},
{
"cell_type": "markdown",
"source": [
"These vectors are then fed into the embedding layer. This layer looks up the corresponding embedding for each word and encapsulates all them into a matrix. This matrix represents the given text as a series of embeddings.\n",
"\n",
"Once we have created the pairs dataset and preprocessed the code commits we can turn our attention to the siamese model itself. It consists of a cloned sequential network, the input of which is a pair of vectors x_left and x_right. The representations of the right and left inputs are used to compute the similarity between the commits: \n",
"\n",
"sim(x$_{l}$,x$_{r}$) = exp(−‖f(x$_{l}$) − f(x$_{r}$)‖$_{1}$),\n",
"\n",
"where sim ∈ [0, 1], ‖ ⋅ ‖$_{1}$ is the L1 norm, and f is the function corresponding to the application of the cloned sequential network to the left/right input.\n",
"\n",
"This setting is called the Manhattan LSTM because we'll use LSTMs as the sequential network, and the L1 norm (used to compute the distance between two samples of a pair) is also called the Manhattan distance."
],
"metadata": {
"id": "Q86UNz_ouCS4"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3fji3c9_mDLn"
},
"outputs": [],
"source": [
"def exponent_neg_manhattan_distance(arms_difference):\n",
" \"\"\" Compute the exponent of the opposite of the L1 norm of a vector, to get the left/right inputs\n",
" similarity from the inputs differences. This function is used to turn the unbounded\n",
" L1 distance to a similarity measure between 0 and 1\"\"\"\n",
"\n",
" return K.exp(-K.sum(K.abs(arms_difference), axis=1, keepdims=True))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AAHY93dgmDLo"
},
"outputs": [],
"source": [
"def siamese_lstm_model(max_length):\n",
"\n",
" input_shape = (max_length,)\n",
" input_left = Input(input_shape,name = 'input_left')\n",
" input_right = Input(input_shape,name = 'input_right')\n",
"\n",
" # load pre-trained word embeddings into an Embedding layer\n",
" # note that we set trainable = False so as to keep the embeddings fixed\n",
" embedding_layer = Embedding(num_words,\n",
" EMBEDDING_DIM,\n",
" embeddings_initializer=Constant(embedding_matrix),\n",
" input_length=max_length,\n",
" trainable=False)\n",
"\n",
" seq = Sequential(name='sequential_network')\n",
" seq.add(embedding_layer)\n",
" seq.add(Bidirectional(LSTM(32, dropout=0.3, recurrent_dropout=0.)))\n",
"\n",
" output_left = seq(input_left)\n",
" output_right = seq(input_right)\n",
"\n",
" # Here we subtract the neuron values of the last layer from the left arm \n",
" # with the corresponding values from the right arm.\n",
"\n",
" subtracted = Subtract(name='pair_representations_difference')([output_left, output_right])\n",
" malstm_distance = Lambda(exponent_neg_manhattan_distance, \n",
" name='masltsm_distance')(subtracted)\n",
" \n",
" # Pack it into a model\n",
" siamese_net = Model(inputs=[input_left, input_right], outputs=malstm_distance)\n",
" siamese_net.compile(loss=\"binary_crossentropy\", optimizer='adam', metrics=['accuracy'])\n",
"\n",
" return siamese_net\n",
"\n",
"\n",
"siamese_lstm = siamese_lstm_model(MAX_SEQ_LENGTH)\n",
"siamese_lstm.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XKcN5jYZmDLo"
},
"outputs": [],
"source": [
"siamese_lstm.fit([x_left,x_right], dataset.target, validation_split=0.3, epochs=12);"
]
},
{
"source": [
"## Predictions\n",
"\n",
"To address the initial problem of finding each git commit class, we compute, for each example in the test set, the similarity score of this example with all the examples in the training set. The predicted category is the one of the closest example in training set.\n",
"\n"
],
"cell_type": "markdown",
"metadata": {
"id": "nSmwskYcmDLp"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tsiUgHkWmDLp"
},
"outputs": [],
"source": [
"reference_sequences = tokenizer.texts_to_sequences(df_train.cleaned_text)\n",
"x_reference_sequences = pad_sequences(reference_sequences, maxlen=MAX_SEQ_LENGTH)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fEcvEer3mDLp"
},
"outputs": [],
"source": [
"def flatten_text_sequence(text):\n",
" flatten = itertools.chain.from_iterable\n",
" text = list(flatten(text))\n",
" return text\n",
"\n",
"def get_prediction(text):\n",
" \"\"\" Get the predicted category, and the most similar text\n",
" in the train set. Note that this way of computing a prediction is highly \n",
" not optimal, but it'll be sufficient for us now. \"\"\"\n",
" x = tokenizer.texts_to_sequences(text.split())\n",
" x = flatten_text_sequence(x)\n",
" x = pad_sequences([x], maxlen=MAX_SEQ_LENGTH)\n",
" # x = np.array(x)\n",
" # print([x[0]]*len(x_reference_sequences))\n",
"\n",
" # print(x_reference_sequences.shape)\n",
"\n",
" # Compute similarities of the text with all text's in the train set.\n",
" # Computing similarities can be a slow process\n",
" # You can use something like FAISS for faster search.\n",
" result = np.repeat(x, len(x_reference_sequences), axis=0)\n",
" \n",
" # similarities = siamese_lstm.predict([[x[0]]*len(x_reference_sequences), x_reference_sequences])\n",
" similarities = siamese_lstm.predict([result, x_reference_sequences])\n",
" most_similar_index = np.argmax(similarities)\n",
" \n",
" # The predicted category is the one of the most similar example from the train set.\n",
" # print(most_similar_index)\n",
" prediction = df_train['class'].iloc[most_similar_index]\n",
" most_similar_example = df_train['cleaned_text'].iloc[most_similar_index]\n",
"\n",
" return prediction, most_similar_example"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7AtfmRaamDLq",
"outputId": "175f0282-2d9c-4ec9-9c11-904f887a84cf"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(100, 100)\n"
]
}
],
"source": [
"x = df_train['cleaned_text'].iloc[34]\n",
"# print(x)\n",
"\n",
"x = tokenizer.texts_to_sequences(x.split())\n",
"x = flatten_text_sequence(x)\n",
"x = pad_sequences([x], maxlen=MAX_SEQ_LENGTH) \n",
"# x\n",
"result = np.repeat(x, len(x_reference_sequences), axis=0)\n",
"print(result.shape)"
]
},
{
"cell_type": "markdown",
"source": [
"Testing on a prediction example"
],
"metadata": {
"id": "5UV7sf0i1skl"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zeGIcVNOmDLq",
"outputId": "e55f9914-4f9e-4ab8-f07b-993e955e68cf"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Sampled Text: revert cloudstack automation fix test failure test 02 revert vm snapshots smoke test vm snapshots py\n",
"True Class: 1\n",
"Predicted Class : 1\n",
"Most similar example in train set: automation fix test failure test 02 revert vm snapshots smoke test vm snapshots py\n"
]
}
],
"source": [
"sample_idx = 22\n",
"\n",
"pred, most_sim = get_prediction(df_test.cleaned_text[sample_idx])\n",
"\n",
"print(f'Sampled Text: {df_test[\"cleaned_text\"].iloc[sample_idx]}')\n",
"print(f'True Class: {df_test[\"class\"].iloc[sample_idx]}')\n",
"print(f'Predicted Class : {pred}')\n",
"print(f'Most similar example in train set: {most_sim}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pe1slx4fmDLq",
"outputId": "1a037442-369d-43ac-db3e-82f8ccd0b149"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Test accuracy (siamese model): 70.00 %\n"
]
}
],
"source": [
"from sklearn.metrics import accuracy_score\n",
"\n",
"df_eval = df_test[:50]\n",
"\n",
"y_pred = [get_prediction(text)[0] for text in df_eval['cleaned_text']]\n",
"accuracy = accuracy_score(y_pred, df_eval['class'])\n",
"\n",
"print(f'Test accuracy (siamese model): {100*accuracy:.2f} %')"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment