Skip to content

Instantly share code, notes, and snippets.

@sayan1999
Created August 30, 2020 05:59
Show Gist options
  • Save sayan1999/d008ef965c72371602c399284b7ab189 to your computer and use it in GitHub Desktop.
Save sayan1999/d008ef965c72371602c399284b7ab189 to your computer and use it in GitHub Desktop.
word_seq2seq-extended.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"colab": {
"name": "word_seq2seq-extended.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"include_colab_link": true
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/sayan1999/d008ef965c72371602c399284b7ab189/word_seq2seq-extended.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "zMy7gQFI9qXy",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 900
},
"outputId": "a69f4c6c-39b4-4dc4-b481-ffd690d9b25a"
},
"source": [
"!pip install bidict\n",
"!pip install pixiedust\n",
"import nltk\n",
"nltk.download('punkt')"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"Collecting bidict\n",
" Downloading https://files.pythonhosted.org/packages/7a/7a/1fcfc397e61b22091267aa767266d8ab200a00b7dbf3aadead7fd41a74b9/bidict-0.21.0-py2.py3-none-any.whl\n",
"Installing collected packages: bidict\n",
"Successfully installed bidict-0.21.0\n",
"Collecting pixiedust\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/16/ba/7488f06b48238205562f9d63aaae2303c060c5dfd63b1ddd3bd9d4656eb1/pixiedust-1.1.18.tar.gz (197kB)\n",
"\u001b[K |████████████████████████████████| 204kB 8.7MB/s \n",
"\u001b[?25hCollecting mpld3\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/66/31/89bd2afd21b920e3612996623e7b3aac14d741537aa77600ea5102a34be0/mpld3-0.5.1.tar.gz (1.0MB)\n",
"\u001b[K |████████████████████████████████| 1.0MB 16.7MB/s \n",
"\u001b[?25hRequirement already satisfied: lxml in /usr/local/lib/python3.6/dist-packages (from pixiedust) (4.2.6)\n",
"Collecting geojson\n",
" Downloading https://files.pythonhosted.org/packages/e4/8d/9e28e9af95739e6d2d2f8d4bef0b3432da40b7c3588fbad4298c1be09e48/geojson-2.5.0-py2.py3-none-any.whl\n",
"Requirement already satisfied: astunparse in /usr/local/lib/python3.6/dist-packages (from pixiedust) (1.6.3)\n",
"Requirement already satisfied: markdown in /usr/local/lib/python3.6/dist-packages (from pixiedust) (3.2.2)\n",
"Collecting colour\n",
" Downloading https://files.pythonhosted.org/packages/74/46/e81907704ab203206769dee1385dc77e1407576ff8f50a0681d0a6b541be/colour-0.1.5-py2.py3-none-any.whl\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from pixiedust) (2.23.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from mpld3->pixiedust) (2.11.2)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from mpld3->pixiedust) (3.2.2)\n",
"Requirement already satisfied: six<2.0,>=1.6.1 in /usr/local/lib/python3.6/dist-packages (from astunparse->pixiedust) (1.15.0)\n",
"Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.6/dist-packages (from astunparse->pixiedust) (0.35.1)\n",
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from markdown->pixiedust) (1.7.0)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pixiedust) (1.24.3)\n",
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pixiedust) (2.10)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pixiedust) (2020.6.20)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pixiedust) (3.0.4)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->mpld3->pixiedust) (1.1.1)\n",
"Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (1.18.5)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (2.8.1)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (0.10.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (1.2.0)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->mpld3->pixiedust) (2.4.7)\n",
"Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->markdown->pixiedust) (3.1.0)\n",
"Building wheels for collected packages: pixiedust, mpld3\n",
" Building wheel for pixiedust (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pixiedust: filename=pixiedust-1.1.18-cp36-none-any.whl size=321727 sha256=aca85894b80a6fe25fb05217aa1095c664d98c2cb4fdd4eb17715d867b26db89\n",
" Stored in directory: /root/.cache/pip/wheels/e8/b1/86/c2f2e16e6bf9bfe556f9dbf8adb9f41816c476d73078c7d0eb\n",
" Building wheel for mpld3 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for mpld3: filename=mpld3-0.5.1-cp36-none-any.whl size=364064 sha256=6ca44fc5e92e2085a5b69fe24c1291795c168a4c20cd1f7e5e4d6066281036b0\n",
" Stored in directory: /root/.cache/pip/wheels/38/68/06/d119af6c3f9a2d1e123c1f72d276576b457131b3a7bf94e402\n",
"Successfully built pixiedust mpld3\n",
"Installing collected packages: mpld3, geojson, colour, pixiedust\n",
"Successfully installed colour-0.1.5 geojson-2.5.0 mpld3-0.5.1 pixiedust-1.1.18\n",
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Unzipping tokenizers/punkt.zip.\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {
"tags": []
},
"execution_count": 1
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aLt_9sNL9qX9",
"colab_type": "code",
"colab": {}
},
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import random\n",
"import re, os, difflib\n",
"from matplotlib import pyplot as plt\n",
"from sklearn.utils import shuffle\n",
"from sklearn.model_selection import train_test_split\n",
"from tensorflow.keras.layers import Input, LSTM, Embedding, Dense, Attention, Bidirectional, Concatenate\n",
"from tensorflow.keras.models import Model, Sequential\n",
"import tensorflow as tf\n",
"from nltk import word_tokenize\n",
"from gensim.models import Word2Vec\n",
"from sklearn.preprocessing import OneHotEncoder\n",
"from bidict import bidict\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from sklearn.metrics import classification_report\n",
"from tensorflow.keras.callbacks import TensorBoard\n",
"from tensorflow.keras.utils import plot_model\n",
"from tensorflow.keras import backend as K"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "osRUcUVb9qYA",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 92
},
"outputId": "a6171656-ac7d-4778-8cd5-9cb57693895e"
},
"source": [
"# parameters\n",
"# if running on colab turn this false, and select GPU runtime\n",
"batch_size=32 if not tf.test.is_gpu_available() else 256\n",
"colab=True\n",
"training=True\n",
"validation=True\n",
"ctx_vec_len=128\n",
"embedding_dim=128\n",
"epochs=25\n",
"# either length or list of index such as range(1, 2200)\n",
"training_samples=100\n",
"dropout=0.2\n",
"weight_file='word-seq2seq.hdf5'"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"WARNING:tensorflow:From <ipython-input-3-3badf1bfea45>:3: is_gpu_available (from tensorflow.python.framework.test_util) is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"Use `tf.config.list_physical_devices('GPU')` instead.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wf7r24jQ9qYH",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 225
},
"outputId": "71b10c13-8b88-4a4b-e0c2-fc1267cf6437"
},
"source": [
"from IPython.display import display, Markdown\n",
"if not colab:\n",
" display(Markdown('''## Architecture For Neural Machine Trans\n",
"![Architecture Neural Machine Trans](image/NeuralMachineTrans.jpg)'''))\n",
" \n",
"else:\n",
" display(Markdown('''## Architecture For Neural Machine Trans\n",
"![Architecture Neural Machine Trans](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAVIAAACVCAMAAAA9kYJlAAABa1BMVEX///+S0FBEcsQAsPAAAACU004oW638/Py+vr7Pz8/39/f/3ZrY2ur/0GcAs/Pa3uYAb7OU1EhrpF1BccagqsI4a8AAgMP/7s//vQDl5vLv7+9tiZzs6+/j4+NhmmJNg2t1jKVGe26DwU+Ol7W2vc/IztopVZ1VcKd+u1LZ2dmMlsDKzeJzrlmutNKhoaG3vNdRUVFtbW1eXl59fX1qamq1tbWmpqZ3d3eLi4tHR0c3NzeXl5e5ubklJSUxMTEfHx9fg4sTExP/xiH/9eMyvPL/8s3/9+L/6sB6j7dge64AnN85X6Pf9P3/zVr/3oz/2G//56j/y0X/13dJaKRLS1KusL1FR2AAADJTc3t+mKomOyBSf0phlG9gjIORortkmWkARplRg7ZcgzJEYCVUg3yK0va8wc5/hI4SHjMJEiEwUIo6UiN2qEAxMlI6VzolQHC14vljyPUzLzgeN2JqgqQvRz9hYnOJj6xfZYZlJtHGAAAShElEQVR4nO2djX/UNprHBShYErRsgGyCiCgNoaCD9fpNfpM9w0AbaHkphe217NH29qXs7l332u719u7Pv0eemTAvdqwkkwSCfx8YO/YzsvT1oxd79IJQp06dOnXq1KlTp7dFceIfdRSOmUKNuJr429UTf0R0/gt9dtBRessleuaTpUng8CSMCM8zjdIkZSlhESkDoJ1SRLUIExf4J3GMIymSJEUk9GuAd0IyM59pjAqfDwTLUewiN0ZxyjPw3wpaDIyJL8FlQ4WokzOUUxTGJHOOOO5vqEQJH07AEcpFjpyB47ooiZJEobB0hkh57gSo5yeJyAT8mXMEVKlP0iOO+psqDoSQA2WmiPgYaVrACSfzyKgsDcFPK5qeRMzpMTTg4Lod0ibJTCdUejoTvI8YZqKnuAeHIiKw4waGKcUUkUAnHMwSFIZCRdpjKjzqqL+xYpybD6jHR/9Y9df2H6jagWPO6AP+m63TFaWd9iAax6qm0RgUhx+V46JcClGTG31y+FE5LspN4RerBKoUNyGIJiFHLNFQcxeJRiJOVXsYnaaUxUqyXoGgJaRRTDPT6PQo7xESIqXlQOwiLNk9GRnlVHJ4/oHGeCQR0gqa8DRgkPHTIPRD6e0mLNqVFkZVxjdImXnq0S5yPF4hDQ0fGewmrA5ppVwRykvklIxEJBQejUOUpqpPeE4V7bx0D5KUUgZZXjpARCBBKNT/VApudpE5Y68O6cLVIV24OqQLV4d04eqQLlwd0ntnthYbYId09e4niw2wQ3qISNnovTRCDmcTR4+bDhEp5mT4SsXx9OunB4/vENj4HYx8m1z/sJByKvucc4AkEcGCCyrgWU1w2iNwjjb4ajz65S9O2i/MTd+XYuoOsWw379AWpUNCKjISYpbEKqUeLTBVWFHMkda0JCIgblTDlCdhqJFMfIHi178AOtsfk5/wIfqwJez1CUDaE6//MMedut1F65CQ+gVCAxaq0LDhGBHPFARIu07GKOVxPv8egZUS5ZphxnKxjZRmac5ZkEYp6wVhxpHrJT5iUZgLjZWI4ATcvtTjpntRyXuCB2FvFHYUpCEKBg5KFMJhumOJsx99cvdQGlEBMa8Ow5iFvicM0gA5FVKWcRkFRTaP1LwAU9oNhQjU9u/UGUXCo3CGc9h1U9lDKCRaA1XRQ3GB5IB7ULS4qTY9DXIeutW9q+JgruH3FAoVy1h1Vw9En6wdCtJIQcoAKeyqSAyRDhhKASmD1LMapDQzSNNIKbLtpbyXhAmEU4bSvNAVXpyHYZYm5pqiRCzNUlwVAFUnOEAqPajg8LBEjSqkccYMUg5hvS0/ftcjdfLCN2WpTohHwUsLcBEgMXBRrmhOUjz/8wobCBRo0/UqI2OkUAZAeQk7EnMPHC1SAbAUCRQrBJBWN04AKi71EKkw2WOMNIAcH1EdpxVSdmA5/8n9xYbXVONLKZCp8aVEjkQcUsmkhANCQiuAi5rkiSAMXBRHUYzUuM+P9lmaipDTkuUu8OIeEbmkGVe+6DthyDXmbsrDmARc57wUhSfc0f2ovDSiPE+KykvzxSb8tT5/vNjwDvjpibjKfLice7ELrs1cF+4ONAuQUyjHdaliSLngtYVLC6Y4oq47+mrVpyiiKMYFy5kpGtq0urp6DzZbsF11hn+vrsL23urozNwBY3rm6drq1vikMzq5NQ4HrJz6cBrLoUUjdeP64zybdOzQwuViH8C7sqqn/CRp7/dxf21t7QlsH8N2DVJ+z2y/gANnzI7J3Wa7NrKsTN8zFluPq/275uDW6KRx3M9HXxuGAwE+HX9tbFovW6TM8km0n9ffPUdOHOcDvNNlmdHIXNm3/h8/szY9UNkide26RBKMLfplaIy9HervOAWN4lVGVtc1elP6w9kiLbFNhcs8jMvWlHGMrcgjc4sGu/p18hC19aQh61siBe9z262MGcat9YjKwUuteq2yAGN7Nz1cba01PMjaIWUAwcZNicJ9ZdElyMN2RWRhbpFlF6N7C35Ub9XdhvasHdLK+6z8qt+3geVhu+zsJrgMLft4rq7Z2S1M+0MqaYmJlWP1rDq5Bbauh3Bm+9L70JGu7QspQrllVs2tqpPIFqkzyG2fQ51VS8NFqWr118gWaWaZVe3s7JGWvaN4Lb0v2SL1LCFkVnb+jk39SeXl7pGun5vXXPHBH9x4PqtPv501cy7NB7XecvkWpOtnh7r0O/zleLeucGOXRif/9aux3dlLdYkdntr88vrYqiaGzublO2P1B18Pd15cb0nJtjafn57Tw09nvn7u+cqFOa2sPJ+OzfqDhzVhPdi55b0z0ot3rpwf6vf430Z7Ly+fnbO7dPXl+Vld+ebb2Uvfujxv9t2dzdnQrl3ZWFoeaun3eGO4s/HyNzsm5PW3T184OacLF/59iim7sTJvBFr5dDLO68C9JqyVG/NetX7tV/P6w7wbXD2/vDTSH/+0Md5d+nCW6aWvTyzNa/nmL9Nmrz5YrjP77uK02eb55RM1Wn45fytrdO5hLauTFz6d5HC9yerhuQmra/XcT56+OHvV9T98dur9WZ069etZw7Mvl+rSduLE1Zm7dPVmreHS+SkHXL9TH9zSN9NFxJ36iy7dfGGDdLMBwzSsJquTpycd63mD0cqnMxd1HnwGAGv0q5li7WI9qRNLH8wY/rkewomNKQhnrzQEd3s663/YcCOXLtsgvdgIa/dIP2owunBj5qLrv64neuqzGTf9l9oMCGm7MlPzfNAAYfnqpNWt2w1mG9Ol5IcNl11+C5GeutYh7ZB2SDukHdIOaYe0Q9oh7ZB2SDukHdIOaYf0wJDyhN5qQ+qYecl2i/T97+fI7h4pE5ZIHWGDtND8AJEyDhJCSIz/0fC+9OTpV0MbEeCeakfKYz6B9ONHOyGlhFkhxYmwQ5oFXzYjZaN0uHgQiqaXe/tEuvKXLMvzXr8/MJ0S8F9roT78a569tvnbf7QhZbjUzBKpwh5pRnqdEkJUHLvadJmQzUgZpSPDEuP//KEB6d89SGw5Sgf+r3qr/SPFg7LMs8wLDNF6Wg/xoF/2jE0P496PDUGd/DGJgsDEuYpywreR/vToJ7Pz06NHP4+Rpn4UQPp6Q9s/NiD97p94Uv9oevu/fLWYMsQ/1BpufIWH6fAyc4cab9B+M77J95wxhgbZX5qsPno1skn6mn1U88tT5aX/bRIzMPRhW6bONtKPK0f9+dGp998fI83HtuBU+KsGr1q68mWUhGGqtaswzmPe7KUiGBvm+G9/2qi327g4SqtDsS8PocYvnFetNT7lO1RPzylISik4xqmYzPhA9dSp73+ayPiSVLZCFLhfPLCo8bHL7Wr8LN1sr56E6QTwVjWi0pnqyTjqz+ClH388V5YK164RVf1Ga4PUeXMaUYtEOtQU0u8fnfp+oiyd0nFtlx4Y0qEeffzuNfUPGOlUQ6pD2j2Qdkg7pB3SDmmHtEPaIe2Qdkg7pG810uvHBOn5BrOb00gbekHPdABuUrP/TSJdb+rgfGOyN3LTrylzvaCtkV7caED6ck+9oM+9rL9FS1emRwl8a9VNvUnXTzd41vRgkc26QRInZwZJPGi4PSszoBqRvj/bC/pWQ0/w5TuzffXr3y0v3Z4K0HlRa7a0cWd65Mn61xt1oyRu/mI33v5BLdMZWID+xvagm8ahPOsP65heWHk+Nzyqsa/+7Eilq7eX5yEsLX1wa8bu3J/h6Jzd8sYMhPVfbm+MBt8sn9jeuX159rLr1z64MqvvvvnWCihC7NrD0yuVfny4Mtbp53NDfPj6SNG58d76LKpzN14HsR3UwxvzQ7XWb9iNKEHrL17evjnSD+Od89/MZ8Bbd87fnNXtly9mWbGLLy4Pded/RjuXX8yNiIObvn5ppH+Ody61jYmb0NlXm69ewX/8u1cjbV6vGdc2ktxptgV2fTuIf462m9frcovtuCeI3eZvh/rN3/93tHerLnHr14cn/++3r3V2p7k2rCahQNRycolaKbuh/AEe2JQoyi7GuxDBFrM9ItZrn3rIiOd2k0FkFtN6NMYlw9hiUjmBMW6YNWgqNK99NoxdCuJnMYRZ4Z7VvYwxHlh4ELWeA6U2LqD20b+J+Um4fehvgfEeBgjvJDNtRPvaJ2a+Dpt7yUwvgazVzDF9GKzH488qTnHPbb1vTCfYi9thmV/jbfKptZwIbmXZGj9ld8eR0B5O4tYszdwQR+7efcNuxglqg0p6JfZ2tZ5OmxypcFQ30d60Eij7AquR9hpbzVhCbYq5Jjl9Kw+ndsW6tp5CwFbS6sLconSopO1YUauplBpkOeMEtYuz3k/jo1YSBxZ1r8Ce3ewurh2r/SG1m3FC2rQLrJ1gF5KDnaZ0G4sPMrvaJLZrkuwTqdUES5ZILWO8C1kiLS0r6ENB6uVWSHOrjL94pGKQW2Rpa6TqEJDazmbJrJ4mYrzota25bm/0gFV/YNmMtEvH/pCCVJsHBrZzgDF5RAtftDe0dqW9IjX+6TjMzDI9+mNqwv6xUmaP9LBn8nSQM76HrbltPM2ozXSjBmllNh1o2zdd349RFmjPIQFKk9RnmKPe7AO1U/S1iJQL7U2q21pIZvr6Q5VHxKiJkU4sg5zUvhUIR5YWkzsjToQw9bH2J5nylpdCwpMy4x4FFyQRg2o1ISnh8807M5d9kHBPyMgskrtzmO3Pz3YiQUR1kvmhoKbZ7Pm+5n6UJVPXd6QICDOTmUtoRREmpBAMMSEyBeckn7ENlZn83EHpZJ4WFIKHABxHwNfR0AKyrUBFyRFX5i8TFBQsTPDSXKy5NqceKSjyZIVUhlC1xDRRNUUItPjMRNNEe2G0EzKdSEiMGTujIenwiOxKpPY2VS3PmFmC04OwEi+u1mWA7COm58Z2fK0HlATA1ddi4FKsSU4gXTLXPIrV5FouLEwVPIrmSidTSN2A5OCPJA35wIdvsEi7HnhGEdC4L2lAwF4FKuNBDAHzkpGAeI2loJnKm6JsiNTpGW6OV9Z4tkFKzSrDOz9rkhBRUUreMwuNpK6OeRnzbG+FK8tMshkgdQIBYUB1Dk09OY3UrEtg4l4EJtKl4NCO8gBpACnhLiLexNO8WVsjLMyCO4GcQGqSLT1qFiPgcAdUWCRwE4kpQzRcTQmhBqJnBmP5CskAvBR4subsT4OEIF+gEKKFSARfQmmdF2bgyxRFBRQTO83q5mShAxmfDxyzhEUkAxqG0rdkOCsRmLgDUqXNLR34EIlZpGbi7QS8FOkgIAZp9hopCyMdTTzIGVS6CL0wCcgEUlkmSRKg2AuUWXiDRmkWhgGpamO4Go0CgmVpLKMKaU8a17ea9nc8j21YV//ImFNwaIForHZ0OpXJCincdJowL6ZRbDlRbY2KvELqB2kvBS+l0RzS2K+81DQAZe6UUIRys2xMhdQsXBJNJIaaZV8KA7MQE0gFPEY45k/ZF1CLxKGG20RkZLyU5igjUCMZL5UkURDwyEsH7Q1FGo2wO3ofrTviFj7U+BxDzGgkkS65u9f30RxyeWaQirJafnsgTFk6gxRFWmNASgOapghLAblYeTQI4ITANB5MlJlOCmWpQr0i9pzJGl8HNJOQ9XXEylAFggdunJlVp5JY9GQaUR87KiKeJB6JAqjxYesd3tJnlJo6xTFrPptqFBq7fK/zqDM/icAdAlGYhzGob7jJ5rPrcDuUCmboU1MqsGpRArOOCeLUEVTyqfsJtny4rtnUU0EVVQjAAf8zg7hYtfIZr0KU5gJm/eqhDRfmChDwHhP1zklYvso6dpI+1CfpgST++C0h2eng9PjJUcegVU2zsL+pevaGrEuzg5rWCnhT9ezzo45Bm5y3DemTL446Bq1a9JLhB637X7whyycdL50xuofQltmaXHbf7IBvOGb7FJh/Mj7ztNUU3TuzOjY0S9RNGd4bXw22q2Y7aeqMTVdnTd8+ff4eCLCsmq1pAZyB7eNPgIc58wwSd9+ceQpnno1Mt2ZNnbHps7tV2ffEHAdOW49hx2B5ag6YLGwOvAfb+7DzeNL03pTp6oTpO65na0+POgrHTWcWvdBvp6cd0kXr6d2G9ZE67Vkd0U6dOnXq1Ok4ypHVL3Pzx6Pul7g9SpRE1fayCTqke5TpO8V7DCkzUst8yFgxxGPiSURjigSNO7a7U9UdbSDSQmTM1VJLj8cRigrSFyRlPqV98o7+wLxnVUhLWabpgJhOXilBLKMBZHwRRKmn6V57ar27MkhZZvr+wAaZrnjbSD2T48lCx2S+ExI9JiKFEsVdlrrCpR5TieNROhAqYUR2SHctprQ23uiaT1cLRM2oFKEJFKFEKyQOr59Yp06dOnXqdAT6f2PCMLGwm5Z8AAAAAElFTkSuQmCC)'''))"
],
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/markdown": "## Architecture For Neural Machine Trans\n![Architecture Neural Machine Trans](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAVIAAACVCAMAAAA9kYJlAAABa1BMVEX///+S0FBEcsQAsPAAAACU004oW638/Py+vr7Pz8/39/f/3ZrY2ur/0GcAs/Pa3uYAb7OU1EhrpF1BccagqsI4a8AAgMP/7s//vQDl5vLv7+9tiZzs6+/j4+NhmmJNg2t1jKVGe26DwU+Ol7W2vc/IztopVZ1VcKd+u1LZ2dmMlsDKzeJzrlmutNKhoaG3vNdRUVFtbW1eXl59fX1qamq1tbWmpqZ3d3eLi4tHR0c3NzeXl5e5ubklJSUxMTEfHx9fg4sTExP/xiH/9eMyvPL/8s3/9+L/6sB6j7dge64AnN85X6Pf9P3/zVr/3oz/2G//56j/y0X/13dJaKRLS1KusL1FR2AAADJTc3t+mKomOyBSf0phlG9gjIORortkmWkARplRg7ZcgzJEYCVUg3yK0va8wc5/hI4SHjMJEiEwUIo6UiN2qEAxMlI6VzolQHC14vljyPUzLzgeN2JqgqQvRz9hYnOJj6xfZYZlJtHGAAAShElEQVR4nO2djX/UNprHBShYErRsgGyCiCgNoaCD9fpNfpM9w0AbaHkphe217NH29qXs7l332u719u7Pv0eemTAvdqwkkwSCfx8YO/YzsvT1oxd79IJQp06dOnXq1KlTp7dFceIfdRSOmUKNuJr429UTf0R0/gt9dtBRessleuaTpUng8CSMCM8zjdIkZSlhESkDoJ1SRLUIExf4J3GMIymSJEUk9GuAd0IyM59pjAqfDwTLUewiN0ZxyjPw3wpaDIyJL8FlQ4WokzOUUxTGJHOOOO5vqEQJH07AEcpFjpyB47ooiZJEobB0hkh57gSo5yeJyAT8mXMEVKlP0iOO+psqDoSQA2WmiPgYaVrACSfzyKgsDcFPK5qeRMzpMTTg4Lod0ibJTCdUejoTvI8YZqKnuAeHIiKw4waGKcUUkUAnHMwSFIZCRdpjKjzqqL+xYpybD6jHR/9Y9df2H6jagWPO6AP+m63TFaWd9iAax6qm0RgUhx+V46JcClGTG31y+FE5LspN4RerBKoUNyGIJiFHLNFQcxeJRiJOVXsYnaaUxUqyXoGgJaRRTDPT6PQo7xESIqXlQOwiLNk9GRnlVHJ4/oHGeCQR0gqa8DRgkPHTIPRD6e0mLNqVFkZVxjdImXnq0S5yPF4hDQ0fGewmrA5ppVwRykvklIxEJBQejUOUpqpPeE4V7bx0D5KUUgZZXjpARCBBKNT/VApudpE5Y68O6cLVIV24OqQLV4d04eqQLlwd0ntnthYbYId09e4niw2wQ3qISNnovTRCDmcTR4+bDhEp5mT4SsXx9OunB4/vENj4HYx8m1z/sJByKvucc4AkEcGCCyrgWU1w2iNwjjb4ajz65S9O2i/MTd+XYuoOsWw379AWpUNCKjISYpbEKqUeLTBVWFHMkda0JCIgblTDlCdhqJFMfIHi178AOtsfk5/wIfqwJez1CUDaE6//MMedut1F65CQ+gVCAxaq0LDhGBHPFARIu07GKOVxPv8egZUS5ZphxnKxjZRmac5ZkEYp6wVhxpHrJT5iUZgLjZWI4ATcvtTjpntRyXuCB2FvFHYUpCEKBg5KFMJhumOJsx99cvdQGlEBMa8Ow5iFvicM0gA5FVKWcRkFRTaP1LwAU9oNhQjU9u/UGUXCo3CGc9h1U9lDKCRaA1XRQ3GB5IB7ULS4qTY9DXIeutW9q+JgruH3FAoVy1h1Vw9En6wdCtJIQcoAKeyqSAyRDhhKASmD1LMapDQzSNNIKbLtpbyXhAmEU4bSvNAVXpyHYZYm5pqiRCzNUlwVAFUnOEAqPajg8LBEjSqkccYMUg5hvS0/ftcjdfLCN2WpTohHwUsLcBEgMXBRrmhOUjz/8wobCBRo0/UqI2OkUAZAeQk7EnMPHC1SAbAUCRQrBJBWN04AKi71EKkw2WOMNIAcH1EdpxVSdmA5/8n9xYbXVONLKZCp8aVEjkQcUsmkhANCQiuAi5rkiSAMXBRHUYzUuM+P9lmaipDTkuUu8OIeEbmkGVe+6DthyDXmbsrDmARc57wUhSfc0f2ovDSiPE+KykvzxSb8tT5/vNjwDvjpibjKfLice7ELrs1cF+4ONAuQUyjHdaliSLngtYVLC6Y4oq47+mrVpyiiKMYFy5kpGtq0urp6DzZbsF11hn+vrsL23urozNwBY3rm6drq1vikMzq5NQ4HrJz6cBrLoUUjdeP64zybdOzQwuViH8C7sqqn/CRp7/dxf21t7QlsH8N2DVJ+z2y/gANnzI7J3Wa7NrKsTN8zFluPq/275uDW6KRx3M9HXxuGAwE+HX9tbFovW6TM8km0n9ffPUdOHOcDvNNlmdHIXNm3/h8/szY9UNkide26RBKMLfplaIy9HervOAWN4lVGVtc1elP6w9kiLbFNhcs8jMvWlHGMrcgjc4sGu/p18hC19aQh61siBe9z262MGcat9YjKwUuteq2yAGN7Nz1cba01PMjaIWUAwcZNicJ9ZdElyMN2RWRhbpFlF6N7C35Ub9XdhvasHdLK+6z8qt+3geVhu+zsJrgMLft4rq7Z2S1M+0MqaYmJlWP1rDq5Bbauh3Bm+9L70JGu7QspQrllVs2tqpPIFqkzyG2fQ51VS8NFqWr118gWaWaZVe3s7JGWvaN4Lb0v2SL1LCFkVnb+jk39SeXl7pGun5vXXPHBH9x4PqtPv501cy7NB7XecvkWpOtnh7r0O/zleLeucGOXRif/9aux3dlLdYkdntr88vrYqiaGzublO2P1B18Pd15cb0nJtjafn57Tw09nvn7u+cqFOa2sPJ+OzfqDhzVhPdi55b0z0ot3rpwf6vf430Z7Ly+fnbO7dPXl+Vld+ebb2Uvfujxv9t2dzdnQrl3ZWFoeaun3eGO4s/HyNzsm5PW3T184OacLF/59iim7sTJvBFr5dDLO68C9JqyVG/NetX7tV/P6w7wbXD2/vDTSH/+0Md5d+nCW6aWvTyzNa/nmL9Nmrz5YrjP77uK02eb55RM1Wn45fytrdO5hLauTFz6d5HC9yerhuQmra/XcT56+OHvV9T98dur9WZ069etZw7Mvl+rSduLE1Zm7dPVmreHS+SkHXL9TH9zSN9NFxJ36iy7dfGGDdLMBwzSsJquTpycd63mD0cqnMxd1HnwGAGv0q5li7WI9qRNLH8wY/rkewomNKQhnrzQEd3s663/YcCOXLtsgvdgIa/dIP2owunBj5qLrv64neuqzGTf9l9oMCGm7MlPzfNAAYfnqpNWt2w1mG9Ol5IcNl11+C5GeutYh7ZB2SDukHdIOaYe0Q9oh7ZB2SDukHdIOaYf0wJDyhN5qQ+qYecl2i/T97+fI7h4pE5ZIHWGDtND8AJEyDhJCSIz/0fC+9OTpV0MbEeCeakfKYz6B9ONHOyGlhFkhxYmwQ5oFXzYjZaN0uHgQiqaXe/tEuvKXLMvzXr8/MJ0S8F9roT78a569tvnbf7QhZbjUzBKpwh5pRnqdEkJUHLvadJmQzUgZpSPDEuP//KEB6d89SGw5Sgf+r3qr/SPFg7LMs8wLDNF6Wg/xoF/2jE0P496PDUGd/DGJgsDEuYpywreR/vToJ7Pz06NHP4+Rpn4UQPp6Q9s/NiD97p94Uv9oevu/fLWYMsQ/1BpufIWH6fAyc4cab9B+M77J95wxhgbZX5qsPno1skn6mn1U88tT5aX/bRIzMPRhW6bONtKPK0f9+dGp998fI83HtuBU+KsGr1q68mWUhGGqtaswzmPe7KUiGBvm+G9/2qi327g4SqtDsS8PocYvnFetNT7lO1RPzylISik4xqmYzPhA9dSp73+ayPiSVLZCFLhfPLCo8bHL7Wr8LN1sr56E6QTwVjWi0pnqyTjqz+ClH388V5YK164RVf1Ga4PUeXMaUYtEOtQU0u8fnfp+oiyd0nFtlx4Y0qEeffzuNfUPGOlUQ6pD2j2Qdkg7pB3SDmmHtEPaIe2Qdkg7pG810uvHBOn5BrOb00gbekHPdABuUrP/TSJdb+rgfGOyN3LTrylzvaCtkV7caED6ck+9oM+9rL9FS1emRwl8a9VNvUnXTzd41vRgkc26QRInZwZJPGi4PSszoBqRvj/bC/pWQ0/w5TuzffXr3y0v3Z4K0HlRa7a0cWd65Mn61xt1oyRu/mI33v5BLdMZWID+xvagm8ahPOsP65heWHk+Nzyqsa/+7Eilq7eX5yEsLX1wa8bu3J/h6Jzd8sYMhPVfbm+MBt8sn9jeuX159rLr1z64MqvvvvnWCihC7NrD0yuVfny4Mtbp53NDfPj6SNG58d76LKpzN14HsR3UwxvzQ7XWb9iNKEHrL17evjnSD+Od89/MZ8Bbd87fnNXtly9mWbGLLy4Pded/RjuXX8yNiIObvn5ppH+Ody61jYmb0NlXm69ewX/8u1cjbV6vGdc2ktxptgV2fTuIf462m9frcovtuCeI3eZvh/rN3/93tHerLnHr14cn/++3r3V2p7k2rCahQNRycolaKbuh/AEe2JQoyi7GuxDBFrM9ItZrn3rIiOd2k0FkFtN6NMYlw9hiUjmBMW6YNWgqNK99NoxdCuJnMYRZ4Z7VvYwxHlh4ELWeA6U2LqD20b+J+Um4fehvgfEeBgjvJDNtRPvaJ2a+Dpt7yUwvgazVzDF9GKzH488qTnHPbb1vTCfYi9thmV/jbfKptZwIbmXZGj9ld8eR0B5O4tYszdwQR+7efcNuxglqg0p6JfZ2tZ5OmxypcFQ30d60Eij7AquR9hpbzVhCbYq5Jjl9Kw+ndsW6tp5CwFbS6sLconSopO1YUauplBpkOeMEtYuz3k/jo1YSBxZ1r8Ce3ewurh2r/SG1m3FC2rQLrJ1gF5KDnaZ0G4sPMrvaJLZrkuwTqdUES5ZILWO8C1kiLS0r6ENB6uVWSHOrjL94pGKQW2Rpa6TqEJDazmbJrJ4mYrzota25bm/0gFV/YNmMtEvH/pCCVJsHBrZzgDF5RAtftDe0dqW9IjX+6TjMzDI9+mNqwv6xUmaP9LBn8nSQM76HrbltPM2ozXSjBmllNh1o2zdd349RFmjPIQFKk9RnmKPe7AO1U/S1iJQL7U2q21pIZvr6Q5VHxKiJkU4sg5zUvhUIR5YWkzsjToQw9bH2J5nylpdCwpMy4x4FFyQRg2o1ISnh8807M5d9kHBPyMgskrtzmO3Pz3YiQUR1kvmhoKbZ7Pm+5n6UJVPXd6QICDOTmUtoRREmpBAMMSEyBeckn7ENlZn83EHpZJ4WFIKHABxHwNfR0AKyrUBFyRFX5i8TFBQsTPDSXKy5NqceKSjyZIVUhlC1xDRRNUUItPjMRNNEe2G0EzKdSEiMGTujIenwiOxKpPY2VS3PmFmC04OwEi+u1mWA7COm58Z2fK0HlATA1ddi4FKsSU4gXTLXPIrV5FouLEwVPIrmSidTSN2A5OCPJA35wIdvsEi7HnhGEdC4L2lAwF4FKuNBDAHzkpGAeI2loJnKm6JsiNTpGW6OV9Z4tkFKzSrDOz9rkhBRUUreMwuNpK6OeRnzbG+FK8tMshkgdQIBYUB1Dk09OY3UrEtg4l4EJtKl4NCO8gBpACnhLiLexNO8WVsjLMyCO4GcQGqSLT1qFiPgcAdUWCRwE4kpQzRcTQmhBqJnBmP5CskAvBR4subsT4OEIF+gEKKFSARfQmmdF2bgyxRFBRQTO83q5mShAxmfDxyzhEUkAxqG0rdkOCsRmLgDUqXNLR34EIlZpGbi7QS8FOkgIAZp9hopCyMdTTzIGVS6CL0wCcgEUlkmSRKg2AuUWXiDRmkWhgGpamO4Go0CgmVpLKMKaU8a17ea9nc8j21YV//ImFNwaIForHZ0OpXJCincdJowL6ZRbDlRbY2KvELqB2kvBS+l0RzS2K+81DQAZe6UUIRys2xMhdQsXBJNJIaaZV8KA7MQE0gFPEY45k/ZF1CLxKGG20RkZLyU5igjUCMZL5UkURDwyEsH7Q1FGo2wO3ofrTviFj7U+BxDzGgkkS65u9f30RxyeWaQirJafnsgTFk6gxRFWmNASgOapghLAblYeTQI4ITANB5MlJlOCmWpQr0i9pzJGl8HNJOQ9XXEylAFggdunJlVp5JY9GQaUR87KiKeJB6JAqjxYesd3tJnlJo6xTFrPptqFBq7fK/zqDM/icAdAlGYhzGob7jJ5rPrcDuUCmboU1MqsGpRArOOCeLUEVTyqfsJtny4rtnUU0EVVQjAAf8zg7hYtfIZr0KU5gJm/eqhDRfmChDwHhP1zklYvso6dpI+1CfpgST++C0h2eng9PjJUcegVU2zsL+pevaGrEuzg5rWCnhT9ezzo45Bm5y3DemTL446Bq1a9JLhB637X7whyycdL50xuofQltmaXHbf7IBvOGb7FJh/Mj7ztNUU3TuzOjY0S9RNGd4bXw22q2Y7aeqMTVdnTd8+ff4eCLCsmq1pAZyB7eNPgIc58wwSd9+ceQpnno1Mt2ZNnbHps7tV2ffEHAdOW49hx2B5ag6YLGwOvAfb+7DzeNL03pTp6oTpO65na0+POgrHTWcWvdBvp6cd0kXr6d2G9ZE67Vkd0U6dOnXq1Ok4ypHVL3Pzx6Pul7g9SpRE1fayCTqke5TpO8V7DCkzUst8yFgxxGPiSURjigSNO7a7U9UdbSDSQmTM1VJLj8cRigrSFyRlPqV98o7+wLxnVUhLWabpgJhOXilBLKMBZHwRRKmn6V57ar27MkhZZvr+wAaZrnjbSD2T48lCx2S+ExI9JiKFEsVdlrrCpR5TieNROhAqYUR2SHctprQ23uiaT1cLRM2oFKEJFKFEKyQOr59Yp06dOnXqdAT6f2PCMLGwm5Z8AAAAAElFTkSuQmCC)",
"text/plain": [
"<IPython.core.display.Markdown object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cJ9Cfl5D9qYN",
"colab_type": "code",
"colab": {}
},
"source": [
"if not colab:\n",
" # if on local machine \n",
" root_dir='.'\n",
" \n",
"else:\n",
" # if using google colab use this code\n",
" from google.colab import drive\n",
" drive.mount('/content/drive')\n",
" root_dir = \"/content/drive/My Drive/Colab Notebooks\""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Ct0UBp8y9qYS",
"colab_type": "code",
"colab": {}
},
"source": [
"data_path = os.path.join(root_dir, \"fra.csv\")\n",
"doc = pd.read_csv(data_path, nrows=training_samples)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tvVhQmsq9qYY",
"colab_type": "code",
"colab": {}
},
"source": [
"# replace contracted forms for english words\n",
"contracted_dict={\"won't\" : \"will not\", \"can\\'t\" : \"can not\", \"n\\'t\" : \" not\", \"\\'re\" : \" are\", \"\\'s\" : \" is\", \"\\'d\" : \" would\", \"\\'ll\" : \" will\", \"\\'t\" : \" not\", \"\\'ve\" : \" have\", \"\\'m\" : \" am\"}\n",
"\n",
"def replace_contracted(text):\n",
"\n",
" regex = re.compile(\"|\".join(map(re.escape, contracted_dict.keys( ))))\n",
" return regex.sub(lambda match: contracted_dict[match.group(0)], text)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WPWsAj-u9qYd",
"colab_type": "code",
"colab": {}
},
"source": [
"# apply decontraction and lowercase\n",
"doc=doc.apply(np.vectorize(lambda sent : replace_contracted(str(sent).strip().lower())))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "54f6A85j9qYg",
"colab_type": "code",
"colab": {}
},
"source": [
"# tokenize sentences and add start_ and _end keyword to target sentences\n",
"source_sents=doc.Source.apply(lambda x: x + ' _END').apply(lambda sent: word_tokenize(sent))\n",
"target_sents=doc.Target.apply(lambda x : 'START_ '+ x + ' _END').apply(lambda sent: word_tokenize(sent))\n",
"temp = list(zip(source_sents, target_sents)) \n",
"random.shuffle(temp) \n",
"source_sents, target_sents = zip(*temp)\n",
"source_sents, target_sents = pd.Series(source_sents), pd.Series(target_sents)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"scrolled": true,
"id": "c0PwKMAA9qYj",
"colab_type": "code",
"colab": {}
},
"source": [
"del(doc)\n",
"# building the vocabulary\n",
"source_vocab=set().union(*source_sents)\n",
"target_vocab=set().union(*target_sents)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KSuPWE2C9qYo",
"colab_type": "code",
"colab": {}
},
"source": [
"# max sentence length for each language in the dataset\n",
"max_source_len=max(source_sents.apply(len))\n",
"max_target_len=max(target_sents.apply(len))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Yhpt_X7Y9qYr",
"colab_type": "code",
"colab": {}
},
"source": [
"# numeric identity for each word in vocab\n",
"source_wordint_rel=bidict(enumerate(source_vocab, 1))\n",
"temp={0:'paddingZero'}\n",
"temp.update(dict(enumerate(target_vocab, 1)))\n",
"target_wordint_rel=bidict(temp)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "h0_0s05X9qYv",
"colab_type": "code",
"colab": {}
},
"source": [
"# prepare inputs and outputs\n",
"encoder_source_arr=[list(map(lambda word : source_wordint_rel.inv[word], sent)) for sent in source_sents]\n",
"decoder_source_arr=[list(map(lambda word : target_wordint_rel.inv[word], sent)) for sent in target_sents]\n",
"decoder_output_arr=[list(map(lambda word : target_wordint_rel.inv[word], sent[1:])) for sent in target_sents]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"scrolled": true,
"id": "WtaBqNMZ9qYy",
"colab_type": "code",
"colab": {}
},
"source": [
"# pad the inputs and outputs to max length\n",
"padded_encoder_source_arr=pad_sequences(encoder_source_arr, maxlen=max_source_len, padding='post')\n",
"padded_decoder_source_arr=pad_sequences(decoder_source_arr, maxlen=max_target_len, padding='post')\n",
"padded_decoder_output_arr=pad_sequences(decoder_output_arr, maxlen=max_target_len, padding='post')\n",
"onehotted_decoder_output_arr=tf.one_hot(padded_decoder_output_arr, len(target_vocab)+1).numpy()\n",
"\n",
"del encoder_source_arr, decoder_source_arr, decoder_output_arr, padded_decoder_output_arr"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "T9AF9kdc9qY1",
"colab_type": "text"
},
"source": [
"# Model Preparation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "YH3L903Z9qY2",
"colab_type": "code",
"colab": {}
},
"source": [
"# context-vector length\n",
"latent_dim=ctx_vec_len\n",
"\n",
"# this is the source languge consumtion layer\n",
"encoder_inputs = Input(shape=(None,), name='encoder_sources')\n",
"# embed the 2-d source into 3-d\n",
"enc_emb = Embedding(len(source_vocab)+1, embedding_dim, mask_zero = True, name='enc_emb')(encoder_inputs)\n",
"\n",
"# LSTM layer to encode the source sentence into context-vector representation\n",
"encoder_lstm = Bidirectional(LSTM(latent_dim, return_state=True, return_sequences=True, name='encoder_lstm1', dropout=dropout), name='encoder_bi-lstm1', merge_mode=\"concat\")\n",
"\n",
"encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder_lstm(enc_emb)\n",
"encoder_states = [forward_h, forward_c, backward_h, backward_c]\n",
"\n",
"encoder_lstm1 = Bidirectional(LSTM(latent_dim, return_state=True, name='encoder_lstm2', dropout=dropout), name='encoder_bi-lstm2', merge_mode=\"concat\")\n",
"encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder_lstm1(encoder_outputs, initial_state=encoder_states)\n",
"\n",
"state_h = Concatenate()([forward_h, backward_h])\n",
"state_c = Concatenate()([forward_c, backward_c])\n",
"# encoded-states tensor stores the context-vector\n",
"encoder_states = [state_h, state_c]"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WOssNVT79qY6",
"colab_type": "code",
"colab": {}
},
"source": [
"# this is the target languge consumtion layer\n",
"decoder_inputs = Input(shape=(None,), name='decoder_sources')\n",
"# embed the 2-d source into 3-d\n",
"dec_emb_layer = Embedding(len(target_vocab)+1, embedding_dim, mask_zero = True, name='dec_emb_layer')\n",
"dec_emb = dec_emb_layer(decoder_inputs)\n",
"\n",
"# decoder LSTM, this takes in the context-vector and starting or so-far decoded part of the target sentence\n",
"decoder_lstm1 = LSTM(latent_dim, return_sequences=True, name='decoder_lstm1', dropout=dropout)\n",
"decoder_outputs11 = decoder_lstm1(dec_emb)\n",
"decoder_lstm2 = LSTM(latent_dim, return_sequences=True, name='decoder_lstm2', dropout=dropout)\n",
"decoder_outputs12 = decoder_lstm2(decoder_outputs11)\n",
"decoder_lstm3 = LSTM(latent_dim*2, return_sequences=True, return_state=True, name='decoder_lstm', dropout=dropout)\n",
"decoder_outputs13, _, _ = decoder_lstm3(decoder_outputs12, initial_state=encoder_states)\n",
"\n",
"# final layer that gives a probabilty distribution of the next possible words\n",
"decoder_dense = Dense(len(target_vocab)+1, activation='softmax', name='decoder_dense')\n",
"decoder_outputs14 = decoder_dense(decoder_outputs13)\n",
"\n",
"# Encode the source sequence to get the \"Context vectors\"\n",
"encoder_model = Model(encoder_inputs, encoder_states, name='Model_Encoder')\n",
"encoder_model.summary()\n",
"plot_model(encoder_model, show_shapes=True, show_layer_names=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "rYR6yV_KZgvG",
"colab_type": "text"
},
"source": [
"# Custom Loss Function to get rid of padding"
]
},
{
"cell_type": "code",
"metadata": {
"id": "VNE30xzaZZX7",
"colab_type": "code",
"colab": {}
},
"source": [
"vocab_len=len(onehotted_decoder_output_arr[0][0])\n",
"\n",
"def PaddedCategoricalCrossentropy(eps=1e-12):\n",
" def loss(y_true, y_pred):\n",
" mask_value = np.zeros((vocab_len))\n",
" mask_value[0] = 1\n",
" # find out which timesteps in `y_true` are not the padding character \n",
" mask = K.equal(y_true, mask_value)\n",
" mask = 1 - K.cast(mask, K.floatx())\n",
" mask = K.sum(mask,axis=2)/2\n",
" # multplying the loss by the mask. the loss for padding will be zero\n",
" loss = tf.keras.layers.multiply([K.categorical_crossentropy(y_true, y_pred), mask])\n",
" return K.sum(loss) / K.sum(mask)\n",
" return loss"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tiYu7bVY9qY9",
"colab_type": "code",
"colab": {}
},
"source": [
"# model building and summary\n",
"model = Model([encoder_inputs, decoder_inputs], decoder_outputs14, name='Model_Translation')\n",
"model.compile(optimizer='Adam', loss=PaddedCategoricalCrossentropy(), metrics=['acc'])\n",
"model.summary()\n",
"plot_model(model, show_shapes=True, show_layer_names=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "wjYkSeMjVz04",
"colab_type": "text"
},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"metadata": {
"id": "g_dxY6fL9qZB",
"colab_type": "code",
"colab": {}
},
"source": [
"# TensorBoard Callback \n",
"tbCallBack = TensorBoard(log_dir=os.path.join(root_dir, 'Graph'), histogram_freq=0, write_graph=True, write_images=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"scrolled": true,
"id": "6SsaUZXF9qZF",
"colab_type": "code",
"colab": {}
},
"source": [
"if training:\n",
" # train the model\n",
" history=model.fit([padded_encoder_source_arr, padded_decoder_source_arr], onehotted_decoder_output_arr, epochs=epochs, validation_split=0.02, callbacks=[tbCallBack], batch_size=batch_size)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "a6-S3Z9QH3iX",
"colab_type": "code",
"colab": {}
},
"source": [
"if training:\n",
" model.save_weights(os.path.join(root_dir, weight_file))\n",
" with plt.style.context('dark_background'):\n",
" plt.plot(history.history['acc'])\n",
" plt.plot(history.history['val_acc'])\n",
" plt.title('model accuracy')\n",
" plt.ylabel('accuracy')\n",
" plt.xlabel('epoch')\n",
" plt.legend(['train', 'val'], loc='upper left')\n",
" plt.show()\n",
" plt.plot(history.history['loss'])\n",
" plt.plot(history.history['val_loss'])\n",
" plt.title('model loss')\n",
" plt.ylabel('loss')\n",
" plt.xlabel('epoch')\n",
" plt.legend(['train', 'val'], loc='upper left')\n",
" plt.show()\n",
" print(f'Accuracy while saving is {model.evaluate([padded_encoder_source_arr, padded_decoder_source_arr], onehotted_decoder_output_arr)}')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "CZpJP5Fk9qZI",
"colab_type": "text"
},
"source": [
"# Decoder Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ksm-Vbao9qZJ",
"colab_type": "code",
"colab": {}
},
"source": [
"# Decoder setup\n",
"# Below tensors will hold the states of the previous time step\n",
"state_h = Input(shape=(latent_dim*2,))\n",
"state_c = Input(shape=(latent_dim*2,))\n",
"\n",
"\n",
"decoder_state_input = [state_h, state_c]\n",
"# Get the embeddings of the decoder sequence\n",
"dec_emb2= dec_emb_layer(decoder_inputs)\n",
"# To predict the next word in the sequence, set the initial states to the states from the previous time step\n",
"decoder_outputs21 = decoder_lstm1(dec_emb2)\n",
"decoder_outputs22 = decoder_lstm2(decoder_outputs21)\n",
"decoder_outputs23, state_h2, state_c2 = decoder_lstm3(decoder_outputs22, initial_state=decoder_state_input)\n",
"decoder_states2 = [state_h2, state_c2]\n",
"# A dense softmax layer to generate prob dist. over the target vocabulary\n",
"decoder_outputs24 = decoder_dense(decoder_outputs23)\n",
"# Final decoder model\n",
"decoder_model = Model(\n",
" [decoder_inputs] + decoder_state_input,\n",
" [decoder_outputs24] + decoder_states2, name='Model_Decoder')\n",
"decoder_model.summary()\n",
"plot_model(decoder_model, show_shapes=True, show_layer_names=True)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "5yRFPYsS9qZM",
"colab_type": "text"
},
"source": [
"# Decoding Logic"
]
},
{
"cell_type": "code",
"metadata": {
"id": "_hAps5WG9qZN",
"colab_type": "code",
"colab": {}
},
"source": [
"def decode_sequence(source_seq):\n",
" \n",
" # Encode the source as state vectors.\n",
" states_value = encoder_model.predict(source_seq)\n",
" # Generate empty target sequence of length 1.\n",
" target_seq = np.zeros((1,1))\n",
" # Populate the first character of \n",
" #target sequence with the start character.\n",
" target_seq[0, 0] = target_wordint_rel.inv['START_']\n",
" # Sampling loop for a batch of sequences\n",
" # (to simplify, here we assume a batch of size 1).\n",
" stop_condition = False\n",
" decoded_sentence = []\n",
" while not stop_condition:\n",
" output_tokens, h, c = decoder_model.predict([target_seq] + states_value)\n",
" # Sample a token\n",
" sampled_token_index = np.argmax(output_tokens[0, -1, :])\n",
" sampled_word =target_wordint_rel[sampled_token_index]\n",
" decoded_sentence += [sampled_word]\n",
" # Exit condition: either hit max length\n",
" # or find stop character.\n",
" if (sampled_word == '_END' or\n",
" len(decoded_sentence) > 50):\n",
" stop_condition = True\n",
" # Update the target sequence (of length 1).\n",
" target_seq = np.zeros((1,1))\n",
" target_seq[0, 0] = sampled_token_index\n",
" # Update states\n",
" states_value = [h, c]\n",
"\n",
" return decoded_sentence"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "EZFK5mg_9qZQ",
"colab_type": "text"
},
"source": [
"# Prediction"
]
},
{
"cell_type": "code",
"metadata": {
"scrolled": false,
"id": "mWwKlX5a9qZQ",
"colab_type": "code",
"colab": {}
},
"source": [
"start=1000\n",
"offset=100\n",
"def calc_strdiff(true, pred):\n",
" # return sum([1 for char in list(difflib.ndiff(true, pred)) if '+ ' in char or '- ' in char])/(len(true))\n",
" return nltk.translate.bleu_score.sentence_bleu([word_tokenize(true)], word_tokenize(pred))\n",
" \n",
"if validation:\n",
" \n",
" model.load_weights(os.path.join(root_dir, weight_file))\n",
" print(f'Accuracy after loading is {model.evaluate([padded_encoder_source_arr, padded_decoder_source_arr], onehotted_decoder_output_arr)}')\n",
" y_truePred = [(' '.join(source_sents[seq_index][:-1]), ' '.join(target_sents[seq_index][1:-1]), ' '.join(decode_sequence(padded_encoder_source_arr[seq_index:seq_index+1])[:-1])) for seq_index, _ in enumerate(padded_encoder_source_arr[start:start+offset], start)]\n",
" bleu_score=[calc_strdiff(true, pred) for _, true, pred, in y_truePred]\n",
" print(f'Bleu Scores are {bleu_score}')\n",
" print(f'Avg bleu score for {len(y_truePred)} tests was {sum(bleu_score)/len(y_truePred)}.')\n",
" print(f\"{pd.DataFrame(y_truePred, columns=['Source', 'Expected', 'Predicted'])}\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Nsddt0BCzEJz",
"colab_type": "code",
"colab": {}
},
"source": [
"pd.DataFrame(y_truePred, columns=['Source', 'Expected', 'Predicted']).to_excel(os.path.join(root_dir, 'review.xlsx'))"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment