Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ishacusp/32cbd2fdd3d31886f8459d71e8b7860a to your computer and use it in GitHub Desktop.
Save ishacusp/32cbd2fdd3d31886f8459d71e8b7860a to your computer and use it in GitHub Desktop.
Git Commits Classification via Siamese Network architecture using LSTM encoders.ipynb
{
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3.8.5 64-bit ('nlp-generic': conda)",
"metadata": {
"interpreter": {
"hash": "b481e22569f8c970d02674a5d9e45918cde1d33fcb13233e32a75967a14c2314"
}
}
},
"colab": {
"name": "Git Commits Classification via Siamese Network architecture using LSTM encoders.ipynb",
"provenance": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0,
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ishacusp/32cbd2fdd3d31886f8459d71e8b7860a/text-classification-via-siamese-network-architecture-using-lstm-encoders.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# Git Commits classification via Siamese Network architecture using LSTM encoders\n",
"\n",
"**Original Author:** [\n",
"Subhasis Jethy](https://github.com/subhasisj/Few-Shot-Learning)<br>\n",
"**Notebook modified by:** [Isha Chaturvedi](https://www.linkedin.com/in/isha-chaturvedi-18372619/)<br>\n",
"**Date created:** 2020/12/21<br>\n",
"**Last modified:** 2023/06/05<br>\n",
"**Description:** Training a Siamese Network to perform git commits text classification using binary classification task. <br> \n",
"\n",
"\n",
"<br>\n",
"\n",
"**Source:**\n",
"<br>\n",
"https://data4thought.com/fewshot_learning_nlp.html,\n",
"<br>\n",
"https://blog.mlreview.com/implementing-malstm-on-kaggles-quora-question-pairs-competition-8b31b0b16a07 "
],
"metadata": {
"id": "_x9oD4bAWJ3b"
}
},
{
"cell_type": "markdown",
"source": [
"#### Introduction\n",
"The main idea of siamese networks is to learn vector representation by training a model that discriminates between pairs of examples that are in the same category, and pairs of examples that come from different categories. This is very simple exercise where we are taking git commits and performing text classification. The data has 100 data samples, with 5 labeled classes for git commits. Each class has only 20 data samples. In this exercise we will learn how Few Shot Learning (FSL) can work with small amount of data samples per class. The test set here is 3277 data samples. The focus of the coding exercise is less on which pre trained model is being used, but more on the part of understanding how FSL works.\n",
"\n",
"You all should have two data files in your current working notebook folder: fewshot_train.csv and fewshot_test.csv.\n",
"\n",
"Note: Glove vectors can be downloaded into the above working folder from: https://www.kaggle.com/datasets/danielwillgeorge/glove6b100dtxt?resource=download\n",
"\n",
"This is only for the purpose of this training.\n",
"\n",
"Standard sources: https://github.com/stanfordnlp/GloVe or https://huggingface.co/stanfordnlp\n"
],
"metadata": {
"id": "xcqwcVjphCLU"
}
},
{
"cell_type": "markdown",
"source": [
"\n",
"#### Importing *Packages*"
],
"metadata": {
"id": "Shp0N-g0P8VV"
}
},
{
"cell_type": "code",
"source": [
"# Need to install these libraries\n",
"\n",
"!pip install -U spacy\n",
"# You must restart the runtime in order to use newly installed versions.\n",
"\n",
"!pip install texthero\n",
"# !pip install zeugma\n",
"\n",
"# Dataset location (download the dataset and locate it in the current working notebook folder):\n",
"# 1. Train data: https://github.com/subhasisj/Few-Shot-Learning/blob/master/Few-shot-Learning-Siamese-LSTM/final_fewshot_train.csv\n",
"# 2. Test data: https://github.com/subhasisj/Few-Shot-Learning/blob/master/Few-shot-Learning-Siamese-LSTM/final_fewshot_test.csv "
],
"metadata": {
"id": "_SR6VW54X3F6",
"outputId": "6cf1522b-8854-4c0a-a037-25e3aacf0392",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: spacy in /usr/local/lib/python3.7/dist-packages (3.4.1)\n",
"Requirement already satisfied: typer<0.5.0,>=0.3.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.4.2)\n",
"Requirement already satisfied: pathy>=0.3.5 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.6.2)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (21.3)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.0.8)\n",
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.8)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.7)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy) (57.4.0)\n",
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.9 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.0.10)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.9.1 in /usr/local/lib/python3.7/dist-packages (from spacy) (0.10.1)\n",
"Requirement already satisfied: typing-extensions<4.2.0,>=3.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy) (4.1.1)\n",
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.21.6)\n",
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.10.0,>=1.7.4 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.9.2)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (4.64.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.11.3)\n",
"Requirement already satisfied: thinc<8.2.0,>=8.1.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (8.1.2)\n",
"Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.4.4)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.0.6)\n",
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (3.3.0)\n",
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (1.0.3)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy) (2.23.0)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from catalogue<2.1.0,>=2.0.6->spacy) (3.8.1)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->spacy) (3.0.9)\n",
"Requirement already satisfied: smart-open<6.0.0,>=5.2.1 in /usr/local/lib/python3.7/dist-packages (from pathy>=0.3.5->spacy) (5.2.1)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (1.24.3)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (3.0.4)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2022.9.24)\n",
"Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.7/dist-packages (from thinc<8.2.0,>=8.1.0->spacy) (0.0.2)\n",
"Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.7/dist-packages (from thinc<8.2.0,>=8.1.0->spacy) (0.7.8)\n",
"Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.7/dist-packages (from typer<0.5.0,>=0.3.0->spacy) (7.1.2)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.7/dist-packages (from jinja2->spacy) (2.0.1)\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting texthero\n",
" Downloading texthero-1.1.0-py3-none-any.whl (24 kB)\n",
"Requirement already satisfied: gensim<4.0,>=3.6.0 in /usr/local/lib/python3.7/dist-packages (from texthero) (3.6.0)\n",
"Requirement already satisfied: matplotlib>=3.1.0 in /usr/local/lib/python3.7/dist-packages (from texthero) (3.2.2)\n",
"Requirement already satisfied: wordcloud>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from texthero) (1.8.2.2)\n",
"Requirement already satisfied: pandas>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from texthero) (1.3.5)\n",
"Requirement already satisfied: plotly>=4.2.0 in /usr/local/lib/python3.7/dist-packages (from texthero) (5.5.0)\n",
"Requirement already satisfied: nltk>=3.3 in /usr/local/lib/python3.7/dist-packages (from texthero) (3.7)\n",
"Requirement already satisfied: tqdm>=4.3 in /usr/local/lib/python3.7/dist-packages (from texthero) (4.64.1)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from texthero) (1.21.6)\n",
"Collecting spacy<3.0.0\n",
" Downloading spacy-2.3.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.4 MB)\n",
"\u001b[K |████████████████████████████████| 10.4 MB 2.7 MB/s \n",
"\u001b[?25hRequirement already satisfied: scikit-learn>=0.22 in /usr/local/lib/python3.7/dist-packages (from texthero) (1.0.2)\n",
"Collecting unidecode>=1.1.1\n",
" Downloading Unidecode-1.3.6-py3-none-any.whl (235 kB)\n",
"\u001b[K |████████████████████████████████| 235 kB 44.4 MB/s \n",
"\u001b[?25hRequirement already satisfied: six>=1.5.0 in /usr/local/lib/python3.7/dist-packages (from gensim<4.0,>=3.6.0->texthero) (1.15.0)\n",
"Requirement already satisfied: smart-open>=1.2.1 in /usr/local/lib/python3.7/dist-packages (from gensim<4.0,>=3.6.0->texthero) (5.2.1)\n",
"Requirement already satisfied: scipy>=0.18.1 in /usr/local/lib/python3.7/dist-packages (from gensim<4.0,>=3.6.0->texthero) (1.7.3)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.1.0->texthero) (1.4.4)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.1.0->texthero) (3.0.9)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.1.0->texthero) (2.8.2)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib>=3.1.0->texthero) (0.11.0)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from kiwisolver>=1.0.1->matplotlib>=3.1.0->texthero) (4.1.1)\n",
"Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.7/dist-packages (from nltk>=3.3->texthero) (2022.6.2)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from nltk>=3.3->texthero) (7.1.2)\n",
"Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from nltk>=3.3->texthero) (1.2.0)\n",
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=1.0.2->texthero) (2022.4)\n",
"Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.7/dist-packages (from plotly>=4.2.0->texthero) (8.1.0)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.22->texthero) (3.1.0)\n",
"Collecting thinc<7.5.0,>=7.4.1\n",
" Downloading thinc-7.4.5-cp37-cp37m-manylinux2014_x86_64.whl (1.0 MB)\n",
"\u001b[K |████████████████████████████████| 1.0 MB 36.5 MB/s \n",
"\u001b[?25hCollecting catalogue<1.1.0,>=0.0.7\n",
" Downloading catalogue-1.0.1-py2.py3-none-any.whl (16 kB)\n",
"Collecting plac<1.2.0,>=0.9.6\n",
" Downloading plac-1.1.3-py2.py3-none-any.whl (20 kB)\n",
"Requirement already satisfied: blis<0.8.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (0.7.8)\n",
"Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (0.10.1)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (2.23.0)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (57.4.0)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (1.0.8)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (2.0.6)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<3.0.0->texthero) (3.0.7)\n",
"Collecting srsly<1.1.0,>=1.0.2\n",
" Downloading srsly-1.0.5-cp37-cp37m-manylinux2014_x86_64.whl (184 kB)\n",
"\u001b[K |████████████████████████████████| 184 kB 74.7 MB/s \n",
"\u001b[?25hRequirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy<3.0.0->texthero) (3.8.1)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0->texthero) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0->texthero) (1.24.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0->texthero) (2022.9.24)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests<3.0.0,>=2.13.0->spacy<3.0.0->texthero) (2.10)\n",
"Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from wordcloud>=1.5.0->texthero) (7.1.2)\n",
"Installing collected packages: srsly, plac, catalogue, thinc, unidecode, spacy, texthero\n",
" Attempting uninstall: srsly\n",
" Found existing installation: srsly 2.4.4\n",
" Uninstalling srsly-2.4.4:\n",
" Successfully uninstalled srsly-2.4.4\n",
" Attempting uninstall: catalogue\n",
" Found existing installation: catalogue 2.0.8\n",
" Uninstalling catalogue-2.0.8:\n",
" Successfully uninstalled catalogue-2.0.8\n",
" Attempting uninstall: thinc\n",
" Found existing installation: thinc 8.1.2\n",
" Uninstalling thinc-8.1.2:\n",
" Successfully uninstalled thinc-8.1.2\n",
" Attempting uninstall: spacy\n",
" Found existing installation: spacy 3.4.1\n",
" Uninstalling spacy-3.4.1:\n",
" Successfully uninstalled spacy-3.4.1\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"en-core-web-sm 3.4.0 requires spacy<3.5.0,>=3.4.0, but you have spacy 2.3.7 which is incompatible.\n",
"confection 0.0.2 requires srsly<3.0.0,>=2.4.0, but you have srsly 1.0.5 which is incompatible.\u001b[0m\n",
"Successfully installed catalogue-1.0.1 plac-1.1.3 spacy-2.3.7 srsly-1.0.5 texthero-1.1.0 thinc-7.4.5 unidecode-1.3.6\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# If above installation doesn't let you import texthero successfully, run the code below and restart run-time\n",
"\n",
"!pip install -U spacy\n",
"import spacy"
],
"metadata": {
"id": "KvwIy41fkLt6",
"outputId": "05988123-6456-463c-f1cc-d13e7229bef8",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: spacy in /usr/local/lib/python3.10/dist-packages (3.5.3)\n",
"Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.0.12)\n",
"Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.0.4)\n",
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.0.9)\n",
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.0.7)\n",
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.0.8)\n",
"Requirement already satisfied: thinc<8.2.0,>=8.1.8 in /usr/local/lib/python3.10/dist-packages (from spacy) (8.1.9)\n",
"Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.1.1)\n",
"Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.4.6)\n",
"Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.0.8)\n",
"Requirement already satisfied: typer<0.8.0,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (0.7.0)\n",
"Requirement already satisfied: pathy>=0.10.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (0.10.1)\n",
"Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /usr/local/lib/python3.10/dist-packages (from spacy) (6.3.0)\n",
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (4.65.0)\n",
"Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.22.4)\n",
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (2.27.1)\n",
"Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.11.0,>=1.7.4 in /usr/local/lib/python3.10/dist-packages (from spacy) (1.10.7)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.1.2)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from spacy) (67.7.2)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (23.1)\n",
"Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /usr/local/lib/python3.10/dist-packages (from spacy) (3.3.0)\n",
"Requirement already satisfied: typing-extensions>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from pydantic!=1.8,!=1.8.1,<1.11.0,>=1.7.4->spacy) (4.5.0)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (1.26.15)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2022.12.7)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (2.0.12)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0.0,>=2.13.0->spacy) (3.4)\n",
"Requirement already satisfied: blis<0.8.0,>=0.7.8 in /usr/local/lib/python3.10/dist-packages (from thinc<8.2.0,>=8.1.8->spacy) (0.7.9)\n",
"Requirement already satisfied: confection<1.0.0,>=0.0.1 in /usr/local/lib/python3.10/dist-packages (from thinc<8.2.0,>=8.1.8->spacy) (0.0.4)\n",
"Requirement already satisfied: click<9.0.0,>=7.1.1 in /usr/local/lib/python3.10/dist-packages (from typer<0.8.0,>=0.3.0->spacy) (8.1.3)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->spacy) (2.1.2)\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "z5TGYMYQmDLU"
},
"outputs": [],
"source": [
"import pandas as pd \n",
"import numpy as np \n",
"from sklearn.metrics import accuracy_score\n",
"\n",
"\n",
"# import texthero as hero\n",
"from google.colab import drive\n",
"\n",
"import itertools\n",
"import os\n",
"\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras import backend as K\n",
"from tensorflow.keras.layers import Input, Dense, Dropout, Lambda, Subtract, LSTM, Embedding, Bidirectional\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
"from tensorflow.keras.initializers import Constant\n",
"from tensorflow.keras.models import Sequential, Model\n"
]
},
{
"cell_type": "code",
"source": [
"# Mount google drive and point to the current repo\n",
"drive.mount('/content/drive')\n",
"\n",
"os.chdir('/content/drive/My Drive/Colab Notebooks')\n",
"!pwd"
],
"metadata": {
"id": "srLyPop3oxhu",
"outputId": "c6d6d47d-8fb2-481a-b9ad-50538db73503",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n",
"/content/drive/My Drive/Colab Notebooks\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "euqY8H1lmDLY"
},
"outputs": [],
"source": [
"# Optional if running locally\n",
"# physical_devices = tf.config.list_physical_devices('GPU')\n",
"# tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)"
]
},
{
"cell_type": "markdown",
"source": [
"#### Loading the data and data processing"
],
"metadata": {
"id": "CuvOBy2XQqO-"
}
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "7WVlkpElmDLZ"
},
"outputs": [],
"source": [
"df_train = pd.read_csv('final_fewshot_train.csv')\n",
"df_test = pd.read_csv('final_fewshot_test.csv')\n",
"df_train=df_train[['text','class']]\n",
"df_test=df_test[['text','class']]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "8Z2qNt56mDLZ",
"outputId": "4a79aed4-3684-4a72-8f68-2bd98b994a88"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text class\n",
"0 [ALLUXIO-2743] Fix failing unit tests 1\n",
"1 #2 Refactored structure of Argument 3\n",
"2 Remove some features from JwtTokenStore 4\n",
"3 Remove duplicated 1.613 section from changelog 2\n",
"4 * webapp structure refactoring 3"
],
"text/html": [
"\n",
" <div id=\"df-fe884b11-d27d-40f9-9e9b-7ab04e62e08d\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[ALLUXIO-2743] Fix failing unit tests</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>#2 Refactored structure of Argument</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Remove some features from JwtTokenStore</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Remove duplicated 1.613 section from changelog</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>* webapp structure refactoring</td>\n",
" <td>3</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-fe884b11-d27d-40f9-9e9b-7ab04e62e08d')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-fe884b11-d27d-40f9-9e9b-7ab04e62e08d button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-fe884b11-d27d-40f9-9e9b-7ab04e62e08d');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 7
}
],
"source": [
"# Check the structure of the training data.\n",
"df_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"id": "xKc-goeumDLa",
"outputId": "2454461b-defc-4bd7-ada2-3af4525e97ec"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text\n",
"class \n",
"1 20\n",
"2 20\n",
"3 20\n",
"4 20\n",
"5 20"
],
"text/html": [
"\n",
" <div id=\"df-2ac57e00-0365-45a4-8d4b-82967e3e6ed6\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" </tr>\n",
" <tr>\n",
" <th>class</th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>20</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-2ac57e00-0365-45a4-8d4b-82967e3e6ed6')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-2ac57e00-0365-45a4-8d4b-82967e3e6ed6 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-2ac57e00-0365-45a4-8d4b-82967e3e6ed6');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 8
}
],
"source": [
"# Check the number of data samples per class. It should be 20 data samples per class.\n",
"df_train.groupby('class').count()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Xe9mo8UVmDLh",
"outputId": "8e94b84b-df67-4980-af26-3e9994728592"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([1, 3, 4, 2, 5])"
]
},
"metadata": {},
"execution_count": 9
}
],
"source": [
"# Store the class labels in labels variables\n",
"\n",
"labels = df_train['class'].unique()\n",
"labels"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 580
},
"id": "WBBrhcGnmDLf",
"outputId": "7bd6b377-fcff-491f-a5a3-47c1be09542c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 3277 entries, 0 to 3276\n",
"Data columns (total 2 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 text 3277 non-null object\n",
" 1 class 3277 non-null int64 \n",
"dtypes: int64(1), object(1)\n",
"memory usage: 51.3+ KB\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text class\n",
"0 ApiServer: Fix apidiscovery fail case, fix com... 1\n",
"1 [GSCOLLECT-1606] Improve primitive map perform... 5\n",
"2 Big rename and more pointcut features in UI 4\n",
"3 Network-refactor: fix bugs in components.xml d... 1\n",
"4 HADOOP-10659. Refactor AccessControlList to re... 5\n",
"... ... ...\n",
"3272 [JENKINS-35020] Fixed some JSHint errors (#2368) 1\n",
"3273 update feature #3145, Improve code and improve... 3\n",
"3274 [ALLUXIO-2743] Fix failing tests 1\n",
"3275 HIVE-5951 : improve performance of adding part... 5\n",
"3276 Change zadd parameter order to allow duplicate... 2\n",
"\n",
"[3277 rows x 2 columns]"
],
"text/html": [
"\n",
" <div id=\"df-68082739-cb4f-4402-98a3-3c92d4ad4414\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>class</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ApiServer: Fix apidiscovery fail case, fix com...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>[GSCOLLECT-1606] Improve primitive map perform...</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Big rename and more pointcut features in UI</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Network-refactor: fix bugs in components.xml d...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>HADOOP-10659. Refactor AccessControlList to re...</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3272</th>\n",
" <td>[JENKINS-35020] Fixed some JSHint errors (#2368)</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3273</th>\n",
" <td>update feature #3145, Improve code and improve...</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3274</th>\n",
" <td>[ALLUXIO-2743] Fix failing tests</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3275</th>\n",
" <td>HIVE-5951 : improve performance of adding part...</td>\n",
" <td>5</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3276</th>\n",
" <td>Change zadd parameter order to allow duplicate...</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>3277 rows × 2 columns</p>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-68082739-cb4f-4402-98a3-3c92d4ad4414')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-68082739-cb4f-4402-98a3-3c92d4ad4414 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-68082739-cb4f-4402-98a3-3c92d4ad4414');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 10
}
],
"source": [
"# Check the structure of the test data. There should be 3277 samples.\n",
"\n",
"df_test.info()\n",
"\n",
"df_test"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "bE5BBOT2mDLi"
},
"outputs": [],
"source": [
"# Clean text\n",
"# def text_cleaner(s):\n",
"# s = hero.remove_digits(s)\n",
"# s = hero.remove_brackets(s)\n",
"# s = hero.remove_punctuation(s)\n",
"# s = hero.remove_whitespace(s)\n",
"# s = hero.remove_stopwords(s)\n",
"\n",
"# return s"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"id": "TQVXcGh7mDLi",
"outputId": "d719da1e-f555-4347-b4da-dcd77861f040"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text class \\\n",
"0 [ALLUXIO-2743] Fix failing unit tests 1 \n",
"1 #2 Refactored structure of Argument 3 \n",
"2 Remove some features from JwtTokenStore 4 \n",
"3 Remove duplicated 1.613 section from changelog 2 \n",
"4 * webapp structure refactoring 3 \n",
"\n",
" cleaned_text \n",
"0 alluxio fix failing unit tests \n",
"1 refactored structure argument \n",
"2 remove features jwttokenstore \n",
"3 remove duplicated section changelog \n",
"4 webapp structure refactoring "
],
"text/html": [
"\n",
" <div id=\"df-fe8fa77e-f8f0-433e-a452-0072369b8c3a\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>class</th>\n",
" <th>cleaned_text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>[ALLUXIO-2743] Fix failing unit tests</td>\n",
" <td>1</td>\n",
" <td>alluxio fix failing unit tests</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>#2 Refactored structure of Argument</td>\n",
" <td>3</td>\n",
" <td>refactored structure argument</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Remove some features from JwtTokenStore</td>\n",
" <td>4</td>\n",
" <td>remove features jwttokenstore</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Remove duplicated 1.613 section from changelog</td>\n",
" <td>2</td>\n",
" <td>remove duplicated section changelog</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>* webapp structure refactoring</td>\n",
" <td>3</td>\n",
" <td>webapp structure refactoring</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-fe8fa77e-f8f0-433e-a452-0072369b8c3a')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-fe8fa77e-f8f0-433e-a452-0072369b8c3a button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-fe8fa77e-f8f0-433e-a452-0072369b8c3a');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 11
}
],
"source": [
"# Use texthero to clean the data. texthero is a python package to perform text pre-processing.\n",
"\n",
"# df_train['cleaned_text'] = hero.clean(df_train['text'])\n",
"# df_test['cleaned_text'] = hero.clean(df_test['text'])\n",
"\n",
"df_train.head()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "B01lwItQmDLj"
},
"outputs": [],
"source": [
"# Preparing the training set - group similar class category together, label it 1,\n",
"# and group disimilar class category, label it as 0.\n",
"\n",
"text_left = []\n",
"text_right = []\n",
"target = []\n",
"\n",
"\n",
"for label in labels:\n",
" \n",
" similar_texts = df_train[df_train['class']==label]['text']\n",
" group_similar_texts = list(itertools.combinations(similar_texts,2))\n",
" \n",
" text_left.extend([group[0] for group in group_similar_texts])\n",
" text_right.extend([group[1] for group in group_similar_texts])\n",
" target.extend([1.]*len(group_similar_texts))\n",
"\n",
" dissimilar_texts = df_train[df_train['class']!=label]['text']\n",
" for i in range(len(group_similar_texts)):\n",
" text_left.append(np.random.choice(similar_texts))\n",
" text_right.append(np.random.choice(dissimilar_texts))\n",
" target.append(0.)\n",
" \n",
"dataset = pd.DataFrame({'text_left':text_left,\n",
" 'text_right':text_right,\n",
" 'target': target})"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"id": "qJO7n90mmDLk",
"outputId": "79cba0f2-16b6-4e29-d4e4-03cee0e8f99e"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" text_left \\\n",
"87 NPE in the payload was causing the ssvm agent ... \n",
"1244 Remove duplicate Utils getters \n",
"13 [ALLUXIO-2743] Fix failing unit tests \n",
"769 Remove some features from JwtTokenStore \n",
"1057 Add AppSales to featured list \n",
"1696 Improved performance by avoiding unnecessary N... \n",
"1895 #691 - Performance Improvements \n",
"148 CLOUDSTACK-5557: UI > Network > Guest Network ... \n",
"287 Fixed adding route for additional public nic o... \n",
"1186 Adds validation of code smells in case express... \n",
"\n",
" text_right target \n",
"87 [Automation] - Fix test failure for test_02_re... 1.0 \n",
"1244 Refactor test code, remove duplicates at Failo... 1.0 \n",
"13 cloudStack 3.0 new UI - NaaS - fix a bug that ... 1.0 \n",
"769 Add AppSales to featured list 1.0 \n",
"1057 new UI - template page, ISO page - fix a bug t... 0.0 \n",
"1696 Improve readability of ComparisonFailure 1.0 \n",
"1895 HBASE-14466 Remove duplicated code from MOB sn... 0.0 \n",
"148 cloudStack 3.0 new UI - NaaS - fix a bug that ... 1.0 \n",
"287 Improved performance by avoiding unnecessary N... 0.0 \n",
"1186 Refactor test code, remove duplicates at Failo... 1.0 "
],
"text/html": [
"\n",
" <div id=\"df-17296452-22c9-4a0a-8263-990b9dd68c1e\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text_left</th>\n",
" <th>text_right</th>\n",
" <th>target</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>87</th>\n",
" <td>NPE in the payload was causing the ssvm agent ...</td>\n",
" <td>[Automation] - Fix test failure for test_02_re...</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1244</th>\n",
" <td>Remove duplicate Utils getters</td>\n",
" <td>Refactor test code, remove duplicates at Failo...</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>[ALLUXIO-2743] Fix failing unit tests</td>\n",
" <td>cloudStack 3.0 new UI - NaaS - fix a bug that ...</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>769</th>\n",
" <td>Remove some features from JwtTokenStore</td>\n",
" <td>Add AppSales to featured list</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1057</th>\n",
" <td>Add AppSales to featured list</td>\n",
" <td>new UI - template page, ISO page - fix a bug t...</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1696</th>\n",
" <td>Improved performance by avoiding unnecessary N...</td>\n",
" <td>Improve readability of ComparisonFailure</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1895</th>\n",
" <td>#691 - Performance Improvements</td>\n",
" <td>HBASE-14466 Remove duplicated code from MOB sn...</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>148</th>\n",
" <td>CLOUDSTACK-5557: UI &gt; Network &gt; Guest Network ...</td>\n",
" <td>cloudStack 3.0 new UI - NaaS - fix a bug that ...</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>287</th>\n",
" <td>Fixed adding route for additional public nic o...</td>\n",
" <td>Improved performance by avoiding unnecessary N...</td>\n",
" <td>0.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1186</th>\n",
" <td>Adds validation of code smells in case express...</td>\n",
" <td>Refactor test code, remove duplicates at Failo...</td>\n",
" <td>1.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-17296452-22c9-4a0a-8263-990b9dd68c1e')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-17296452-22c9-4a0a-8263-990b9dd68c1e button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-17296452-22c9-4a0a-8263-990b9dd68c1e');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 14
}
],
"source": [
"# Check how the prepared training data looks like.\n",
"\n",
"dataset.sample(10)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "t9kx8KEYmDLk",
"outputId": "13cadf27-7c1e-4493-c1e4-f3a85ef8cbb8"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 1900 entries, 0 to 1899\n",
"Data columns (total 3 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 text_left 1900 non-null object \n",
" 1 text_right 1900 non-null object \n",
" 2 target 1900 non-null float64\n",
"dtypes: float64(1), object(2)\n",
"memory usage: 44.7+ KB\n"
]
}
],
"source": [
"# Check the structure of the prepared training data.\n",
"\n",
"dataset.info()"
]
},
{
"source": [
"From a training set of 100 samples were able to create 1900 samples for training the siamese network."
],
"cell_type": "markdown",
"metadata": {
"id": "OyMN6NeimDLk"
}
},
{
"source": [
"#### Model preparation and training"
],
"cell_type": "markdown",
"metadata": {
"id": "or2QnX4PmDLl"
}
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "QGWIKKMBmDLm"
},
"outputs": [],
"source": [
"MAX_SEQ_LENGTH = 100\n",
"VOCAB_SIZE = 10000"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "XxMIExxemDLm",
"outputId": "4095705d-6c53-4ea5-c24d-29723f1e4258",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found 697 unique tokens.\n",
"(1900, 100)\n",
"(1900, 100)\n"
]
}
],
"source": [
"tokenizer = Tokenizer(num_words=VOCAB_SIZE)\n",
"tokenizer.fit_on_texts(df_train.text)\n",
"sequences_left = tokenizer.texts_to_sequences(dataset.text_left)\n",
"sequences_right = tokenizer.texts_to_sequences(dataset.text_right)\n",
"\n",
"word_index = tokenizer.word_index\n",
"print('Found %s unique tokens.' % len(word_index))\n",
"\n",
"# Inputs to the network are zero-padded sequences of word indices. These inputs are vectors of fixed length.\n",
"x_left = pad_sequences(sequences_left, maxlen=MAX_SEQ_LENGTH)\n",
"x_right = pad_sequences(sequences_right, maxlen=MAX_SEQ_LENGTH)\n",
"\n",
"print(x_left.shape)\n",
"print(x_right.shape)"
]
},
{
"cell_type": "code",
"source": [
"# Check the shape of the x_left tensor.\n",
"x_left"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "hVC8WiO9oStK",
"outputId": "36199e4e-1b34-4955-b7c0-ba50effc3058"
},
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 0, 0, 0, ..., 97, 67, 22],\n",
" [ 0, 0, 0, ..., 97, 67, 22],\n",
" [ 0, 0, 0, ..., 97, 67, 22],\n",
" ...,\n",
" [ 0, 0, 0, ..., 41, 662, 192],\n",
" [ 0, 0, 0, ..., 136, 3, 321],\n",
" [ 0, 0, 0, ..., 688, 14, 55]], dtype=int32)"
]
},
"metadata": {},
"execution_count": 18
}
]
},
{
"cell_type": "code",
"source": [
"# Check the location of the downloaded vectors.\n",
"# !ls"
],
"metadata": {
"id": "XLurFaVvsXEE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "9LBSHmT6mDLn",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b59c3931-2f58-4095-c4c6-e265b5f226ee"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"GloVe data loaded\n"
]
}
],
"source": [
"# Load the glove vectors \n",
"\n",
"embeddings_index = {}\n",
"\n",
"f = open('glove.6B.100d.txt',encoding=\"utf8\")\n",
"for line in f:\n",
" values = line.split(' ')\n",
" word = values[0] ## The first entry is the word\n",
"\n",
" # These are the vectors representing the embedding for the word\n",
" coefs = np.asarray(values[1:], dtype='float32') \n",
" embeddings_index[word] = coefs\n",
"f.close()\n",
"\n",
"print('GloVe data loaded')"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "vhdmr4ypmDLn"
},
"outputs": [],
"source": [
"# The final state of the LSTM for each text is 100 dimensional vector. \n",
"\n",
"EMBEDDING_DIM = 100"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "8T06quW0mDLo"
},
"outputs": [],
"source": [
"num_words = min(VOCAB_SIZE, len(word_index)) + 1\n",
"embedding_matrix = np.zeros((num_words, EMBEDDING_DIM))\n",
"for word, i in word_index.items():\n",
" if i > VOCAB_SIZE:\n",
" continue\n",
" # This references the loaded embeddings dictionary\n",
" embedding_vector = embeddings_index.get(word)\n",
" if embedding_vector is not None:\n",
" # words not found in embedding index will be all-zeros.\n",
" embedding_matrix[i] = embedding_vector\n",
"\n"
]
},
{
"cell_type": "markdown",
"source": [
"These vectors are then fed into the embedding layer. This layer looks up the corresponding embedding for each word and encapsulates all them into a matrix. This matrix represents the given text as a series of embeddings.\n",
"\n",
"Once we have created the pairs dataset and preprocessed the code commits we can turn our attention to the siamese model itself. It consists of a cloned sequential network, the input of which is a pair of vectors x_left and x_right. The representations of the right and left inputs are used to compute the similarity between the commits: \n",
"\n",
"sim(x$_{l}$,x$_{r}$) = exp(−‖f(x$_{l}$) − f(x$_{r}$)‖$_{1}$),\n",
"\n",
"where sim ∈ [0, 1], ‖ ⋅ ‖$_{1}$ is the L1 norm, and f is the function corresponding to the application of the cloned sequential network to the left/right input.\n",
"\n",
"This setting is called the Manhattan LSTM because we'll use LSTMs as the sequential network, and the L1 norm (used to compute the distance between two samples of a pair) is also called the Manhattan distance."
],
"metadata": {
"id": "Q86UNz_ouCS4"
}
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "3fji3c9_mDLn"
},
"outputs": [],
"source": [
"def exponent_neg_manhattan_distance(arms_difference):\n",
" \"\"\" Compute the exponent of the opposite of the L1 norm of a vector, to get the left/right inputs\n",
" similarity from the inputs differences. This function is used to turn the unbounded\n",
" L1 distance to a similarity measure between 0 and 1\"\"\"\n",
"\n",
" return K.exp(-K.sum(K.abs(arms_difference), axis=1, keepdims=True))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "AAHY93dgmDLo",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4265acef-3af2-4c2b-d708-9fbf67e00873"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_left (InputLayer) [(None, 100)] 0 [] \n",
" \n",
" input_right (InputLayer) [(None, 100)] 0 [] \n",
" \n",
" sequential_network (Sequential (None, 64) 103848 ['input_left[0][0]', \n",
" ) 'input_right[0][0]'] \n",
" \n",
" pair_representations_differenc (None, 64) 0 ['sequential_network[0][0]', \n",
" e (Subtract) 'sequential_network[1][0]'] \n",
" \n",
" masltsm_distance (Lambda) (None, 1) 0 ['pair_representations_difference\n",
" [0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 103,848\n",
"Trainable params: 34,048\n",
"Non-trainable params: 69,800\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"def siamese_lstm_model(max_length):\n",
"\n",
" input_shape = (max_length,)\n",
" input_left = Input(input_shape,name = 'input_left')\n",
" input_right = Input(input_shape,name = 'input_right')\n",
"\n",
" # load pre-trained word embeddings into an Embedding layer\n",
" # note that we set trainable = False so as to keep the embeddings fixed\n",
" embedding_layer = Embedding(num_words,\n",
" EMBEDDING_DIM,\n",
" embeddings_initializer=Constant(embedding_matrix),\n",
" input_length=max_length,\n",
" trainable=False)\n",
"\n",
" seq = Sequential(name='sequential_network')\n",
" seq.add(embedding_layer)\n",
" seq.add(Bidirectional(LSTM(32, dropout=0.3, recurrent_dropout=0.)))\n",
"\n",
" output_left = seq(input_left)\n",
" output_right = seq(input_right)\n",
"\n",
" # Here we subtract the neuron values of the last layer from the left arm \n",
" # with the corresponding values from the right arm.\n",
"\n",
" subtracted = Subtract(name='pair_representations_difference')([output_left, output_right])\n",
" malstm_distance = Lambda(exponent_neg_manhattan_distance, \n",
" name='masltsm_distance')(subtracted)\n",
" \n",
" # Pack it into a model\n",
" siamese_net = Model(inputs=[input_left, input_right], outputs=malstm_distance)\n",
" siamese_net.compile(loss=\"binary_crossentropy\", optimizer='adam', metrics=['accuracy'])\n",
"\n",
" return siamese_net\n",
"\n",
"\n",
"siamese_lstm = siamese_lstm_model(MAX_SEQ_LENGTH)\n",
"siamese_lstm.summary()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "XKcN5jYZmDLo",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "043df138-2232-4f6b-8746-7c709695abef"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/50\n",
"42/42 [==============================] - 8s 198ms/step - loss: 0.5629 - accuracy: 0.7887 - val_loss: 0.8310 - val_accuracy: 0.3737\n",
"Epoch 2/50\n",
"42/42 [==============================] - 9s 216ms/step - loss: 0.5482 - accuracy: 0.8000 - val_loss: 0.8228 - val_accuracy: 0.3912\n",
"Epoch 3/50\n",
"42/42 [==============================] - 6s 153ms/step - loss: 0.5232 - accuracy: 0.8263 - val_loss: 0.8130 - val_accuracy: 0.4263\n",
"Epoch 4/50\n",
"42/42 [==============================] - 9s 208ms/step - loss: 0.5152 - accuracy: 0.8278 - val_loss: 0.8061 - val_accuracy: 0.4596\n",
"Epoch 5/50\n",
"42/42 [==============================] - 6s 153ms/step - loss: 0.4976 - accuracy: 0.8474 - val_loss: 0.7996 - val_accuracy: 0.5105\n",
"Epoch 6/50\n",
"42/42 [==============================] - 9s 209ms/step - loss: 0.4849 - accuracy: 0.8609 - val_loss: 0.7912 - val_accuracy: 0.5246\n",
"Epoch 7/50\n",
"42/42 [==============================] - 7s 155ms/step - loss: 0.4749 - accuracy: 0.8624 - val_loss: 0.7899 - val_accuracy: 0.5474\n",
"Epoch 8/50\n",
"42/42 [==============================] - 9s 209ms/step - loss: 0.4550 - accuracy: 0.8759 - val_loss: 0.7710 - val_accuracy: 0.5772\n",
"Epoch 9/50\n",
"42/42 [==============================] - 6s 152ms/step - loss: 0.4478 - accuracy: 0.8977 - val_loss: 0.7756 - val_accuracy: 0.5912\n",
"Epoch 10/50\n",
"42/42 [==============================] - 9s 213ms/step - loss: 0.4399 - accuracy: 0.8992 - val_loss: 0.7634 - val_accuracy: 0.6088\n",
"Epoch 11/50\n",
"42/42 [==============================] - 6s 153ms/step - loss: 0.4200 - accuracy: 0.9195 - val_loss: 0.7633 - val_accuracy: 0.6053\n",
"Epoch 12/50\n",
"42/42 [==============================] - 9s 213ms/step - loss: 0.4153 - accuracy: 0.9143 - val_loss: 0.7427 - val_accuracy: 0.6263\n",
"Epoch 13/50\n",
"42/42 [==============================] - 7s 156ms/step - loss: 0.4034 - accuracy: 0.9188 - val_loss: 0.7392 - val_accuracy: 0.6281\n",
"Epoch 14/50\n",
"42/42 [==============================] - 9s 226ms/step - loss: 0.3932 - accuracy: 0.9293 - val_loss: 0.7234 - val_accuracy: 0.6421\n",
"Epoch 15/50\n",
"42/42 [==============================] - 6s 152ms/step - loss: 0.3783 - accuracy: 0.9429 - val_loss: 0.7103 - val_accuracy: 0.6439\n",
"Epoch 16/50\n",
"42/42 [==============================] - 9s 212ms/step - loss: 0.3699 - accuracy: 0.9391 - val_loss: 0.6821 - val_accuracy: 0.6667\n",
"Epoch 17/50\n",
"42/42 [==============================] - 6s 155ms/step - loss: 0.3644 - accuracy: 0.9421 - val_loss: 0.6873 - val_accuracy: 0.6544\n",
"Epoch 18/50\n",
"42/42 [==============================] - 9s 217ms/step - loss: 0.3526 - accuracy: 0.9421 - val_loss: 0.6656 - val_accuracy: 0.6737\n",
"Epoch 19/50\n",
"42/42 [==============================] - 6s 149ms/step - loss: 0.3413 - accuracy: 0.9466 - val_loss: 0.6501 - val_accuracy: 0.6965\n",
"Epoch 20/50\n",
"42/42 [==============================] - 9s 217ms/step - loss: 0.3291 - accuracy: 0.9609 - val_loss: 0.6637 - val_accuracy: 0.6596\n",
"Epoch 21/50\n",
"42/42 [==============================] - 7s 156ms/step - loss: 0.3237 - accuracy: 0.9624 - val_loss: 0.6448 - val_accuracy: 0.6895\n",
"Epoch 22/50\n",
"42/42 [==============================] - 9s 210ms/step - loss: 0.3100 - accuracy: 0.9677 - val_loss: 0.6465 - val_accuracy: 0.6544\n",
"Epoch 23/50\n",
"42/42 [==============================] - 6s 148ms/step - loss: 0.3002 - accuracy: 0.9677 - val_loss: 0.6391 - val_accuracy: 0.6842\n",
"Epoch 24/50\n",
"42/42 [==============================] - 9s 210ms/step - loss: 0.2905 - accuracy: 0.9805 - val_loss: 0.6395 - val_accuracy: 0.6947\n",
"Epoch 25/50\n",
"42/42 [==============================] - 6s 148ms/step - loss: 0.2820 - accuracy: 0.9789 - val_loss: 0.6405 - val_accuracy: 0.6825\n",
"Epoch 26/50\n",
"42/42 [==============================] - 9s 225ms/step - loss: 0.2759 - accuracy: 0.9752 - val_loss: 0.6353 - val_accuracy: 0.6895\n",
"Epoch 27/50\n",
"42/42 [==============================] - 6s 151ms/step - loss: 0.2641 - accuracy: 0.9805 - val_loss: 0.6349 - val_accuracy: 0.6895\n",
"Epoch 28/50\n",
"42/42 [==============================] - 9s 220ms/step - loss: 0.2548 - accuracy: 0.9820 - val_loss: 0.6508 - val_accuracy: 0.6737\n",
"Epoch 29/50\n",
"42/42 [==============================] - 6s 149ms/step - loss: 0.2460 - accuracy: 0.9835 - val_loss: 0.6484 - val_accuracy: 0.6842\n",
"Epoch 30/50\n",
"42/42 [==============================] - 9s 211ms/step - loss: 0.2415 - accuracy: 0.9902 - val_loss: 0.6630 - val_accuracy: 0.6719\n",
"Epoch 31/50\n",
"42/42 [==============================] - 6s 152ms/step - loss: 0.2313 - accuracy: 0.9902 - val_loss: 0.6621 - val_accuracy: 0.6684\n",
"Epoch 32/50\n",
"42/42 [==============================] - 9s 222ms/step - loss: 0.2233 - accuracy: 0.9940 - val_loss: 0.6447 - val_accuracy: 0.6772\n",
"Epoch 33/50\n",
"42/42 [==============================] - 7s 159ms/step - loss: 0.2198 - accuracy: 0.9895 - val_loss: 0.6803 - val_accuracy: 0.6561\n",
"Epoch 34/50\n",
"42/42 [==============================] - 9s 218ms/step - loss: 0.2052 - accuracy: 0.9925 - val_loss: 0.6422 - val_accuracy: 0.6737\n",
"Epoch 35/50\n",
"42/42 [==============================] - 7s 176ms/step - loss: 0.2047 - accuracy: 0.9895 - val_loss: 0.6686 - val_accuracy: 0.6667\n",
"Epoch 36/50\n",
"42/42 [==============================] - 9s 209ms/step - loss: 0.1988 - accuracy: 0.9902 - val_loss: 0.6146 - val_accuracy: 0.6772\n",
"Epoch 37/50\n",
"42/42 [==============================] - 6s 150ms/step - loss: 0.1973 - accuracy: 0.9902 - val_loss: 0.6783 - val_accuracy: 0.6684\n",
"Epoch 38/50\n",
"42/42 [==============================] - 9s 211ms/step - loss: 0.1891 - accuracy: 0.9917 - val_loss: 0.6993 - val_accuracy: 0.6632\n",
"Epoch 39/50\n",
"42/42 [==============================] - 6s 152ms/step - loss: 0.1795 - accuracy: 0.9970 - val_loss: 0.6686 - val_accuracy: 0.6754\n",
"Epoch 40/50\n",
"42/42 [==============================] - 9s 207ms/step - loss: 0.1779 - accuracy: 0.9947 - val_loss: 0.6839 - val_accuracy: 0.6684\n",
"Epoch 41/50\n",
"42/42 [==============================] - 6s 152ms/step - loss: 0.1724 - accuracy: 0.9925 - val_loss: 0.6766 - val_accuracy: 0.6702\n",
"Epoch 42/50\n",
"42/42 [==============================] - 9s 212ms/step - loss: 0.1662 - accuracy: 0.9955 - val_loss: 0.6583 - val_accuracy: 0.6789\n",
"Epoch 43/50\n",
"42/42 [==============================] - 6s 149ms/step - loss: 0.1660 - accuracy: 0.9895 - val_loss: 0.6694 - val_accuracy: 0.6667\n",
"Epoch 44/50\n",
"42/42 [==============================] - 9s 208ms/step - loss: 0.1641 - accuracy: 0.9940 - val_loss: 0.6792 - val_accuracy: 0.6719\n",
"Epoch 45/50\n",
"42/42 [==============================] - 6s 151ms/step - loss: 0.1543 - accuracy: 0.9932 - val_loss: 0.6918 - val_accuracy: 0.6561\n",
"Epoch 46/50\n",
"42/42 [==============================] - 9s 218ms/step - loss: 0.1531 - accuracy: 0.9962 - val_loss: 0.6604 - val_accuracy: 0.6649\n",
"Epoch 47/50\n",
"42/42 [==============================] - 6s 151ms/step - loss: 0.1473 - accuracy: 0.9955 - val_loss: 0.6723 - val_accuracy: 0.6702\n",
"Epoch 48/50\n",
"42/42 [==============================] - 9s 208ms/step - loss: 0.1486 - accuracy: 0.9955 - val_loss: 0.6183 - val_accuracy: 0.6842\n",
"Epoch 49/50\n",
"42/42 [==============================] - 6s 148ms/step - loss: 0.1513 - accuracy: 0.9925 - val_loss: 0.6328 - val_accuracy: 0.6895\n",
"Epoch 50/50\n",
"42/42 [==============================] - 9s 206ms/step - loss: 0.1428 - accuracy: 0.9962 - val_loss: 0.7270 - val_accuracy: 0.6404\n"
]
}
],
"source": [
"siamese_lstm.fit([x_left,x_right], dataset.target, validation_split=0.3, epochs=50);"
]
},
{
"source": [
"#### Model Predictions\n",
"\n",
"To address the initial problem of finding each git commit class, we compute, for each example in the test set, the similarity score of this example with all the examples in the training set. The predicted category is the one of the closest example in training set.\n",
"\n"
],
"cell_type": "markdown",
"metadata": {
"id": "nSmwskYcmDLp"
}
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"id": "tsiUgHkWmDLp"
},
"outputs": [],
"source": [
"reference_sequences = tokenizer.texts_to_sequences(df_train.text)\n",
"x_reference_sequences = pad_sequences(reference_sequences, maxlen=MAX_SEQ_LENGTH)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "fEcvEer3mDLp"
},
"outputs": [],
"source": [
"def flatten_text_sequence(text):\n",
" flatten = itertools.chain.from_iterable\n",
" text = list(flatten(text))\n",
" return text\n",
"\n",
"def get_prediction(text):\n",
" \"\"\" Get the predicted category, and the most similar text\n",
" in the train set. Note that this way of computing a prediction is highly \n",
" not optimal, but it'll be sufficient for us now. \"\"\"\n",
" x = tokenizer.texts_to_sequences(text.split())\n",
" x = flatten_text_sequence(x)\n",
" x = pad_sequences([x], maxlen=MAX_SEQ_LENGTH)\n",
" # x = np.array(x)\n",
" # print([x[0]]*len(x_reference_sequences))\n",
"\n",
" # print(x_reference_sequences.shape)\n",
"\n",
" # Compute similarities of the text with all text's in the train set.\n",
" # Computing similarities can be a slow process\n",
" # You can use something like FAISS for faster search.\n",
" result = np.repeat(x, len(x_reference_sequences), axis=0)\n",
" \n",
" # similarities = siamese_lstm.predict([[x[0]]*len(x_reference_sequences), x_reference_sequences])\n",
" similarities = siamese_lstm.predict([result, x_reference_sequences])\n",
" most_similar_index = np.argmax(similarities)\n",
" \n",
" # The predicted category is the one of the most similar example from the train set.\n",
" # print(most_similar_index)\n",
" prediction = df_train['class'].iloc[most_similar_index]\n",
" most_similar_example = df_train['text'].iloc[most_similar_index]\n",
"\n",
" return prediction, most_similar_example"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7AtfmRaamDLq",
"outputId": "ca847da5-f8b5-4e61-a219-4de10439fb2e"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(100, 100)\n"
]
}
],
"source": [
"x = df_train['text'].iloc[34]\n",
"# print(x)\n",
"\n",
"x = tokenizer.texts_to_sequences(x.split())\n",
"x = flatten_text_sequence(x)\n",
"x = pad_sequences([x], maxlen=MAX_SEQ_LENGTH) \n",
"# x\n",
"result = np.repeat(x, len(x_reference_sequences), axis=0)\n",
"print(result.shape)"
]
},
{
"cell_type": "markdown",
"source": [
"Testing on a prediction example"
],
"metadata": {
"id": "5UV7sf0i1skl"
}
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"tags": [],
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zeGIcVNOmDLq",
"outputId": "9e1533ea-af78-4a4d-dcd9-edc4f88aa08c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"4/4 [==============================] - 0s 34ms/step\n",
"Sampled Text: Revert \"CLOUDSTACK-7762 -[Automation] - Fix test failure for test_02_revert_vm_snapshots in smoke/test_vm_snapshots.py\"\n",
"True Class: 1\n",
"Predicted Class : 1\n",
"Most similar example in train set: [Automation] - Fix test failure for test_02_revert_vm_snapshots in smoke/test_vm_snapshots.py\n"
]
}
],
"source": [
"sample_idx = 22\n",
"\n",
"pred, most_sim = get_prediction(df_test.text[sample_idx])\n",
"\n",
"print(f'Sampled Text: {df_test[\"text\"].iloc[sample_idx]}')\n",
"print(f'True Class: {df_test[\"class\"].iloc[sample_idx]}')\n",
"print(f'Predicted Class : {pred}')\n",
"print(f'Most similar example in train set: {most_sim}')"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pe1slx4fmDLq",
"outputId": "799b1dc8-0288-4320-eeb1-04b5a798dd51"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"4/4 [==============================] - 0s 37ms/step\n",
"4/4 [==============================] - 0s 36ms/step\n",
"4/4 [==============================] - 0s 28ms/step\n",
"4/4 [==============================] - 0s 29ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 29ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 28ms/step\n",
"4/4 [==============================] - 0s 34ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 29ms/step\n",
"4/4 [==============================] - 0s 29ms/step\n",
"4/4 [==============================] - 0s 31ms/step\n",
"4/4 [==============================] - 0s 31ms/step\n",
"4/4 [==============================] - 0s 31ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 29ms/step\n",
"4/4 [==============================] - 0s 28ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 29ms/step\n",
"4/4 [==============================] - 0s 33ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 33ms/step\n",
"4/4 [==============================] - 0s 46ms/step\n",
"4/4 [==============================] - 0s 59ms/step\n",
"4/4 [==============================] - 0s 52ms/step\n",
"4/4 [==============================] - 0s 53ms/step\n",
"4/4 [==============================] - 0s 54ms/step\n",
"4/4 [==============================] - 0s 54ms/step\n",
"4/4 [==============================] - 0s 55ms/step\n",
"4/4 [==============================] - 0s 56ms/step\n",
"4/4 [==============================] - 0s 53ms/step\n",
"4/4 [==============================] - 0s 56ms/step\n",
"4/4 [==============================] - 0s 56ms/step\n",
"4/4 [==============================] - 0s 50ms/step\n",
"4/4 [==============================] - 0s 52ms/step\n",
"4/4 [==============================] - 0s 39ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 27ms/step\n",
"4/4 [==============================] - 0s 28ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 28ms/step\n",
"4/4 [==============================] - 0s 28ms/step\n",
"4/4 [==============================] - 0s 30ms/step\n",
"4/4 [==============================] - 0s 31ms/step\n",
"4/4 [==============================] - 0s 29ms/step\n",
"4/4 [==============================] - 0s 29ms/step\n",
"Test accuracy (siamese model): 80.00 %\n"
]
}
],
"source": [
"from sklearn.metrics import accuracy_score\n",
"\n",
"df_eval = df_test[:50]\n",
"\n",
"y_pred = [get_prediction(text)[0] for text in df_eval['text']]\n",
"accuracy = accuracy_score(y_pred, df_eval['class'])\n",
"\n",
"print(f'Test accuracy (siamese model): {100*accuracy:.2f} %')"
]
},
{
"cell_type": "markdown",
"source": [
"We used binary cross entropy (BCE) loss function. However using BCE might not be the best choice of a loss function.\n",
"\n",
"We see that with no data cleaning, we are able to perform text classification on git commits. However we had the train the model for around 50 epochs to achieve that. \n",
"\n"
],
"metadata": {
"id": "vdXl8GahXJ1j"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment