Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save elijah123815/4b235bd3d789d5218cd1a5aa3477ac8a to your computer and use it in GitHub Desktop.
Save elijah123815/4b235bd3d789d5218cd1a5aa3477ac8a to your computer and use it in GitHub Desktop.
Created on Skills Network Labs
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://www.skills.network/\"><img src=\"https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBM-DL0120ENedX/labs/Template%20for%20Instructional%20Hands-on%20Labs/images/IDSNlogo.png\" width=\"400px\" align=\"center\"></a>\n",
"\n",
"<h1 align=\"center\"><font size=\"5\">RECURRENT NETWORKS and LSTM IN DEEP LEARNING</font></h1>\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>Applying Recurrent Neural Networks/LSTM for Language Modeling</h2>\n",
"Hello and welcome to this part. In this notebook, we will go over the topic of Language Modelling, and create a Recurrent Neural Network model based on the Long Short-Term Memory unit to train and benchmark on the Penn Treebank dataset. By the end of this notebook, you should be able to understand how TensorFlow builds and executes a RNN model for Language Modelling.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>The Objective</h2>\n",
"By now, you should have an understanding of how Recurrent Networks work -- a specialized model to process sequential data by keeping track of the \"state\" or context. In this notebook, we go over a TensorFlow code snippet for creating a model focused on <b>Language Modelling</b> -- a very relevant task that is the cornerstone of many different linguistic problems such as <b>Speech Recognition, Machine Translation and Image Captioning</b>. For this, we will be using the Penn Treebank dataset, which is an often-used dataset for benchmarking Language Modelling models.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>Table of Contents</h2>\n",
"<ol>\n",
" <li><a href=\"#language_modelling\">What exactly is Language Modelling?</a></li>\n",
" <li><a href=\"#treebank_dataset\">The Penn Treebank dataset</a></li>\n",
" <li><a href=\"#word_embedding\">Word Embedding</a></li>\n",
" <li><a href=\"#building_lstm_model\">Building the LSTM model for Language Modeling</a></li>\n",
" <li><a href=\"#ltsm\">LTSM</a></li>\n",
"</ol>\n",
"<p></p>\n",
"</div>\n",
"<br>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"* * *\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"language_modelling\"></a>\n",
"\n",
"<h2>What exactly is Language Modelling?</h2>\n",
"Language Modelling, to put it simply, <b>is the task of assigning probabilities to sequences of words</b>. This means that, given a context of one or a sequence of words in the language the model was trained on, the model should provide the next most probable words or sequence of words that follows from the given sequence of words the sentence. Language Modelling is one of the most important tasks in Natural Language Processing.\n",
"\n",
"<img src=\"https://ibm.box.com/shared/static/1d1i5gub6wljby2vani2vzxp0xsph702.png\" width=\"1080\">\n",
"<center><i>Example of a sentence being predicted</i></center>\n",
"<br><br>\n",
"In this example, one can see the predictions for the next word of a sentence, given the context \"This is an\". As you can see, this boils down to a sequential data analysis task -- you are given a word or a sequence of words (the input data), and, given the context (the state), you need to find out what is the next word (the prediction). This kind of analysis is very important for language-related tasks such as <b>Speech Recognition, Machine Translation, Image Captioning, Text Correction</b> and many other very relevant problems. \n",
"\n",
"<img src=\"https://ibm.box.com/shared/static/az39idf9ipfdpc5ugifpgxnydelhyf3i.png\" width=\"1080\">\n",
"<center><i>The above example is a schema of an RNN in execution</i></center>\n",
"<br><br>\n",
"As the above image shows, Recurrent Network models fit this problem like a glove. Alongside LSTM and its capacity to maintain the model's state for over one thousand time steps, we have all the tools we need to undertake this problem. The goal for this notebook is to create a model that can reach <b>low levels of perplexity</b> on our desired dataset.\n",
"\n",
"For Language Modelling problems, <b>perplexity</b> is the way to gauge efficiency. Perplexity is simply a measure of how well a probabilistic model is able to predict its sample. A higher-level way to explain this would be saying that <b>low perplexity means a higher degree of trust in the predictions the model makes</b>. Therefore, the lower perplexity is, the better.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"treebank_dataset\"></a>\n",
"\n",
"<h2>The Penn Treebank dataset</h2>\n",
"Historically, datasets big enough for Natural Language Processing are hard to come by. This is in part due to the necessity of the sentences to be broken down and tagged with a certain degree of correctness -- or else the models trained on it won't be able to be correct at all. This means that we need a <b>large amount of data, annotated by or at least corrected by humans</b>. This is, of course, not an easy task at all.\n",
"\n",
"The Penn Treebank, or PTB for short, is a dataset maintained by the University of Pennsylvania. It is <i>huge</i> -- there are over <b>four million and eight hundred thousand</b> annotated words in it, all corrected by humans. It is composed of many different sources, from abstracts of Department of Energy papers to texts from the Library of America. Since it is verifiably correct and of such a huge size, the Penn Treebank is commonly used as a benchmark dataset for Language Modelling.\n",
"\n",
"The dataset is divided in different kinds of annotations, such as Piece-of-Speech, Syntactic and Semantic skeletons. For this example, we will simply use a sample of clean, non-annotated words (with the exception of one tag --<code>&lt;unk></code>\n",
", which is used for rare words such as uncommon proper nouns) for our model. This means that we just want to predict what the next words would be, not what they mean in context or their classes on a given sentence.\n",
"\n",
"<center>Example of text from the dataset we are going to use, <b>ptb.train</b></center>\n",
"<br><br>\n",
"\n",
"<div class=\"alert alert-block alert-info\" style=\"margin-top: 20px\">\n",
" <center>the percentage of lung cancer deaths among the workers at the west <code>&lt;unk&gt;</code> mass. paper factory appears to be the highest for any asbestos workers studied in western industrialized countries he said \n",
" the plant which is owned by <code>&lt;unk&gt;</code> & <code>&lt;unk&gt;</code> co. was under contract with <code>&lt;unk&gt;</code> to make the cigarette filters \n",
" the finding probably will support those who argue that the U.S. should regulate the class of asbestos including <code>&lt;unk&gt;</code> more <code>&lt;unk&gt;</code> than the common kind of asbestos <code>&lt;unk&gt;</code> found in most schools and other buildings dr. <code>&lt;unk&gt;</code> said</center>\n",
"</div>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<a id=\"word_embedding\"></a>\n",
"\n",
"<h2>Word Embeddings</h2><br/>\n",
"\n",
"For better processing, in this example, we will make use of <a href=\"https://www.tensorflow.org/tutorials/word2vec/\"><b>word embeddings</b></a>, which is <b>a way of representing sentence structures or words as n-dimensional vectors (where n is a reasonably high number, such as 200 or 500) of real numbers</b>. Basically, we will assign each word a randomly-initialized vector, and input those into the network to be processed. After a number of iterations, these vectors are expected to assume values that help the network to correctly predict what it needs to -- in our case, the probable next word in the sentence. This is shown to be a very effective task in Natural Language Processing, and is a commonplace practice.\n",
"<br><br>\n",
"<font size=\"4\"><strong>\n",
"$$Vec(\"Example\") = [0.02, 0.00, 0.00, 0.92, 0.30, \\ldots]$$\n",
"</strong></font>\n",
"<br>\n",
"Word Embedding tends to group up similarly used words <i>reasonably</i> close together in the vectorial space. For example, if we use T-SNE (a dimensional reduction visualization algorithm) to flatten the dimensions of our vectors into a 2-dimensional space and plot these words in a 2-dimensional space, we might see something like this:\n",
"\n",
"<img src=\"https://ibm.box.com/shared/static/bqhc5dg879gcoabzhxra1w8rkg3od1cu.png\" width=\"800\">\n",
"<center><i>T-SNE Mockup with clusters marked for easier visualization</i></center>\n",
"<br><br>\n",
"As you can see, words that are frequently used together, in place of each other, or in the same places as them tend to be grouped together -- being closer together the higher they are correlated. For example, \"None\" is pretty semantically close to \"Zero\", while a phrase that uses \"Italy\", you could probably also fit \"Germany\" in it, with little damage to the sentence structure. The vectorial \"closeness\" for similar words like this is a great indicator of a well-built model.\n",
"\n",
"<hr>\n",
" \n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"We need to import the necessary modules for our code. We need <b><code>numpy</code></b> and <b><code>tensorflow</code></b>, obviously. Additionally, we can import directly the <b><code>tensorflow.models.rnn</code></b> model, which includes the function for building RNNs, and <b><code>tensorflow.models.rnn.ptb.reader</code></b> which is the helper module for getting the input data from the dataset we just downloaded.\n",
"\n",
"If you want to learn more take a look at <https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/reader.py>\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: tensorflow==2.2.0rc0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (2.2.0rc0)\n",
"Requirement already satisfied: gast==0.3.3 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (0.3.3)\n",
"Requirement already satisfied: opt-einsum>=2.3.2 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (2.3.2)\n",
"Requirement already satisfied: keras-preprocessing>=1.1.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (1.1.2)\n",
"Requirement already satisfied: numpy<2.0,>=1.16.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (1.16.4)\n",
"Requirement already satisfied: protobuf>=3.8.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (3.11.3)\n",
"Requirement already satisfied: wheel>=0.26; python_version >= \"3\" in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (0.34.2)\n",
"Requirement already satisfied: wrapt>=1.11.1 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (1.12.1)\n",
"Requirement already satisfied: grpcio>=1.8.6 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (1.24.3)\n",
"Requirement already satisfied: six>=1.12.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (1.12.0)\n",
"Requirement already satisfied: google-pasta>=0.1.8 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (0.2.0)\n",
"Requirement already satisfied: scipy==1.4.1; python_version >= \"3\" in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (1.4.1)\n",
"Requirement already satisfied: absl-py>=0.7.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (0.9.0)\n",
"Requirement already satisfied: tensorflow-estimator<2.2.0,>=2.1.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (2.1.0)\n",
"Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (2.10.0)\n",
"Requirement already satisfied: tensorboard<2.2.0,>=2.1.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (2.1.1)\n",
"Requirement already satisfied: termcolor>=1.1.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (1.1.0)\n",
"Requirement already satisfied: astunparse==1.6.3 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorflow==2.2.0rc0) (1.6.3)\n",
"Requirement already satisfied: setuptools in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from protobuf>=3.8.0->tensorflow==2.2.0rc0) (47.3.0)\n",
"Requirement already satisfied: werkzeug>=0.11.15 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (1.0.1)\n",
"Requirement already satisfied: google-auth<2,>=1.6.3 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (1.19.1)\n",
"Requirement already satisfied: requests<3,>=2.21.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (2.22.0)\n",
"Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (0.4.1)\n",
"Requirement already satisfied: markdown>=2.6.8 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (3.2.2)\n",
"Requirement already satisfied: pyasn1-modules>=0.2.1 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (0.2.8)\n",
"Requirement already satisfied: rsa<5,>=3.1.4; python_version >= \"3\" in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (4.6)\n",
"Requirement already satisfied: cachetools<5.0,>=2.0.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (4.1.1)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (2.8)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (2019.9.11)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (3.0.4)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from requests<3,>=2.21.0->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (1.25.6)\n",
"Requirement already satisfied: requests-oauthlib>=0.7.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (1.3.0)\n",
"Requirement already satisfied: importlib-metadata; python_version < \"3.8\" in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from markdown>=2.6.8->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (1.6.0)\n",
"Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (0.4.8)\n",
"Requirement already satisfied: oauthlib>=3.0.0 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (3.1.0)\n",
"Requirement already satisfied: zipp>=0.5 in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (from importlib-metadata; python_version < \"3.8\"->markdown>=2.6.8->tensorboard<2.2.0,>=2.1.0->tensorflow==2.2.0rc0) (3.1.0)\n",
"Requirement already satisfied: numpy in /Users/smadhavan/.pyenv/versions/3.7.4/lib/python3.7/site-packages (1.16.4)\n"
]
}
],
"source": [
"!pip install tensorflow==2.2.0rc0\n",
"!pip install numpy\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"if not tf.__version__ == '2.2.0-rc0':\n",
" print(tf.__version__)\n",
" raise ValueError('please upgrade to TensorFlow 2.2.0-rc0, or restart your Kernel (Kernel->Restart & Clear Output)')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"IMPORTANT! => Please restart the kernel by clicking on \"Kernel\"->\"Restart and Clear Outout\" and wait until all output disapears. Then your changes are beeing picked up\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Archive: data/ptb.zip\n",
" creating: data/ptb/\n",
" inflating: data/ptb/reader.py \n",
" creating: data/__MACOSX/\n",
" creating: data/__MACOSX/ptb/\n",
" inflating: data/__MACOSX/ptb/._reader.py \n",
" inflating: data/__MACOSX/._ptb \n"
]
}
],
"source": [
"!mkdir data\n",
"!mkdir data/ptb\n",
"!wget -q -O data/ptb/reader.py https://cf-courses-data.s3.us.cloud-object-storage.appdomain.cloud/IBMDeveloperSkillsNetwork-DL0120EN-SkillsNetwork/labs/Week3/data/ptb/reader.py\n",
"!cp data/ptb/reader.py . \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.2.0-rc0\n"
]
}
],
"source": [
"import reader"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<a id=\"building_lstm_model\"></a>\n",
"\n",
"<h2>Building the LSTM model for Language Modeling</h2>\n",
"Now that we know exactly what we are doing, we can start building our model using TensorFlow. The very first thing we need to do is download and extract the <code>simple-examples</code> dataset, which can be done by executing the code cell below.\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2020-08-31 00:12:47-- http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz\n",
"Resolving www.fit.vutbr.cz (www.fit.vutbr.cz)...147.229.9.23\n",
"Connecting to www.fit.vutbr.cz (www.fit.vutbr.cz)|147.229.9.23|:80...connected.\n",
"HTTP request sent, awaiting response...200 OK\n",
"Length: 34869662 (33M) [application/x-gtar]\n",
"Saving to: ‘simple-examples.tgz’\n",
"\n",
"simple-examples.tgz 100%[===================>] 33.25M 58.2KB/s in 8m 28s \n",
"\n",
"2020-08-31 00:21:16 (67.1 KB/s) - ‘simple-examples.tgz’ saved [34869662/34869662]\n",
"\n"
]
}
],
"source": [
"!wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz \n",
"!tar xzf simple-examples.tgz -C data/"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Additionally, for the sake of making it easy to play around with the model's hyperparameters, we can declare them beforehand. Feel free to change these -- you will see a difference in performance each time you change those! \n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"#Initial weight scale\n",
"init_scale = 0.1\n",
"#Initial learning rate\n",
"learning_rate = 1.0\n",
"#Maximum permissible norm for the gradient (For gradient clipping -- another measure against Exploding Gradients)\n",
"max_grad_norm = 5\n",
"#The number of layers in our model\n",
"num_layers = 2\n",
"#The total number of recurrence steps, also known as the number of layers when our RNN is \"unfolded\"\n",
"num_steps = 20\n",
"#The number of processing units (neurons) in the hidden layers\n",
"hidden_size_l1 = 256\n",
"hidden_size_l2 = 128\n",
"#The maximum number of epochs trained with the initial learning rate\n",
"max_epoch_decay_lr = 4\n",
"#The total number of epochs in training\n",
"max_epoch = 15\n",
"#The probability for keeping data in the Dropout Layer (This is an optimization, but is outside our scope for this notebook!)\n",
"#At 1, we ignore the Dropout Layer wrapping.\n",
"keep_prob = 1\n",
"#The decay for the learning rate\n",
"decay = 0.5\n",
"#The size for each batch of data\n",
"batch_size = 30\n",
"#The size of our vocabulary\n",
"vocab_size = 10000\n",
"embeding_vector_size= 200\n",
"#Training flag to separate training from testing\n",
"is_training = 1\n",
"#Data directory for our dataset\n",
"data_dir = \"data/simple-examples/data/\""
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Some clarifications for LSTM architecture based on the arguments:\n",
"\n",
"Network structure:\n",
"\n",
"<ul>\n",
" <li>In this network, the number of LSTM cells are 2. To give the model more expressive power, we can add multiple layers of LSTMs to process the data. The output of the first layer will become the input of the second and so on.\n",
" </li>\n",
" <li>The recurrence steps is 20, that is, when our RNN is \"Unfolded\", the recurrence step is 20.</li> \n",
" <li>the structure is like:\n",
" <ul>\n",
" <li>200 input units -> [200x200] Weight -> 200 Hidden units (first layer) -> [200x200] Weight matrix -> 200 Hidden units (second layer) -> [200] weight Matrix -> 200 unit output</li>\n",
" </ul>\n",
" </li>\n",
"</ul>\n",
"<br>\n",
"\n",
"Input layer: \n",
"\n",
"<ul>\n",
" <li>The network has 200 input units.</li>\n",
" <li>Suppose each word is represented by an embedding vector of dimensionality e=200. The input layer of each cell will have 200 linear units. These e=200 linear units are connected to each of the h=200 LSTM units in the hidden layer (assuming there is only one hidden layer, though our case has 2 layers).\n",
" </li>\n",
" <li>The input shape is [batch_size, num_steps], that is [30x20]. It will turn into [30x20x200] after embedding, and then 20x[30x200]\n",
" </li>\n",
"</ul>\n",
"<br>\n",
"\n",
"Hidden layer:\n",
"\n",
"<ul>\n",
" <li>Each LSTM has 200 hidden units which is equivalent to the dimensionality of the embedding words and output.</li>\n",
"</ul>\n",
"<br>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"There is a lot to be done and a ton of information to process at the same time, so go over this code slowly. It may seem complex at first, but if you try to apply what you just learned about language modelling to the code you see, you should be able to understand it.\n",
"\n",
"This code is adapted from the <a href=\"https://github.com/tensorflow/models\">PTBModel</a> example bundled with the TensorFlow source code.\n",
"\n",
"<h3>Training data</h3>\n",
"The story starts from data:\n",
"<ul>\n",
" <li>Train data is a list of words, of size 929589, represented by numbers, e.g. [9971, 9972, 9974, 9975,...]</li>\n",
" <li>We read data as mini-batch of size b=30. Assume the size of each sentence is 20 words (num_steps = 20). Then it will take $$floor(\\frac{N}{b \\times h})+1=1548$$ iterations for the learner to go through all sentences once. Where N is the size of the list of words, b is batch size, and h is size of each sentence. So, the number of iterators is 1548\n",
" </li>\n",
" <li>Each batch data is read from train dataset of size 600, and shape of [30x20]</li>\n",
"</ul>\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"# Reads the data and separates it into training data, validation data and testing data\n",
"raw_data = reader.ptb_raw_data(data_dir)\n",
"train_data, valid_data, test_data, vocab, word_to_id = raw_data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"929589"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(train_data)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter', '<eos>', 'pierre', '<unk>', 'N', 'years', 'old', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'nov.', 'N', '<eos>', 'mr.', '<unk>', 'is', 'chairman', 'of', '<unk>', 'n.v.', 'the', 'dutch', 'publishing', 'group', '<eos>', 'rudolph', '<unk>', 'N', 'years', 'old', 'and', 'former', 'chairman', 'of', 'consolidated', 'gold', 'fields', 'plc', 'was', 'named', 'a', 'nonexecutive', 'director', 'of', 'this', 'british', 'industrial', 'conglomerate', '<eos>', 'a', 'form', 'of', 'asbestos', 'once', 'used', 'to', 'make', 'kent', 'cigarette', 'filters', 'has', 'caused', 'a', 'high', 'percentage', 'of', 'cancer', 'deaths', 'among', 'a', 'group', 'of']\n"
]
}
],
"source": [
"def id_to_word(id_list):\n",
" line = []\n",
" for w in id_list:\n",
" for word, wid in word_to_id.items():\n",
" if wid == w:\n",
" line.append(word)\n",
" return line \n",
" \n",
"\n",
"print(id_to_word(train_data[0:100]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Lets just read one mini-batch now and feed our network:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"itera = reader.ptb_iterator(train_data, batch_size, num_steps)\n",
"first_touple = itera.__next__()\n",
"_input_data = first_touple[0]\n",
"_targets = first_touple[1]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
"(30, 20)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_input_data.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(30, 20)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_targets.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Lets look at 3 sentences of our input x:\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([[9970, 9971, 9972, 9974, 9975, 9976, 9980, 9981, 9982, 9983, 9984,\n",
" 9986, 9987, 9988, 9989, 9991, 9992, 9993, 9994, 9995],\n",
" [2654, 6, 334, 2886, 4, 1, 233, 711, 834, 11, 130,\n",
" 123, 7, 514, 2, 63, 10, 514, 8, 605],\n",
" [ 0, 1071, 4, 0, 185, 24, 368, 20, 31, 3109, 954,\n",
" 12, 3, 21, 2, 2915, 2, 12, 3, 21]],\n",
" dtype=int32)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_input_data[0:3]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim']\n"
]
}
],
"source": [
"print(id_to_word(_input_data[0,:]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<h3>Embeddings</h3>\n",
"We have to convert the words in our dataset to vectors of numbers. The traditional approach is to use one-hot encoding method that is usually used for converting categorical values to numerical values. However, One-hot encoded vectors are high-dimensional, sparse and in a big dataset, computationally inefficient. So, we use word2vec approach. It is, in fact, a layer in our LSTM network, where the word IDs will be represented as a dense representation before feeding to the LSTM. \n",
"\n",
"The embedded vectors also get updated during the training process of the deep neural network.\n",
"We create the embeddings for our input data. <b>embedding_vocab</b> is matrix of [10000x200] for all 10000 unique words.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<b>embedding_lookup()</b> finds the embedded values for our batch of 30x20 words. It goes to each row of <code>input_data</code>, and for each word in the row/sentence, finds the correspond vector in <code>embedding_dic<code>. <br>\n",
"It creates a [30x20x200] tensor, so, the first element of <b>inputs</b> (the first sentence), is a matrix of 20x200, which each row of it, is vector representing a word in the sentence.\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"embedding_layer = tf.keras.layers.Embedding(vocab_size, embeding_vector_size,batch_input_shape=(batch_size, num_steps),trainable=True,name=\"embedding_vocab\") "
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(30, 20, 200), dtype=float32, numpy=\n",
"array([[[-0.00963352, 0.04447431, 0.03000686, ..., -0.01430992,\n",
" -0.04179025, -0.03740842],\n",
" [-0.00185136, 0.0061741 , -0.02399608, ..., 0.00052229,\n",
" 0.02421384, -0.03178833],\n",
" [ 0.03165312, -0.03180691, -0.02924181, ..., -0.04003506,\n",
" 0.04339501, 0.00341809],\n",
" ...,\n",
" [ 0.00384145, -0.025701 , -0.03223218, ..., -0.04989583,\n",
" -0.04297003, 0.03399796],\n",
" [ 0.04402603, 0.01031557, -0.04961705, ..., -0.04415311,\n",
" -0.04264161, -0.04333409],\n",
" [-0.04641173, 0.0193573 , 0.03973095, ..., -0.01120675,\n",
" -0.03314363, -0.02827821]],\n",
"\n",
" [[-0.03361372, -0.04295586, 0.03282306, ..., -0.04212505,\n",
" 0.03222534, 0.04298704],\n",
" [ 0.0467957 , 0.01119499, -0.03936114, ..., 0.01421765,\n",
" -0.00408707, 0.00464406],\n",
" [-0.00853928, 0.04816118, 0.03704181, ..., -0.04109078,\n",
" -0.01007197, 0.0286814 ],\n",
" ...,\n",
" [ 0.02400419, -0.02699733, -0.01459912, ..., 0.02876126,\n",
" -0.03713653, -0.03767141],\n",
" [ 0.01717831, 0.03590203, -0.02132884, ..., 0.02870237,\n",
" 0.02762144, -0.0386932 ],\n",
" [ 0.03142231, 0.03410158, 0.02940441, ..., -0.02251859,\n",
" 0.03102667, -0.0225789 ]],\n",
"\n",
" [[ 0.04789158, -0.0051275 , 0.00197969, ..., 0.0385602 ,\n",
" 0.01963456, -0.00577633],\n",
" [-0.00659157, -0.03292717, -0.04912363, ..., -0.02876605,\n",
" -0.02011316, 0.04732842],\n",
" [ 0.02080854, -0.02982075, 0.02682466, ..., 0.04946811,\n",
" -0.00229813, 0.00112446],\n",
" ...,\n",
" [ 0.02607346, -0.0415265 , -0.01123742, ..., -0.00301447,\n",
" 0.03501078, 0.04961843],\n",
" [ 0.03771366, 0.02780659, -0.04763606, ..., 0.03223178,\n",
" -0.02288019, 0.01915834],\n",
" [ 0.0474854 , -0.00783978, 0.03445766, ..., -0.01544482,\n",
" -0.02868042, 0.00567645]],\n",
"\n",
" ...,\n",
"\n",
" [[-0.03359639, -0.01192484, 0.02161874, ..., 0.03021271,\n",
" -0.00923153, 0.02466537],\n",
" [ 0.02889576, -0.02996948, -0.00271885, ..., 0.04775413,\n",
" -0.03437793, 0.00707468],\n",
" [-0.01108208, 0.04819736, -0.0070881 , ..., -0.03020909,\n",
" -0.04625753, -0.02776117],\n",
" ...,\n",
" [-0.01802858, -0.01735945, -0.0152363 , ..., -0.00422896,\n",
" 0.04264318, -0.04981344],\n",
" [ 0.00044448, 0.04429383, 0.02860153, ..., 0.02268933,\n",
" -0.00661413, 0.04705914],\n",
" [-0.00394057, 0.00958618, 0.01651721, ..., -0.01800476,\n",
" 0.03917832, -0.02807972]],\n",
"\n",
" [[-0.03200712, 0.0478869 , 0.03639423, ..., 0.01515701,\n",
" 0.01208561, -0.00037247],\n",
" [ 0.04942827, 0.0239869 , -0.03067038, ..., -0.01130854,\n",
" 0.02087886, 0.03564385],\n",
" [ 0.02080854, -0.02982075, 0.02682466, ..., 0.04946811,\n",
" -0.00229813, 0.00112446],\n",
" ...,\n",
" [ 0.03879717, -0.00401578, 0.03618746, ..., 0.02274625,\n",
" -0.00253565, 0.0069579 ],\n",
" [ 0.0467957 , 0.01119499, -0.03936114, ..., 0.01421765,\n",
" -0.00408707, 0.00464406],\n",
" [-0.007146 , -0.0061972 , 0.02657751, ..., 0.03626447,\n",
" -0.0126319 , -0.00684062]],\n",
"\n",
" [[ 0.02167756, 0.04020509, -0.03158095, ..., 0.03396587,\n",
" 0.00901557, 0.04349915],\n",
" [-0.03342368, 0.01288665, 0.04773717, ..., 0.01916863,\n",
" 0.04265571, 0.0195909 ],\n",
" [-0.01865839, 0.03886256, -0.00242085, ..., 0.00906827,\n",
" 0.00160943, -0.00667285],\n",
" ...,\n",
" [-0.00134995, 0.01540628, -0.00959247, ..., 0.01157932,\n",
" 0.02552957, 0.03260735],\n",
" [-0.02965094, 0.02561637, -0.02820656, ..., -0.01047771,\n",
" 0.02718032, -0.04407642],\n",
" [ 0.00973027, -0.02719208, -0.03486856, ..., -0.00013089,\n",
" 0.01386044, 0.02843121]]], dtype=float32)>"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Define where to get the data for our embeddings from\n",
"inputs = embedding_layer(_input_data)\n",
"inputs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h3>Constructing Recurrent Neural Networks</h3>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"In this step, we create the stacked LSTM using <b>tf.keras.layers.StackedRNNCells</b>, which is a 2 layer LSTM network:\n"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"lstm_cell_l1 = tf.keras.layers.LSTMCell(hidden_size_l1)\n",
"lstm_cell_l2 = tf.keras.layers.LSTMCell(hidden_size_l2)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"stacked_lstm = tf.keras.layers.StackedRNNCells([lstm_cell_l1, lstm_cell_l2])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<b>tf.keras.layers.RNN</b> creates a recurrent neural network using <b>stacked_lstm</b>. \n",
"\n",
"The input should be a Tensor of shape: [batch_size, max_time, embedding_vector_size], in our case it would be (30, 20, 200)\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"layer = tf.keras.layers.RNN(stacked_lstm,[batch_size, num_steps],return_state=False,stateful=True,trainable=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Also, we initialize the states of the nework:\n",
"\n",
"<h4>_initial_state</h4>\n",
"\n",
"For each LSTM, there are 2 state matrices, c_state and m_state. c_state and m_state represent \"Memory State\" and \"Cell State\". Each hidden layer, has a vector of size 30, which keeps the states. so, for 200 hidden units in each LSTM, we have a matrix of size [30x200]\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"init_state = tf.Variable(tf.zeros([batch_size,embeding_vector_size]),trainable=False)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"layer.inital_state = init_state"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Variable 'Variable:0' shape=(30, 200) dtype=float32, numpy=\n",
"array([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"layer.inital_state"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"so, lets look at the outputs. The output of the stackedLSTM comes from 128 hidden_layer, and in each time step(=20), one of them get activated. we use the linear activation to map the 128 hidden layer to a [30X20 matrix]\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"outputs = layer(inputs)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(30, 20, 128), dtype=float32, numpy=\n",
"array([[[-1.15399947e-03, -5.68167306e-04, -6.36982557e-04, ...,\n",
" -4.73892607e-04, 5.77440369e-04, 9.37731005e-04],\n",
" [-2.25645606e-03, -4.12659574e-04, -1.41658777e-04, ...,\n",
" -2.64253211e-03, -4.77828173e-04, 2.03687442e-03],\n",
" [-3.47123714e-03, -1.35197095e-03, 1.65716890e-04, ...,\n",
" -3.61537002e-03, -4.11159941e-04, 1.61307643e-03],\n",
" ...,\n",
" [-8.39538034e-03, -2.45015253e-04, -1.09094812e-03, ...,\n",
" 1.08035376e-04, -5.07893157e-04, 2.33715540e-03],\n",
" [-7.56491860e-03, -1.81483256e-03, -1.33696629e-03, ...,\n",
" 2.04627588e-03, 3.20677907e-04, 1.67608715e-03],\n",
" [-6.53917482e-03, -2.30778474e-03, -2.84525519e-03, ...,\n",
" 4.18785587e-03, 6.00512576e-05, -1.63874269e-04]],\n",
"\n",
" [[ 1.09756053e-04, 1.32923410e-03, 1.75143490e-04, ...,\n",
" -4.65967751e-04, 1.16835954e-03, -1.53120316e-03],\n",
" [ 3.86477594e-04, 1.76512927e-03, 1.25126052e-03, ...,\n",
" -6.69570640e-04, 1.57780340e-03, -3.60413967e-03],\n",
" [ 7.42918928e-04, 3.66925145e-04, 9.12140007e-04, ...,\n",
" -4.80953662e-04, 6.38809230e-04, -2.65075359e-03],\n",
" ...,\n",
" [-3.85934510e-03, -4.22394089e-03, 6.22767257e-03, ...,\n",
" -3.76218013e-05, -2.43352191e-03, 3.04848235e-03],\n",
" [-3.42462212e-03, -5.17992396e-03, 6.16987981e-03, ...,\n",
" 9.50013928e-04, -1.55145070e-03, 2.25520926e-03],\n",
" [-3.16660781e-03, -5.87866781e-03, 6.85461890e-03, ...,\n",
" 7.26393540e-04, 5.82168694e-04, 1.00521476e-03]],\n",
"\n",
" [[-7.00488628e-04, 6.71217800e-04, 1.14916719e-03, ...,\n",
" 5.90832242e-05, -1.10659108e-03, 5.05686738e-04],\n",
" [-1.68715615e-03, 2.62441719e-03, 7.52402237e-04, ...,\n",
" 2.93933379e-04, -7.50989828e-04, 1.84876705e-03],\n",
" [-2.15385109e-03, 3.76083306e-03, 8.66153336e-04, ...,\n",
" 7.48437655e-04, -3.39015998e-04, 1.33799261e-03],\n",
" ...,\n",
" [-3.54648242e-03, 2.66019814e-03, -1.99942477e-03, ...,\n",
" -3.88026633e-03, -1.18369795e-03, -2.26780702e-03],\n",
" [-4.94966330e-03, 2.58278195e-03, -1.76210597e-03, ...,\n",
" -3.51913180e-03, -2.96921353e-03, -2.58426508e-03],\n",
" [-4.13994351e-03, 2.01505190e-03, -2.82721012e-03, ...,\n",
" -1.18802756e-03, -3.00479867e-03, -1.76892546e-03]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.09716470e-03, 1.16411468e-03, 1.45541516e-03, ...,\n",
" -1.31794252e-03, -6.36122539e-04, -9.64721185e-05],\n",
" [ 9.03060660e-04, 2.16806470e-03, 3.56120500e-03, ...,\n",
" -2.42436817e-03, -2.66180676e-03, -5.67734824e-04],\n",
" [ 4.85980941e-04, 3.34257772e-03, 2.30879360e-03, ...,\n",
" -2.14602519e-03, -2.78160814e-03, -5.42481255e-04],\n",
" ...,\n",
" [-1.08194526e-03, -1.35212112e-03, -1.02357357e-03, ...,\n",
" 6.48161629e-03, 9.64760606e-04, 2.00056820e-03],\n",
" [-2.12396239e-03, -1.86850911e-03, -1.43049145e-03, ...,\n",
" 4.64034779e-03, 1.80464529e-03, 2.95537570e-03],\n",
" [-2.70976266e-03, -2.69680098e-03, -1.07835757e-03, ...,\n",
" 3.10675241e-03, 2.75482493e-03, 1.96297700e-03]],\n",
"\n",
" [[-3.01606662e-04, 1.11515663e-04, -2.01646253e-04, ...,\n",
" -1.00146246e-03, -3.86775209e-04, 5.55271708e-06],\n",
" [-2.78348831e-04, -5.37605898e-04, -9.52472619e-04, ...,\n",
" 8.93422839e-05, -8.30825826e-04, -7.54974666e-04],\n",
" [-3.05426482e-04, -4.47027967e-04, -9.24672117e-04, ...,\n",
" 1.30620867e-03, -1.10295252e-03, -2.41234433e-03],\n",
" ...,\n",
" [-3.81959789e-03, 4.48417058e-03, 7.90258683e-03, ...,\n",
" -4.18833457e-03, -2.37779669e-03, -1.72640721e-03],\n",
" [-4.27715806e-03, 4.85268515e-03, 8.99966620e-03, ...,\n",
" -4.45619645e-03, -2.15246575e-03, -2.49825441e-03],\n",
" [-4.11562435e-03, 4.08452749e-03, 1.00230826e-02, ...,\n",
" -5.72837796e-03, -1.68386812e-03, -1.94654311e-03]],\n",
"\n",
" [[-9.87263280e-04, -1.43964004e-04, -1.23946884e-04, ...,\n",
" 2.93438963e-04, -7.24759069e-04, -5.73443482e-04],\n",
" [-1.64883956e-03, 1.52932844e-04, -1.52630778e-03, ...,\n",
" 5.01818024e-04, -1.31549395e-03, -6.85221690e-04],\n",
" [-2.14515440e-03, 2.30228249e-03, -9.52944567e-04, ...,\n",
" -5.82493143e-04, -2.65693013e-03, -1.24638854e-03],\n",
" ...,\n",
" [-1.71335950e-03, -1.80269952e-03, -1.54979946e-03, ...,\n",
" 8.14396713e-04, 1.70287199e-03, -2.42939615e-03],\n",
" [-1.80875650e-03, -1.14083337e-03, -1.14884553e-03, ...,\n",
" 1.86088914e-03, 8.82311724e-04, -2.89117754e-03],\n",
" [-1.89534924e-03, -2.00069742e-03, -8.35343730e-04, ...,\n",
" 3.46254394e-03, -1.56247336e-03, -2.52691004e-03]]],\n",
" dtype=float32)>"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<h2>Dense layer</h2>\n",
"We now create densely-connected neural network layer that would reshape the outputs tensor from [30 x 20 x 128] to [30 x 20 x 10000].\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"logits_outputs = dense(outputs)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape of the output from dense layer: (30, 20, 10000)\n"
]
}
],
"source": [
"print(\"shape of the output from dense layer: \", logits_outputs.shape) #(batch_size, sequence_length, vocab_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h2>Activation layer</h2>\n",
"\n",
"A softmax activation layers is also then applied to derive the probability of the output being in any of the multiclass(10000 in this case) possibilities. \n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"activation = tf.keras.layers.Activation('softmax')"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"output_words_prob = activation(logits_outputs)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape of the output from the activation layer: (30, 20, 10000)\n"
]
}
],
"source": [
"print(\"shape of the output from the activation layer: \", output_words_prob.shape) #(batch_size, sequence_length, vocab_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lets look at the probability of observing words for t=0 to t=20:\n"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The probability of observing words in t=0 to t=20 tf.Tensor(\n",
"[[9.99780095e-05 9.99922995e-05 9.99947006e-05 ... 1.00020879e-04\n",
" 9.99827971e-05 9.99914919e-05]\n",
" [9.99729527e-05 9.99912663e-05 9.99962358e-05 ... 1.00030331e-04\n",
" 9.99581753e-05 9.99895565e-05]\n",
" [9.99803524e-05 9.99943295e-05 1.00012709e-04 ... 1.00035977e-04\n",
" 9.99577169e-05 9.99934273e-05]\n",
" ...\n",
" [9.99300610e-05 1.00030367e-04 1.00004378e-04 ... 1.00011341e-04\n",
" 9.99942495e-05 9.98878968e-05]\n",
" [9.99065378e-05 1.00010395e-04 9.99664844e-05 ... 1.00023492e-04\n",
" 9.99747208e-05 9.98966716e-05]\n",
" [9.98915930e-05 1.00003337e-04 9.99345066e-05 ... 1.00023390e-04\n",
" 9.99777913e-05 9.98746909e-05]], shape=(20, 10000), dtype=float32)\n"
]
}
],
"source": [
"print(\"The probability of observing words in t=0 to t=20\", output_words_prob[0,0:num_steps])"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<h3>Prediction</h3>\n",
"What is the word correspond to the probability output? Lets use the maximum probability:\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([1464, 1137, 9909, 8233, 8233, 8233, 9604, 1260, 976, 976, 7646,\n",
" 7129, 7366, 7366, 7366, 7366, 1732, 1732, 1732, 1732])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.argmax(output_words_prob[0,0:num_steps], axis=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"So, what is the ground truth for the first word of first sentence? You can get it from target tensor, if you want to find the embedding vector: \n"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
"array([9971, 9972, 9974, 9975, 9976, 9980, 9981, 9982, 9983, 9984, 9986,\n",
" 9987, 9988, 9989, 9991, 9992, 9993, 9994, 9995, 9996], dtype=int32)"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"_targets[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<h4>Objective function</h4>\n",
"\n",
"How similar the predicted words are to the target words?\n",
"\n",
"Now we have to define our objective function, to calculate the similarity of predicted values to ground truth, and then, penalize the model with the error. Our objective is to minimize loss function, that is, to minimize the average negative log probability of the target words:\n",
"\n",
"$$\\\\text{loss} = -\\\\frac{1}{N}\\\\sum_{i=1}^{N} \\\\ln p_{\\\\text{target}\\_i}$$\n",
"\n",
"This function is already implemented and available in TensorFlow through _tf.keras.losses.sparse_categorical_crossentropy_. It calculates the categorical cross-entropy loss for <b>logits</b> and the <b>target</b> sequence. \n",
"\n",
"The arguments of this function are: \n",
"\n",
"<ul>\n",
" <li>logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols].</li> \n",
" <li>targets: List of 1D batch-sized int32 Tensors of the same length as logits.</li> \n",
"</ul>\n"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"def crossentropy(y_true, y_pred):\n",
" return tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"loss = crossentropy(_targets, output_words_prob)"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Lets look at the first 10 values of loss:\n"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(10,), dtype=float32, numpy=\n",
"array([9.2101345, 9.210351 , 9.209828 , 9.210473 , 9.210363 , 9.209582 ,\n",
" 9.209699 , 9.210181 , 9.210007 , 9.210093 ], dtype=float32)>"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss[0,:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we define cost as average of the losses:\n"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(), dtype=float32, numpy=184.20605>"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cost = tf.reduce_sum(loss / batch_size)\n",
"cost"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<h3>Training</h3>\n",
"\n",
"To do training for our network, we have to take the following steps:\n",
"\n",
"<ol>\n",
" <li>Define the optimizer.</li>\n",
" <li>Assemble layers to build model.</li>\n",
" <li>Calculate the gradients based on the loss function.</li>\n",
" <li>Apply the optimizer to the variables/gradients tuple.</li>\n",
"</ol>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<h4>1. Define Optimizer</h4>\n"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"# Create a variable for the learning rate\n",
"lr = tf.Variable(0.0, trainable=False)\n",
"optimizer = tf.keras.optimizers.SGD(lr=lr, clipnorm=max_grad_norm)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h4>2. Assemble layers to build model.</h4>\n"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding_vocab (Embedding) (30, 20, 200) 2000000 \n",
"_________________________________________________________________\n",
"rnn (RNN) (30, 20, 128) 671088 \n",
"_________________________________________________________________\n",
"dense (Dense) (30, 20, 10000) 1290000 \n",
"_________________________________________________________________\n",
"activation (Activation) (30, 20, 10000) 0 \n",
"=================================================================\n",
"Total params: 3,961,088\n",
"Trainable params: 3,955,088\n",
"Non-trainable params: 6,000\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model = tf.keras.Sequential()\n",
"model.add(embedding_layer)\n",
"model.add(layer)\n",
"model.add(dense)\n",
"model.add(activation)\n",
"model.compile(loss=crossentropy, optimizer=optimizer)\n",
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<h4>2. Trainable Variables</h4>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Defining a variable, if you passed <i>trainable=True</i>, the variable constructor automatically adds new variables to the graph collection <b>GraphKeys.TRAINABLE_VARIABLES</b>. Now, using <i>tf.trainable_variables()</i> you can get all variables created with <b>trainable=True</b>.\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"# Get all TensorFlow variables marked as \"trainable\" (i.e. all of them except _lr, which we just created)\n",
"tvars = model.trainable_variables"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Note: we can find the name and scope of all variables:\n"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [
{
"data": {
"text/plain": [
"['embedding_vocab/embeddings:0',\n",
" 'rnn/stacked_rnn_cells/lstm_cell/kernel:0',\n",
" 'rnn/stacked_rnn_cells/lstm_cell/recurrent_kernel:0',\n",
" 'rnn/stacked_rnn_cells/lstm_cell/bias:0',\n",
" 'rnn/stacked_rnn_cells/lstm_cell_1/kernel:0',\n",
" 'rnn/stacked_rnn_cells/lstm_cell_1/recurrent_kernel:0',\n",
" 'rnn/stacked_rnn_cells/lstm_cell_1/bias:0',\n",
" 'dense/kernel:0',\n",
" 'dense/bias:0']"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[v.name for v in tvars] "
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<h4>3. Calculate the gradients based on the loss function</h4>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"**Gradient**: The gradient of a function is the slope of its derivative (line), or in other words, the rate of change of a function. It's a vector (a direction to move) that points in the direction of greatest increase of the function, and calculated by the <b>derivative</b> operation.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First lets recall the gradient function using an toy example:\n",
"$$ z = \\\\left(2x^2 + 3xy\\\\right)$$\n"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"x = tf.constant(1.0)\n",
"y = tf.constant(2.0)\n",
"with tf.GradientTape(persistent=True) as g:\n",
" g.watch(x)\n",
" g.watch(y)\n",
" func_test = 2 * x * x + 3 * x * y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The <b>tf.gradients()</b> function allows you to compute the symbolic gradient of one tensor with respect to one or more other tensors—including variables. <b>tf.gradients(func, xs)</b> constructs symbolic partial derivatives of sum of <b>func</b> w.r.t. <i>x</i> in <b>xs</b>. \n",
"\n",
"Now, lets look at the derivitive w.r.t. <b>var_x</b>:\n",
"$$ \\\\frac{\\\\partial \\\\:}{\\\\partial \\\\:x}\\\\left(2x^2 + 3xy\\\\right) = 4x + 3y $$\n"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(10.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"var_grad = g.gradient(func_test, x) # Will compute to 10.0\n",
"print(var_grad)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"the derivative w.r.t. <b>var_y</b>:\n",
"$$ \\\\frac{\\\\partial \\\\:}{\\\\partial \\\\:y}\\\\left(2x^2 + 3xy\\\\right) = 3x $$\n"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(3.0, shape=(), dtype=float32)\n"
]
}
],
"source": [
"var_grad = g.gradient(func_test, y) # Will compute to 3.0\n",
"print(var_grad)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can look at gradients w.r.t all variables:\n"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"with tf.GradientTape() as tape:\n",
" # Forward pass.\n",
" output_words_prob = model(_input_data)\n",
" # Loss value for this batch.\n",
" loss = crossentropy(_targets, output_words_prob)\n",
" cost = tf.reduce_sum(loss,axis=0) / batch_size"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# Get gradients of loss wrt the trainable variables.\n",
"grad_t_list = tape.gradient(cost, tvars)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[<tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x7f049baed780>, <tf.Tensor: shape=(200, 1024), dtype=float32, numpy=\n",
"array([[-1.5256627e-07, -6.7794031e-07, -8.7616563e-08, ...,\n",
" 5.3887391e-07, -6.7084341e-07, -1.9336210e-07],\n",
" [ 6.9934833e-07, -4.5459316e-07, -7.6316510e-08, ...,\n",
" -1.2728367e-07, 4.6999193e-07, 4.3857995e-08],\n",
" [ 2.9738970e-08, 1.4588701e-07, 5.3057556e-07, ...,\n",
" 2.6987630e-07, 3.5652761e-09, -4.3176478e-08],\n",
" ...,\n",
" [-4.5120191e-07, -6.6480175e-07, -1.1827770e-07, ...,\n",
" 6.3660934e-07, 1.6707975e-07, -1.4780258e-07],\n",
" [ 1.1878119e-06, 1.8076659e-07, 8.1836511e-08, ...,\n",
" -5.5646012e-08, -2.8694140e-07, 1.2399234e-07],\n",
" [ 8.6045688e-08, 1.0653308e-06, -2.0499257e-07, ...,\n",
" -5.7437262e-07, -3.3745656e-07, -4.2511931e-07]], dtype=float32)>, <tf.Tensor: shape=(256, 1024), dtype=float32, numpy=\n",
"array([[ 4.2339448e-08, -1.3613645e-07, -1.1695782e-09, ...,\n",
" 1.7090683e-07, 6.7923679e-08, 4.3531614e-09],\n",
" [ 8.9228593e-08, 5.4885692e-08, 2.2328859e-08, ...,\n",
" 4.7083506e-07, -5.5002836e-08, 1.8088105e-08],\n",
" [-3.1586396e-08, -9.3367575e-08, -4.9391662e-08, ...,\n",
" -4.2792667e-08, -6.9556641e-08, -7.3376555e-08],\n",
" ...,\n",
" [ 1.1441854e-07, -5.2786987e-08, 4.6940315e-08, ...,\n",
" 4.7034268e-07, -5.7223208e-08, 6.7549152e-08],\n",
" [ 3.8188105e-08, 3.8312010e-08, -4.8367259e-08, ...,\n",
" 9.4381036e-08, -2.5698074e-07, -1.0312381e-07],\n",
" [-2.1261064e-08, 1.0131200e-07, -7.8779578e-08, ...,\n",
" 3.4905204e-07, -8.7168971e-08, 2.5306417e-07]], dtype=float32)>, <tf.Tensor: shape=(1024,), dtype=float32, numpy=\n",
"array([ 3.0017754e-05, -1.3988034e-05, 1.7601337e-06, ...,\n",
" 6.2383711e-05, 1.1182047e-05, -8.6920409e-07], dtype=float32)>, <tf.Tensor: shape=(256, 512), dtype=float32, numpy=\n",
"array([[-2.8290195e-08, 1.5334706e-08, 1.1612169e-07, ...,\n",
" -5.9297982e-08, -1.0097165e-07, 1.9135928e-07],\n",
" [ 7.9073644e-08, 9.2087049e-08, -9.9100248e-08, ...,\n",
" -1.0497566e-07, -1.0579502e-07, 5.9344270e-08],\n",
" [-3.2640752e-09, -8.7425754e-08, 1.4985581e-07, ...,\n",
" 1.8573699e-07, -3.7578321e-08, -3.3039575e-07],\n",
" ...,\n",
" [-3.6049147e-07, -1.6131165e-07, 9.4717407e-08, ...,\n",
" 1.0768905e-07, -6.8524194e-08, 9.9016496e-08],\n",
" [ 1.4439294e-08, 6.4697218e-08, -1.4680919e-07, ...,\n",
" 8.5160416e-08, 6.7236094e-08, 2.6067036e-07],\n",
" [ 1.1350662e-07, -1.7743082e-07, -1.3867341e-07, ...,\n",
" 5.3598807e-08, -2.1517354e-07, 2.1897785e-08]], dtype=float32)>, <tf.Tensor: shape=(128, 512), dtype=float32, numpy=\n",
"array([[ 9.0920679e-08, -9.7806634e-08, -1.0596389e-07, ...,\n",
" -4.0080572e-08, 3.3786947e-09, -5.9041042e-08],\n",
" [-4.3531806e-08, 2.8909722e-07, -1.0749537e-07, ...,\n",
" -3.2792688e-08, 9.8477187e-08, -2.0046159e-07],\n",
" [-1.2692003e-07, -2.2698548e-07, 2.2944741e-07, ...,\n",
" -1.3208687e-07, -2.1840023e-08, 7.7354201e-08],\n",
" ...,\n",
" [-5.3174009e-08, 4.6163819e-09, 6.1483739e-08, ...,\n",
" 4.6193790e-07, -9.2429282e-09, 1.3388501e-07],\n",
" [ 8.8607351e-08, -1.7749313e-07, 7.1082020e-08, ...,\n",
" 1.4228182e-07, -9.3125578e-08, 4.7980748e-09],\n",
" [ 1.8544512e-08, 1.6310172e-07, 1.0322760e-08, ...,\n",
" -2.6089724e-08, -4.1588436e-08, -2.5750484e-07]], dtype=float32)>, <tf.Tensor: shape=(512,), dtype=float32, numpy=\n",
"array([-1.99110345e-05, -7.29465546e-06, 2.61911173e-05, 1.57037521e-05,\n",
" 4.71392450e-05, -2.40694553e-05, 3.97232907e-05, -2.41768757e-05,\n",
" 2.20949660e-05, 1.45677586e-05, 1.53334659e-05, -1.15410585e-05,\n",
" -1.78399441e-05, 7.72970543e-06, -3.04341957e-05, -1.23156778e-05,\n",
" -2.88985138e-05, -3.24518442e-05, -1.37811994e-05, -4.26956431e-05,\n",
" 2.68382955e-05, 5.45565399e-06, -4.25340431e-06, 1.34797592e-05,\n",
" 3.28052920e-05, -3.64782863e-06, -7.26121216e-05, -1.10095789e-05,\n",
" -2.20755610e-05, -5.45642615e-05, 6.67087152e-05, 4.31579392e-05,\n",
" -2.71007411e-05, -1.66865757e-05, 1.15991115e-05, 1.35062855e-05,\n",
" 5.31767218e-05, -8.22714428e-05, -1.52825942e-06, 1.00017196e-05,\n",
" 1.10279016e-05, -1.05835104e-04, -1.24400158e-05, 4.59890407e-06,\n",
" -6.26609108e-05, -2.25741060e-05, 9.35145436e-05, -8.17662112e-06,\n",
" -1.96223482e-06, 1.03931798e-05, 6.14793407e-06, -3.57266472e-05,\n",
" -2.66429197e-06, 6.07607362e-05, 2.20493257e-06, -5.11070357e-06,\n",
" -2.34731124e-05, -2.77198833e-05, -4.84299835e-06, 1.24287526e-05,\n",
" 4.04886887e-05, -9.61388014e-06, -2.38893244e-05, -1.97624195e-05,\n",
" 1.49732834e-04, -5.39622852e-06, -2.26899392e-05, 7.22518962e-07,\n",
" 2.41547041e-05, -2.52683276e-05, -4.31222943e-05, -4.30931759e-05,\n",
" -5.11325452e-05, -1.20022678e-05, -1.13347807e-04, -3.19706014e-05,\n",
" 2.77400686e-05, 1.15523808e-05, -3.50865739e-05, -3.18428574e-05,\n",
" -2.78922162e-05, 1.28757665e-05, -1.02233698e-05, 9.90332410e-06,\n",
" 1.31079069e-05, -5.32801678e-05, 1.87705227e-05, -1.26863042e-05,\n",
" -3.54922486e-05, 1.36968156e-04, -2.23922370e-06, -3.07610462e-05,\n",
" 1.22330648e-05, 1.26477562e-05, 1.43745383e-05, -7.98514011e-05,\n",
" -9.61628321e-05, -3.54778022e-05, -1.73661138e-05, -1.29420878e-04,\n",
" 5.61618035e-05, -8.48723175e-06, -1.58587463e-05, -7.84525037e-06,\n",
" 4.35119618e-05, 3.49604961e-05, 3.41538107e-05, -4.09324421e-05,\n",
" -3.71239803e-05, 2.63132315e-05, -8.98130929e-06, -1.15276180e-05,\n",
" 2.24412670e-05, 4.23215488e-06, 4.73320397e-05, 6.52694507e-06,\n",
" 2.72522266e-05, 9.80317873e-06, -1.62495126e-05, 4.84671946e-05,\n",
" 7.91536331e-06, -3.38005011e-05, 3.75219679e-05, -1.26622281e-05,\n",
" -1.43127882e-05, -1.01089681e-05, -5.82824214e-07, 4.73749606e-05,\n",
" -3.34400829e-05, -9.13100666e-06, 3.34885335e-05, 7.48173807e-06,\n",
" 5.00146343e-05, -4.63066608e-05, 5.54803737e-05, -6.93207694e-05,\n",
" 1.49135958e-05, 2.06972636e-05, 6.67572749e-07, -3.38796672e-05,\n",
" -2.45707542e-05, 5.00457736e-06, -7.11795874e-05, -3.28824171e-05,\n",
" -7.12445253e-05, -4.47741731e-05, -1.83749580e-05, -5.06311189e-05,\n",
" 2.51818274e-05, 1.93268388e-05, -3.03403503e-05, 2.56008680e-05,\n",
" 5.57523308e-05, -2.89987947e-05, -1.01858255e-04, -3.52753068e-06,\n",
" -2.11007937e-05, -6.58260324e-05, 8.68044590e-05, 5.96911414e-05,\n",
" -1.47248156e-05, -3.93647315e-06, 5.77755418e-06, 1.11619847e-05,\n",
" 1.06337742e-04, -7.39129755e-05, -1.76148787e-05, 2.84849120e-05,\n",
" 9.02289503e-06, -1.57410817e-04, -8.27398890e-06, 1.21711164e-05,\n",
" -7.93867002e-05, -4.40134572e-05, 1.02432459e-04, -3.75858945e-05,\n",
" -9.79200740e-06, -8.26310134e-06, 1.07381156e-05, -3.74959855e-05,\n",
" -6.88323780e-08, 1.00759702e-04, 1.37312236e-05, 2.63480179e-05,\n",
" -2.73081823e-05, -4.68854632e-05, 1.79791823e-05, 9.39827714e-06,\n",
" 5.23153649e-05, -1.15003641e-05, -4.02484293e-05, -3.32964737e-06,\n",
" 1.91177969e-04, -1.24483377e-05, -3.35657496e-05, 1.18089883e-05,\n",
" 3.86983993e-06, -2.82810270e-05, -5.05980242e-05, -9.17996877e-05,\n",
" -6.17987243e-05, -1.98910184e-05, -1.50251610e-04, -6.54812029e-05,\n",
" 1.50678688e-05, 1.77241236e-05, -5.29934405e-05, -2.31698523e-05,\n",
" -4.95133499e-05, -1.56820533e-07, -2.10566177e-05, 1.56480855e-05,\n",
" 2.50106350e-05, -6.91187961e-05, -1.54551053e-05, 2.49530731e-05,\n",
" -2.96732142e-05, 2.08491663e-04, -6.68619168e-06, -1.37702909e-05,\n",
" -5.03858864e-05, 4.12632726e-05, -2.86901486e-05, -8.94564000e-05,\n",
" -1.43599260e-04, -7.36482398e-05, -2.30675596e-06, -1.83999378e-04,\n",
" 8.15467065e-05, -2.82169822e-05, -3.25880064e-05, 2.02052215e-05,\n",
" 5.78480431e-05, 3.43770625e-06, 5.27895536e-05, -5.51045741e-05,\n",
" -4.11938563e-05, 3.63297222e-05, -1.12845282e-05, 5.06180186e-07,\n",
" 4.75956986e-05, -1.25265224e-05, 7.28672021e-05, 1.10194214e-05,\n",
" 2.42327988e-05, -3.21596599e-05, -2.95279242e-05, 7.33682391e-05,\n",
" -3.90976857e-06, -3.56802411e-05, 6.13520242e-05, -2.70030296e-06,\n",
" -3.46124507e-05, 1.15872890e-05, -2.52637255e-06, 7.42620468e-05,\n",
" -2.82356748e-03, 9.28777736e-03, 1.73751581e-02, -4.10499200e-02,\n",
" -4.51671481e-02, 3.32909301e-02, -3.30383927e-02, 5.03160767e-02,\n",
" -3.33478153e-02, -4.89518326e-03, -1.24594867e-02, -1.90301379e-03,\n",
" 7.89022632e-03, -3.43460143e-02, -4.65416834e-02, 2.02876348e-02,\n",
" -3.06195803e-02, 3.12450295e-03, -5.92209212e-02, 2.21785568e-02,\n",
" -3.97754312e-02, 3.58326547e-02, 5.22245467e-02, 1.27030537e-02,\n",
" -2.33417563e-02, 1.31659340e-02, 1.44711733e-02, 3.33931670e-02,\n",
" -1.63034862e-03, -1.68398954e-02, -4.64263670e-02, -2.31010932e-02,\n",
" 2.15429142e-02, -3.47415321e-02, 1.11877243e-03, -3.20874527e-02,\n",
" 5.04406765e-02, -4.71264534e-02, 2.37041665e-03, 4.25750390e-02,\n",
" -5.64210536e-03, -5.16877770e-02, -3.33238556e-03, -3.65043543e-02,\n",
" -4.31881547e-02, 2.77623609e-02, 4.35373001e-02, -9.80284158e-03,\n",
" -6.57556718e-03, 9.63302981e-03, 2.45699212e-02, -1.20348455e-02,\n",
" 3.58177796e-02, 4.03181985e-02, 1.08939670e-02, 4.60814917e-03,\n",
" -1.88708119e-02, -6.05110498e-03, 8.50436278e-04, 4.17488441e-02,\n",
" -2.86238641e-02, 1.25137642e-02, 9.99696460e-03, -2.48471536e-02,\n",
" -6.57288432e-02, -4.41054180e-02, 1.80869340e-03, -1.15264514e-02,\n",
" -5.49815595e-02, -3.83359101e-03, 5.45885488e-02, -1.82086844e-02,\n",
" 1.97352841e-02, -1.32854898e-02, 3.80803235e-02, -2.38714851e-02,\n",
" 1.27633018e-02, 2.73932256e-02, -3.79742645e-02, -2.63928697e-02,\n",
" -5.67285754e-02, -3.20242494e-02, 4.83035408e-02, 2.28642486e-02,\n",
" -3.45898569e-02, -2.30111182e-02, -7.41353817e-03, -2.42721885e-02,\n",
" -1.71943605e-02, 3.50073166e-02, 7.73749501e-03, 9.98023897e-03,\n",
" 5.36300708e-03, 4.23797593e-02, 2.52226219e-02, 2.60486584e-02,\n",
" 3.93506847e-02, 7.16058761e-02, -1.00544207e-02, -2.26410627e-02,\n",
" -2.88525205e-02, 9.60933790e-03, 1.30569823e-02, 4.85427156e-02,\n",
" 1.93330981e-02, -3.36245634e-02, 1.23050390e-03, 2.35427171e-02,\n",
" 3.31656858e-02, 2.35361233e-02, 2.96390150e-04, -1.67156989e-03,\n",
" -1.38984136e-02, -5.40724136e-02, 4.72925082e-02, -1.50090493e-02,\n",
" 2.23950148e-02, 2.23637652e-02, -5.58751523e-02, 3.77190337e-02,\n",
" 7.17466883e-03, 1.39882518e-02, -2.95721702e-02, 5.99282272e-02,\n",
" -2.90058665e-02, 5.81529103e-02, -6.22727163e-03, -3.30739133e-02,\n",
" -1.86167708e-05, -8.26393625e-06, 1.85363424e-05, 3.93697519e-06,\n",
" 6.22062507e-05, -3.35149198e-05, 4.28577405e-05, -3.29527393e-05,\n",
" 3.22472661e-05, 1.02839167e-05, 1.64196717e-05, -1.46906095e-05,\n",
" -1.88890153e-05, 3.56183682e-06, -3.06116017e-05, -1.29915452e-05,\n",
" -2.44489820e-05, -3.31477495e-05, -1.67895232e-05, -4.20412616e-05,\n",
" 2.38274915e-05, -4.84490101e-07, -7.07884328e-06, 1.06966863e-05,\n",
" 2.61116147e-05, -7.47097693e-06, -8.22395959e-05, -4.96298526e-07,\n",
" -2.12179802e-05, -6.15144818e-05, 7.81837807e-05, 4.90739403e-05,\n",
" -2.08170059e-05, -1.80483075e-05, 3.40487168e-06, 1.67522285e-05,\n",
" 5.50277036e-05, -8.44242168e-05, -8.92466232e-06, 1.96976725e-05,\n",
" 1.22851006e-05, -1.04505889e-04, -1.73575663e-05, 9.93075537e-06,\n",
" -6.43244784e-05, -2.50952125e-05, 8.69531650e-05, -4.96423127e-06,\n",
" -6.89434910e-06, 9.70117162e-06, 9.21935407e-06, -3.99091950e-05,\n",
" -2.90359822e-06, 7.03785408e-05, 1.44372348e-06, -2.17463894e-05,\n",
" -2.21670507e-05, -3.43847569e-05, 3.11448366e-06, 1.50533106e-05,\n",
" 5.36600637e-05, -1.49244943e-05, -2.03131785e-05, -1.57592058e-05,\n",
" 1.64054130e-04, -1.49498055e-05, -2.55462710e-05, 8.12182043e-06,\n",
" 1.87098176e-05, -1.58740186e-05, -4.62989774e-05, -6.41835504e-05,\n",
" -5.70351804e-05, -1.13773003e-05, -1.18867021e-04, -5.31088263e-05,\n",
" 2.54290953e-05, 1.49962407e-05, -4.32317320e-05, -2.96516337e-05,\n",
" -2.52453938e-05, 7.26332837e-06, -1.55389953e-05, 1.02725371e-05,\n",
" 2.61295572e-05, -4.55491536e-05, 1.16021420e-05, -6.67282984e-06,\n",
" -3.90158821e-05, 1.61193544e-04, -6.24900531e-06, -2.59799344e-05,\n",
" 6.33476748e-06, 2.56378080e-05, 1.49027719e-05, -6.58978679e-05,\n",
" -1.11316956e-04, -4.26451406e-05, -1.82335225e-05, -1.34126953e-04,\n",
" 6.26451074e-05, -1.22298261e-05, -1.43776870e-05, -3.52375218e-06,\n",
" 4.39129508e-05, 4.03065787e-05, 3.12449920e-05, -4.53516332e-05,\n",
" -4.20891847e-05, 2.80705517e-05, -1.27462572e-05, -7.86944929e-06,\n",
" 2.50178582e-05, -4.44879561e-06, 4.51270171e-05, 4.05591436e-06,\n",
" 2.70031233e-05, 1.36304516e-05, -2.12590166e-05, 5.47602613e-05,\n",
" 3.89774505e-06, -4.69338629e-05, 3.92884358e-05, -5.52044821e-06,\n",
" -3.36022604e-05, -6.58165573e-06, -4.14886381e-06, 5.84282279e-05],\n",
" dtype=float32)>, <tf.Tensor: shape=(128, 10000), dtype=float32, numpy=\n",
"array([[ 9.5183105e-04, 2.1646395e-03, 9.0878268e-05, ...,\n",
" -2.6323535e-06, -2.6330113e-06, -2.6322307e-06],\n",
" [ 7.3757558e-04, 1.5442743e-03, 1.6750256e-03, ...,\n",
" -2.0000080e-06, -1.9987749e-06, -1.9970939e-06],\n",
" [-2.1177637e-03, -2.8248103e-03, -2.9303601e-03, ...,\n",
" 4.9381433e-06, 4.9378909e-06, 4.9390701e-06],\n",
" ...,\n",
" [-3.6386732e-04, 3.3230041e-05, -1.1236004e-03, ...,\n",
" 3.1839045e-07, 3.1763255e-07, 3.1627582e-07],\n",
" [-1.2519001e-04, 4.3181507e-04, 4.7914428e-04, ...,\n",
" -9.6191116e-07, -9.6355552e-07, -9.6267138e-07],\n",
" [ 1.6367263e-03, 1.3170558e-03, 9.3035400e-04, ...,\n",
" -2.7387218e-06, -2.7385686e-06, -2.7389342e-06]], dtype=float32)>, <tf.Tensor: shape=(10000,), dtype=float32, numpy=\n",
"array([-0.7979981 , -1.0313315 , -1.0313318 , ..., 0.00199974,\n",
" 0.00199973, 0.00199951], dtype=float32)>]\n"
]
}
],
"source": [
"print(grad_t_list)"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"now, we have a list of tensors, t-list. We can use it to find clipped tensors. <b>clip_by_global_norm</b> clips values of multiple tensors by the ratio of the sum of their norms.\n",
"\n",
"<b>clip_by_global_norm</b> get <i>t-list</i> as input and returns 2 things:\n",
"\n",
"<ul>\n",
" <li>a list of clipped tensors, so called <i>list_clipped</i></li> \n",
" <li>the global norm (global_norm) of all tensors in t_list</li> \n",
"</ul>\n"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<tensorflow.python.framework.indexed_slices.IndexedSlices at 0x7f049ba93eb8>,\n",
" <tf.Tensor: shape=(200, 1024), dtype=float32, numpy=\n",
" array([[-1.5256627e-07, -6.7794031e-07, -8.7616563e-08, ...,\n",
" 5.3887391e-07, -6.7084341e-07, -1.9336210e-07],\n",
" [ 6.9934833e-07, -4.5459316e-07, -7.6316510e-08, ...,\n",
" -1.2728367e-07, 4.6999193e-07, 4.3857995e-08],\n",
" [ 2.9738970e-08, 1.4588701e-07, 5.3057556e-07, ...,\n",
" 2.6987630e-07, 3.5652761e-09, -4.3176478e-08],\n",
" ...,\n",
" [-4.5120191e-07, -6.6480175e-07, -1.1827770e-07, ...,\n",
" 6.3660934e-07, 1.6707975e-07, -1.4780258e-07],\n",
" [ 1.1878119e-06, 1.8076659e-07, 8.1836511e-08, ...,\n",
" -5.5646012e-08, -2.8694140e-07, 1.2399234e-07],\n",
" [ 8.6045688e-08, 1.0653308e-06, -2.0499257e-07, ...,\n",
" -5.7437262e-07, -3.3745656e-07, -4.2511931e-07]], dtype=float32)>,\n",
" <tf.Tensor: shape=(256, 1024), dtype=float32, numpy=\n",
" array([[ 4.2339448e-08, -1.3613645e-07, -1.1695782e-09, ...,\n",
" 1.7090683e-07, 6.7923679e-08, 4.3531614e-09],\n",
" [ 8.9228593e-08, 5.4885692e-08, 2.2328859e-08, ...,\n",
" 4.7083506e-07, -5.5002836e-08, 1.8088105e-08],\n",
" [-3.1586396e-08, -9.3367575e-08, -4.9391662e-08, ...,\n",
" -4.2792667e-08, -6.9556641e-08, -7.3376555e-08],\n",
" ...,\n",
" [ 1.1441854e-07, -5.2786987e-08, 4.6940315e-08, ...,\n",
" 4.7034268e-07, -5.7223208e-08, 6.7549152e-08],\n",
" [ 3.8188105e-08, 3.8312010e-08, -4.8367259e-08, ...,\n",
" 9.4381036e-08, -2.5698074e-07, -1.0312381e-07],\n",
" [-2.1261064e-08, 1.0131200e-07, -7.8779578e-08, ...,\n",
" 3.4905204e-07, -8.7168971e-08, 2.5306417e-07]], dtype=float32)>,\n",
" <tf.Tensor: shape=(1024,), dtype=float32, numpy=\n",
" array([ 3.0017754e-05, -1.3988034e-05, 1.7601337e-06, ...,\n",
" 6.2383711e-05, 1.1182047e-05, -8.6920409e-07], dtype=float32)>,\n",
" <tf.Tensor: shape=(256, 512), dtype=float32, numpy=\n",
" array([[-2.8290195e-08, 1.5334706e-08, 1.1612169e-07, ...,\n",
" -5.9297982e-08, -1.0097165e-07, 1.9135928e-07],\n",
" [ 7.9073644e-08, 9.2087049e-08, -9.9100248e-08, ...,\n",
" -1.0497566e-07, -1.0579502e-07, 5.9344270e-08],\n",
" [-3.2640752e-09, -8.7425754e-08, 1.4985581e-07, ...,\n",
" 1.8573699e-07, -3.7578321e-08, -3.3039575e-07],\n",
" ...,\n",
" [-3.6049147e-07, -1.6131165e-07, 9.4717407e-08, ...,\n",
" 1.0768905e-07, -6.8524194e-08, 9.9016496e-08],\n",
" [ 1.4439294e-08, 6.4697218e-08, -1.4680919e-07, ...,\n",
" 8.5160416e-08, 6.7236094e-08, 2.6067036e-07],\n",
" [ 1.1350662e-07, -1.7743082e-07, -1.3867341e-07, ...,\n",
" 5.3598807e-08, -2.1517354e-07, 2.1897785e-08]], dtype=float32)>,\n",
" <tf.Tensor: shape=(128, 512), dtype=float32, numpy=\n",
" array([[ 9.0920679e-08, -9.7806634e-08, -1.0596389e-07, ...,\n",
" -4.0080572e-08, 3.3786947e-09, -5.9041042e-08],\n",
" [-4.3531806e-08, 2.8909722e-07, -1.0749537e-07, ...,\n",
" -3.2792688e-08, 9.8477187e-08, -2.0046159e-07],\n",
" [-1.2692003e-07, -2.2698548e-07, 2.2944741e-07, ...,\n",
" -1.3208687e-07, -2.1840023e-08, 7.7354201e-08],\n",
" ...,\n",
" [-5.3174009e-08, 4.6163819e-09, 6.1483739e-08, ...,\n",
" 4.6193790e-07, -9.2429282e-09, 1.3388501e-07],\n",
" [ 8.8607351e-08, -1.7749313e-07, 7.1082020e-08, ...,\n",
" 1.4228182e-07, -9.3125578e-08, 4.7980748e-09],\n",
" [ 1.8544512e-08, 1.6310172e-07, 1.0322760e-08, ...,\n",
" -2.6089724e-08, -4.1588436e-08, -2.5750484e-07]], dtype=float32)>,\n",
" <tf.Tensor: shape=(512,), dtype=float32, numpy=\n",
" array([-1.99110345e-05, -7.29465546e-06, 2.61911173e-05, 1.57037521e-05,\n",
" 4.71392450e-05, -2.40694553e-05, 3.97232907e-05, -2.41768757e-05,\n",
" 2.20949660e-05, 1.45677586e-05, 1.53334659e-05, -1.15410585e-05,\n",
" -1.78399441e-05, 7.72970543e-06, -3.04341957e-05, -1.23156778e-05,\n",
" -2.88985138e-05, -3.24518442e-05, -1.37811994e-05, -4.26956431e-05,\n",
" 2.68382955e-05, 5.45565399e-06, -4.25340431e-06, 1.34797592e-05,\n",
" 3.28052920e-05, -3.64782863e-06, -7.26121216e-05, -1.10095789e-05,\n",
" -2.20755610e-05, -5.45642615e-05, 6.67087152e-05, 4.31579392e-05,\n",
" -2.71007411e-05, -1.66865757e-05, 1.15991115e-05, 1.35062855e-05,\n",
" 5.31767218e-05, -8.22714428e-05, -1.52825942e-06, 1.00017196e-05,\n",
" 1.10279016e-05, -1.05835104e-04, -1.24400158e-05, 4.59890407e-06,\n",
" -6.26609108e-05, -2.25741060e-05, 9.35145436e-05, -8.17662112e-06,\n",
" -1.96223482e-06, 1.03931798e-05, 6.14793407e-06, -3.57266472e-05,\n",
" -2.66429197e-06, 6.07607362e-05, 2.20493257e-06, -5.11070357e-06,\n",
" -2.34731124e-05, -2.77198833e-05, -4.84299835e-06, 1.24287526e-05,\n",
" 4.04886887e-05, -9.61388014e-06, -2.38893244e-05, -1.97624195e-05,\n",
" 1.49732834e-04, -5.39622852e-06, -2.26899392e-05, 7.22518962e-07,\n",
" 2.41547041e-05, -2.52683276e-05, -4.31222943e-05, -4.30931759e-05,\n",
" -5.11325452e-05, -1.20022678e-05, -1.13347807e-04, -3.19706014e-05,\n",
" 2.77400686e-05, 1.15523808e-05, -3.50865739e-05, -3.18428574e-05,\n",
" -2.78922162e-05, 1.28757665e-05, -1.02233698e-05, 9.90332410e-06,\n",
" 1.31079069e-05, -5.32801678e-05, 1.87705227e-05, -1.26863042e-05,\n",
" -3.54922486e-05, 1.36968156e-04, -2.23922370e-06, -3.07610462e-05,\n",
" 1.22330648e-05, 1.26477562e-05, 1.43745383e-05, -7.98514011e-05,\n",
" -9.61628321e-05, -3.54778022e-05, -1.73661138e-05, -1.29420878e-04,\n",
" 5.61618035e-05, -8.48723175e-06, -1.58587463e-05, -7.84525037e-06,\n",
" 4.35119618e-05, 3.49604961e-05, 3.41538107e-05, -4.09324421e-05,\n",
" -3.71239803e-05, 2.63132315e-05, -8.98130929e-06, -1.15276180e-05,\n",
" 2.24412670e-05, 4.23215488e-06, 4.73320397e-05, 6.52694507e-06,\n",
" 2.72522266e-05, 9.80317873e-06, -1.62495126e-05, 4.84671946e-05,\n",
" 7.91536331e-06, -3.38005011e-05, 3.75219679e-05, -1.26622281e-05,\n",
" -1.43127882e-05, -1.01089681e-05, -5.82824214e-07, 4.73749606e-05,\n",
" -3.34400829e-05, -9.13100666e-06, 3.34885335e-05, 7.48173807e-06,\n",
" 5.00146343e-05, -4.63066608e-05, 5.54803737e-05, -6.93207694e-05,\n",
" 1.49135958e-05, 2.06972636e-05, 6.67572749e-07, -3.38796672e-05,\n",
" -2.45707542e-05, 5.00457736e-06, -7.11795874e-05, -3.28824171e-05,\n",
" -7.12445253e-05, -4.47741731e-05, -1.83749580e-05, -5.06311189e-05,\n",
" 2.51818274e-05, 1.93268388e-05, -3.03403503e-05, 2.56008680e-05,\n",
" 5.57523308e-05, -2.89987947e-05, -1.01858255e-04, -3.52753068e-06,\n",
" -2.11007937e-05, -6.58260324e-05, 8.68044590e-05, 5.96911414e-05,\n",
" -1.47248156e-05, -3.93647315e-06, 5.77755418e-06, 1.11619847e-05,\n",
" 1.06337742e-04, -7.39129755e-05, -1.76148787e-05, 2.84849120e-05,\n",
" 9.02289503e-06, -1.57410817e-04, -8.27398890e-06, 1.21711164e-05,\n",
" -7.93867002e-05, -4.40134572e-05, 1.02432459e-04, -3.75858945e-05,\n",
" -9.79200740e-06, -8.26310134e-06, 1.07381156e-05, -3.74959855e-05,\n",
" -6.88323780e-08, 1.00759702e-04, 1.37312236e-05, 2.63480179e-05,\n",
" -2.73081823e-05, -4.68854632e-05, 1.79791823e-05, 9.39827714e-06,\n",
" 5.23153649e-05, -1.15003641e-05, -4.02484293e-05, -3.32964737e-06,\n",
" 1.91177969e-04, -1.24483377e-05, -3.35657496e-05, 1.18089883e-05,\n",
" 3.86983993e-06, -2.82810270e-05, -5.05980242e-05, -9.17996877e-05,\n",
" -6.17987243e-05, -1.98910184e-05, -1.50251610e-04, -6.54812029e-05,\n",
" 1.50678688e-05, 1.77241236e-05, -5.29934405e-05, -2.31698523e-05,\n",
" -4.95133499e-05, -1.56820533e-07, -2.10566177e-05, 1.56480855e-05,\n",
" 2.50106350e-05, -6.91187961e-05, -1.54551053e-05, 2.49530731e-05,\n",
" -2.96732142e-05, 2.08491663e-04, -6.68619168e-06, -1.37702909e-05,\n",
" -5.03858864e-05, 4.12632726e-05, -2.86901486e-05, -8.94564000e-05,\n",
" -1.43599260e-04, -7.36482398e-05, -2.30675596e-06, -1.83999378e-04,\n",
" 8.15467065e-05, -2.82169822e-05, -3.25880064e-05, 2.02052215e-05,\n",
" 5.78480431e-05, 3.43770625e-06, 5.27895536e-05, -5.51045741e-05,\n",
" -4.11938563e-05, 3.63297222e-05, -1.12845282e-05, 5.06180186e-07,\n",
" 4.75956986e-05, -1.25265224e-05, 7.28672021e-05, 1.10194214e-05,\n",
" 2.42327988e-05, -3.21596599e-05, -2.95279242e-05, 7.33682391e-05,\n",
" -3.90976857e-06, -3.56802411e-05, 6.13520242e-05, -2.70030296e-06,\n",
" -3.46124507e-05, 1.15872890e-05, -2.52637255e-06, 7.42620468e-05,\n",
" -2.82356748e-03, 9.28777736e-03, 1.73751581e-02, -4.10499200e-02,\n",
" -4.51671481e-02, 3.32909301e-02, -3.30383927e-02, 5.03160767e-02,\n",
" -3.33478153e-02, -4.89518326e-03, -1.24594867e-02, -1.90301379e-03,\n",
" 7.89022632e-03, -3.43460143e-02, -4.65416834e-02, 2.02876348e-02,\n",
" -3.06195803e-02, 3.12450295e-03, -5.92209212e-02, 2.21785568e-02,\n",
" -3.97754312e-02, 3.58326547e-02, 5.22245467e-02, 1.27030537e-02,\n",
" -2.33417563e-02, 1.31659340e-02, 1.44711733e-02, 3.33931670e-02,\n",
" -1.63034862e-03, -1.68398954e-02, -4.64263670e-02, -2.31010932e-02,\n",
" 2.15429142e-02, -3.47415321e-02, 1.11877243e-03, -3.20874527e-02,\n",
" 5.04406765e-02, -4.71264534e-02, 2.37041665e-03, 4.25750390e-02,\n",
" -5.64210536e-03, -5.16877770e-02, -3.33238556e-03, -3.65043543e-02,\n",
" -4.31881547e-02, 2.77623609e-02, 4.35373001e-02, -9.80284158e-03,\n",
" -6.57556718e-03, 9.63302981e-03, 2.45699212e-02, -1.20348455e-02,\n",
" 3.58177796e-02, 4.03181985e-02, 1.08939670e-02, 4.60814917e-03,\n",
" -1.88708119e-02, -6.05110498e-03, 8.50436278e-04, 4.17488441e-02,\n",
" -2.86238641e-02, 1.25137642e-02, 9.99696460e-03, -2.48471536e-02,\n",
" -6.57288432e-02, -4.41054180e-02, 1.80869340e-03, -1.15264514e-02,\n",
" -5.49815595e-02, -3.83359101e-03, 5.45885488e-02, -1.82086844e-02,\n",
" 1.97352841e-02, -1.32854898e-02, 3.80803235e-02, -2.38714851e-02,\n",
" 1.27633018e-02, 2.73932256e-02, -3.79742645e-02, -2.63928697e-02,\n",
" -5.67285754e-02, -3.20242494e-02, 4.83035408e-02, 2.28642486e-02,\n",
" -3.45898569e-02, -2.30111182e-02, -7.41353817e-03, -2.42721885e-02,\n",
" -1.71943605e-02, 3.50073166e-02, 7.73749501e-03, 9.98023897e-03,\n",
" 5.36300708e-03, 4.23797593e-02, 2.52226219e-02, 2.60486584e-02,\n",
" 3.93506847e-02, 7.16058761e-02, -1.00544207e-02, -2.26410627e-02,\n",
" -2.88525205e-02, 9.60933790e-03, 1.30569823e-02, 4.85427156e-02,\n",
" 1.93330981e-02, -3.36245634e-02, 1.23050390e-03, 2.35427171e-02,\n",
" 3.31656858e-02, 2.35361233e-02, 2.96390150e-04, -1.67156989e-03,\n",
" -1.38984136e-02, -5.40724136e-02, 4.72925082e-02, -1.50090493e-02,\n",
" 2.23950148e-02, 2.23637652e-02, -5.58751523e-02, 3.77190337e-02,\n",
" 7.17466883e-03, 1.39882518e-02, -2.95721702e-02, 5.99282272e-02,\n",
" -2.90058665e-02, 5.81529103e-02, -6.22727163e-03, -3.30739133e-02,\n",
" -1.86167708e-05, -8.26393625e-06, 1.85363424e-05, 3.93697519e-06,\n",
" 6.22062507e-05, -3.35149198e-05, 4.28577405e-05, -3.29527393e-05,\n",
" 3.22472661e-05, 1.02839167e-05, 1.64196717e-05, -1.46906095e-05,\n",
" -1.88890153e-05, 3.56183682e-06, -3.06116017e-05, -1.29915452e-05,\n",
" -2.44489820e-05, -3.31477495e-05, -1.67895232e-05, -4.20412616e-05,\n",
" 2.38274915e-05, -4.84490101e-07, -7.07884328e-06, 1.06966863e-05,\n",
" 2.61116147e-05, -7.47097693e-06, -8.22395959e-05, -4.96298526e-07,\n",
" -2.12179802e-05, -6.15144818e-05, 7.81837807e-05, 4.90739403e-05,\n",
" -2.08170059e-05, -1.80483075e-05, 3.40487168e-06, 1.67522285e-05,\n",
" 5.50277036e-05, -8.44242168e-05, -8.92466232e-06, 1.96976725e-05,\n",
" 1.22851006e-05, -1.04505889e-04, -1.73575663e-05, 9.93075537e-06,\n",
" -6.43244784e-05, -2.50952125e-05, 8.69531650e-05, -4.96423127e-06,\n",
" -6.89434910e-06, 9.70117162e-06, 9.21935407e-06, -3.99091950e-05,\n",
" -2.90359822e-06, 7.03785408e-05, 1.44372348e-06, -2.17463894e-05,\n",
" -2.21670507e-05, -3.43847569e-05, 3.11448366e-06, 1.50533106e-05,\n",
" 5.36600637e-05, -1.49244943e-05, -2.03131785e-05, -1.57592058e-05,\n",
" 1.64054130e-04, -1.49498055e-05, -2.55462710e-05, 8.12182043e-06,\n",
" 1.87098176e-05, -1.58740186e-05, -4.62989774e-05, -6.41835504e-05,\n",
" -5.70351804e-05, -1.13773003e-05, -1.18867021e-04, -5.31088263e-05,\n",
" 2.54290953e-05, 1.49962407e-05, -4.32317320e-05, -2.96516337e-05,\n",
" -2.52453938e-05, 7.26332837e-06, -1.55389953e-05, 1.02725371e-05,\n",
" 2.61295572e-05, -4.55491536e-05, 1.16021420e-05, -6.67282984e-06,\n",
" -3.90158821e-05, 1.61193544e-04, -6.24900531e-06, -2.59799344e-05,\n",
" 6.33476748e-06, 2.56378080e-05, 1.49027719e-05, -6.58978679e-05,\n",
" -1.11316956e-04, -4.26451406e-05, -1.82335225e-05, -1.34126953e-04,\n",
" 6.26451074e-05, -1.22298261e-05, -1.43776870e-05, -3.52375218e-06,\n",
" 4.39129508e-05, 4.03065787e-05, 3.12449920e-05, -4.53516332e-05,\n",
" -4.20891847e-05, 2.80705517e-05, -1.27462572e-05, -7.86944929e-06,\n",
" 2.50178582e-05, -4.44879561e-06, 4.51270171e-05, 4.05591436e-06,\n",
" 2.70031233e-05, 1.36304516e-05, -2.12590166e-05, 5.47602613e-05,\n",
" 3.89774505e-06, -4.69338629e-05, 3.92884358e-05, -5.52044821e-06,\n",
" -3.36022604e-05, -6.58165573e-06, -4.14886381e-06, 5.84282279e-05],\n",
" dtype=float32)>,\n",
" <tf.Tensor: shape=(128, 10000), dtype=float32, numpy=\n",
" array([[ 9.5183105e-04, 2.1646395e-03, 9.0878268e-05, ...,\n",
" -2.6323535e-06, -2.6330113e-06, -2.6322307e-06],\n",
" [ 7.3757558e-04, 1.5442743e-03, 1.6750256e-03, ...,\n",
" -2.0000080e-06, -1.9987749e-06, -1.9970939e-06],\n",
" [-2.1177637e-03, -2.8248103e-03, -2.9303601e-03, ...,\n",
" 4.9381433e-06, 4.9378909e-06, 4.9390701e-06],\n",
" ...,\n",
" [-3.6386732e-04, 3.3230041e-05, -1.1236004e-03, ...,\n",
" 3.1839045e-07, 3.1763255e-07, 3.1627582e-07],\n",
" [-1.2519001e-04, 4.3181507e-04, 4.7914428e-04, ...,\n",
" -9.6191116e-07, -9.6355552e-07, -9.6267138e-07],\n",
" [ 1.6367263e-03, 1.3170558e-03, 9.3035400e-04, ...,\n",
" -2.7387218e-06, -2.7385686e-06, -2.7389342e-06]], dtype=float32)>,\n",
" <tf.Tensor: shape=(10000,), dtype=float32, numpy=\n",
" array([-0.7979981 , -1.0313315 , -1.0313318 , ..., 0.00199974,\n",
" 0.00199973, 0.00199951], dtype=float32)>]"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Define the gradient clipping threshold\n",
"grads, _ = tf.clip_by_global_norm(grad_t_list, max_grad_norm)\n",
"grads"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<h4> 4.Apply the optimizer to the variables/gradients tuple. </h4>\n"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"# Create the training TensorFlow Operation through our optimizer\n",
"train_op = optimizer.apply_gradients(zip(grads, tvars))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a id=\"ltsm\"></a>\n",
"\n",
"<h2>LSTM</h2>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"We learned how the model is build step by step. Noe, let's then create a Class that represents our model. This class needs a few things:\n",
"\n",
"<ul>\n",
" <li>We have to create the model in accordance with our defined hyperparameters</li>\n",
" <li>We have to create the LSTM cell structure and connect them with our RNN structure</li>\n",
" <li>We have to create the word embeddings and point them to the input data</li>\n",
" <li>We have to create the input structure for our RNN</li>\n",
" <li>We need to create a logistic structure to return the probability of our words</li>\n",
" <li>We need to create the loss and cost functions for our optimizer to work, and then create the optimizer</li>\n",
" <li>And finally, we need to create a training operation that can be run to actually train our model</li>\n",
"</ul>\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
},
"tags": []
},
"outputs": [],
"source": [
"class PTBModel(object):\n",
"\n",
"\n",
" def __init__(self):\n",
" ######################################\n",
" # Setting parameters for ease of use #\n",
" ######################################\n",
" self.batch_size = batch_size\n",
" self.num_steps = num_steps\n",
" self.hidden_size_l1 = hidden_size_l1\n",
" self.hidden_size_l2 = hidden_size_l2\n",
" self.vocab_size = vocab_size\n",
" self.embeding_vector_size = embeding_vector_size\n",
" # Create a variable for the learning rate\n",
" self._lr = 1.0\n",
" \n",
" ###############################################################################\n",
" # Initializing the model using keras Sequential API #\n",
" ###############################################################################\n",
" \n",
" self._model = tf.keras.models.Sequential()\n",
" \n",
" ####################################################################\n",
" # Creating the word embeddings layer and adding it to the sequence #\n",
" ####################################################################\n",
" with tf.device(\"/cpu:0\"):\n",
" # Create the embeddings for our input data. Size is hidden size.\n",
" self._embedding_layer = tf.keras.layers.Embedding(self.vocab_size, self.embeding_vector_size,batch_input_shape=(self.batch_size, self.num_steps),trainable=True,name=\"embedding_vocab\") #[10000x200]\n",
" self._model.add(self._embedding_layer)\n",
" \n",
"\n",
" ##########################################################################\n",
" # Creating the LSTM cell structure and connect it with the RNN structure #\n",
" ##########################################################################\n",
" # Create the LSTM Cells. \n",
" # This creates only the structure for the LSTM and has to be associated with a RNN unit still.\n",
" # The argument of LSTMCell is size of hidden layer, that is, the number of hidden units of the LSTM (inside A). \n",
" # LSTM cell processes one word at a time and computes probabilities of the possible continuations of the sentence.\n",
" lstm_cell_l1 = tf.keras.layers.LSTMCell(hidden_size_l1)\n",
" lstm_cell_l2 = tf.keras.layers.LSTMCell(hidden_size_l2)\n",
" \n",
"\n",
" \n",
" # By taking in the LSTM cells as parameters, the StackedRNNCells function junctions the LSTM units to the RNN units.\n",
" # RNN cell composed sequentially of stacked simple cells.\n",
" stacked_lstm = tf.keras.layers.StackedRNNCells([lstm_cell_l1, lstm_cell_l2])\n",
"\n",
"\n",
" \n",
"\n",
" ############################################\n",
" # Creating the input structure for our RNN #\n",
" ############################################\n",
" # Input structure is 20x[30x200]\n",
" # Considering each word is represended by a 200 dimentional vector, and we have 30 batchs, we create 30 word-vectors of size [30xx2000]\n",
" # The input structure is fed from the embeddings, which are filled in by the input data\n",
" # Feeding a batch of b sentences to a RNN:\n",
" # In step 1, first word of each of the b sentences (in a batch) is input in parallel. \n",
" # In step 2, second word of each of the b sentences is input in parallel. \n",
" # The parallelism is only for efficiency. \n",
" # Each sentence in a batch is handled in parallel, but the network sees one word of a sentence at a time and does the computations accordingly. \n",
" # All the computations involving the words of all sentences in a batch at a given time step are done in parallel. \n",
"\n",
" ########################################################################################################\n",
" # Instantiating our RNN model and setting stateful to True to feed forward the state to the next layer #\n",
" ########################################################################################################\n",
" \n",
" self._RNNlayer = tf.keras.layers.RNN(stacked_lstm,[batch_size, num_steps],return_state=False,stateful=True,trainable=True)\n",
" \n",
" # Define the initial state, i.e., the model state for the very first data point\n",
" # It initialize the state of the LSTM memory. The memory state of the network is initialized with a vector of zeros and gets updated after reading each word.\n",
" self._initial_state = tf.Variable(tf.zeros([batch_size,embeding_vector_size]),trainable=False)\n",
" self._RNNlayer.inital_state = self._initial_state\n",
" \n",
" ############################################\n",
" # Adding RNN layer to keras sequential API #\n",
" ############################################ \n",
" self._model.add(self._RNNlayer)\n",
" \n",
" #self._model.add(tf.keras.layers.LSTM(hidden_size_l1,return_sequences=True,stateful=True))\n",
" #self._model.add(tf.keras.layers.LSTM(hidden_size_l2,return_sequences=True))\n",
" \n",
" \n",
" ####################################################################################################\n",
" # Instantiating a Dense layer that connects the output to the vocab_size and adding layer to model#\n",
" ####################################################################################################\n",
" self._dense = tf.keras.layers.Dense(self.vocab_size)\n",
" self._model.add(self._dense)\n",
" \n",
" \n",
" ####################################################################################################\n",
" # Adding softmax activation layer and deriving probability to each class and adding layer to model #\n",
" ####################################################################################################\n",
" self._activation = tf.keras.layers.Activation('softmax')\n",
" self._model.add(self._activation)\n",
"\n",
" ##########################################################\n",
" # Instantiating the stochastic gradient decent optimizer #\n",
" ########################################################## \n",
" self._optimizer = tf.keras.optimizers.SGD(lr=self._lr, clipnorm=max_grad_norm)\n",
" \n",
" \n",
" ##############################################################################\n",
" # Compiling and summarizing the model stacked using the keras sequential API #\n",
" ##############################################################################\n",
" self._model.compile(loss=self.crossentropy, optimizer=self._optimizer)\n",
" self._model.summary()\n",
"\n",
"\n",
" def crossentropy(self,y_true, y_pred):\n",
" return tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)\n",
"\n",
" def train_batch(self,_input_data,_targets):\n",
" #################################################\n",
" # Creating the Training Operation for our Model #\n",
" #################################################\n",
" # Create a variable for the learning rate\n",
" self._lr = tf.Variable(0.0, trainable=False)\n",
" # Get all TensorFlow variables marked as \"trainable\" (i.e. all of them except _lr, which we just created)\n",
" tvars = self._model.trainable_variables\n",
" # Define the gradient clipping threshold\n",
" with tf.GradientTape() as tape:\n",
" # Forward pass.\n",
" output_words_prob = self._model(_input_data)\n",
" # Loss value for this batch.\n",
" loss = self.crossentropy(_targets, output_words_prob)\n",
" # average across batch and reduce sum\n",
" cost = tf.reduce_sum(loss/ self.batch_size)\n",
" # Get gradients of loss wrt the trainable variables.\n",
" grad_t_list = tape.gradient(cost, tvars)\n",
" # Define the gradient clipping threshold\n",
" grads, _ = tf.clip_by_global_norm(grad_t_list, max_grad_norm)\n",
" # Create the training TensorFlow Operation through our optimizer\n",
" train_op = self._optimizer.apply_gradients(zip(grads, tvars))\n",
" return cost\n",
" \n",
" def test_batch(self,_input_data,_targets):\n",
" #################################################\n",
" # Creating the Testing Operation for our Model #\n",
" #################################################\n",
" output_words_prob = self._model(_input_data)\n",
" loss = self.crossentropy(_targets, output_words_prob)\n",
" # average across batch and reduce sum\n",
" cost = tf.reduce_sum(loss/ self.batch_size)\n",
"\n",
" return cost\n",
" @classmethod\n",
" def instance(cls) : \n",
" return PTBModel()"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"With that, the actual structure of our Recurrent Neural Network with Long Short-Term Memory is finished. What remains for us to do is to actually create the methods to run through time -- that is, the <code>run_epoch</code> method to be run at each epoch and a <code>main</code> script which ties all of this together.\n",
"\n",
"What our <code>run_epoch</code> method should do is take our input data and feed it to the relevant operations. This will return at the very least the current result for the cost function.\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"outputs": [],
"source": [
"\n",
"########################################################################################################################\n",
"# run_one_epoch takes as parameters the model instance, the data to be fed, training or testing mode and verbose info #\n",
"########################################################################################################################\n",
"def run_one_epoch(m, data,is_training=True,verbose=False):\n",
"\n",
" #Define the epoch size based on the length of the data, batch size and the number of steps\n",
" epoch_size = ((len(data) // m.batch_size) - 1) // m.num_steps\n",
" start_time = time.time()\n",
" costs = 0.\n",
" iters = 0\n",
" \n",
" m._model.reset_states()\n",
" \n",
" #For each step and data point\n",
" for step, (x, y) in enumerate(reader.ptb_iterator(data, m.batch_size, m.num_steps)):\n",
" \n",
" #Evaluate and return cost, state by running cost, final_state and the function passed as parameter\n",
" #y = tf.keras.utils.to_categorical(y, num_classes=vocab_size)\n",
" if is_training : \n",
" loss= m.train_batch(x, y)\n",
" else :\n",
" loss = m.test_batch(x, y)\n",
" \n",
"\n",
" #Add returned cost to costs (which keeps track of the total costs for this epoch)\n",
" costs += loss\n",
" \n",
" #Add number of steps to iteration counter\n",
" iters += m.num_steps\n",
"\n",
" if verbose and step % (epoch_size // 10) == 10:\n",
" print(\"Itr %d of %d, perplexity: %.3f speed: %.0f wps\" % (step , epoch_size, np.exp(costs / iters), iters * m.batch_size / (time.time() - start_time)))\n",
" \n",
"\n",
"\n",
" # Returns the Perplexity rating for us to keep track of how the model is evolving\n",
" return np.exp(costs / iters)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"Now, we create the <code>main</code> method to tie everything together. The code here reads the data from the directory, using the <code>reader</code> helper module, and then trains and evaluates the model on both a testing and a validating subset of data.\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
},
"tags": []
},
"outputs": [],
"source": [
"# Reads the data and separates it into training data, validation data and testing data\n",
"raw_data = reader.ptb_raw_data(data_dir)\n",
"train_data, valid_data, test_data, _, _ = raw_data"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_1\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"embedding_vocab (Embedding) (30, 20, 200) 2000000 \n",
"_________________________________________________________________\n",
"rnn_1 (RNN) (30, 20, 128) 671088 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (30, 20, 10000) 1290000 \n",
"_________________________________________________________________\n",
"activation_1 (Activation) (30, 20, 10000) 0 \n",
"=================================================================\n",
"Total params: 3,961,088\n",
"Trainable params: 3,955,088\n",
"Non-trainable params: 6,000\n",
"_________________________________________________________________\n",
"Epoch 1 : Learning rate: 1.000\n",
"Itr 10 of 1549, perplexity: 4715.241 speed: 1308 wps\n",
"Itr 164 of 1549, perplexity: 1093.509 speed: 1335 wps\n",
"Itr 318 of 1549, perplexity: 845.907 speed: 1330 wps\n",
"Itr 472 of 1549, perplexity: 699.792 speed: 1327 wps\n",
"Itr 626 of 1549, perplexity: 596.043 speed: 1329 wps\n",
"Itr 780 of 1549, perplexity: 529.537 speed: 1329 wps\n",
"Itr 934 of 1549, perplexity: 477.252 speed: 1330 wps\n",
"Itr 1088 of 1549, perplexity: 437.885 speed: 1329 wps\n",
"Itr 1242 of 1549, perplexity: 407.549 speed: 1325 wps\n",
"Itr 1396 of 1549, perplexity: 379.556 speed: 1324 wps\n",
"Epoch 1 : Train Perplexity: 357.373\n",
"Epoch 1 : Valid Perplexity: 209.852\n",
"Epoch 2 : Learning rate: 1.000\n",
"Itr 10 of 1549, perplexity: 237.770 speed: 1292 wps\n",
"Itr 164 of 1549, perplexity: 209.217 speed: 1299 wps\n",
"Itr 318 of 1549, perplexity: 200.021 speed: 1288 wps\n",
"Itr 472 of 1549, perplexity: 191.926 speed: 1297 wps\n",
"Itr 626 of 1549, perplexity: 183.176 speed: 1296 wps\n",
"Itr 780 of 1549, perplexity: 179.550 speed: 1294 wps\n",
"Itr 934 of 1549, perplexity: 175.636 speed: 1302 wps\n",
"Itr 1088 of 1549, perplexity: 172.254 speed: 1300 wps\n",
"Itr 1242 of 1549, perplexity: 169.849 speed: 1290 wps\n",
"Itr 1396 of 1549, perplexity: 165.902 speed: 1283 wps\n",
"Epoch 2 : Train Perplexity: 163.108\n",
"Epoch 2 : Valid Perplexity: 161.490\n",
"Epoch 3 : Learning rate: 1.000\n",
"Itr 10 of 1549, perplexity: 159.223 speed: 1328 wps\n",
"Itr 164 of 1549, perplexity: 145.948 speed: 1226 wps\n",
"Itr 318 of 1549, perplexity: 142.239 speed: 1273 wps\n",
"Itr 472 of 1549, perplexity: 137.880 speed: 1293 wps\n",
"Itr 626 of 1549, perplexity: 132.748 speed: 1304 wps\n",
"Itr 780 of 1549, perplexity: 131.634 speed: 1309 wps\n",
"Itr 934 of 1549, perplexity: 129.992 speed: 1309 wps\n",
"Itr 1088 of 1549, perplexity: 128.514 speed: 1311 wps\n",
"Itr 1242 of 1549, perplexity: 127.692 speed: 1312 wps\n",
"Itr 1396 of 1549, perplexity: 125.498 speed: 1313 wps\n",
"Epoch 3 : Train Perplexity: 124.223\n",
"Epoch 3 : Valid Perplexity: 148.629\n",
"Epoch 4 : Learning rate: 1.000\n",
"Itr 10 of 1549, perplexity: 127.194 speed: 1303 wps\n",
"Itr 164 of 1549, perplexity: 118.747 speed: 10 wps\n",
"Itr 318 of 1549, perplexity: 116.761 speed: 7 wps\n",
"Itr 472 of 1549, perplexity: 113.585 speed: 10 wps\n",
"Itr 626 of 1549, perplexity: 109.791 speed: 13 wps\n",
"Itr 780 of 1549, perplexity: 109.369 speed: 16 wps\n",
"Itr 934 of 1549, perplexity: 108.379 speed: 20 wps\n",
"Itr 1088 of 1549, perplexity: 107.456 speed: 23 wps\n",
"Itr 1242 of 1549, perplexity: 107.111 speed: 26 wps\n",
"Itr 1396 of 1549, perplexity: 105.538 speed: 29 wps\n",
"Epoch 4 : Train Perplexity: 104.802\n",
"Epoch 4 : Valid Perplexity: 138.192\n",
"Epoch 5 : Learning rate: 1.000\n",
"Itr 10 of 1549, perplexity: 109.613 speed: 1018 wps\n",
"Itr 164 of 1549, perplexity: 103.603 speed: 76 wps\n",
"Itr 318 of 1549, perplexity: 102.100 speed: 138 wps\n",
"Itr 472 of 1549, perplexity: 99.375 speed: 191 wps\n",
"Itr 626 of 1549, perplexity: 96.303 speed: 237 wps\n",
"Itr 780 of 1549, perplexity: 96.143 speed: 282 wps\n",
"Itr 934 of 1549, perplexity: 95.503 speed: 324 wps\n",
"Itr 1088 of 1549, perplexity: 94.909 speed: 362 wps\n",
"Itr 1242 of 1549, perplexity: 94.743 speed: 398 wps\n",
"Itr 1396 of 1549, perplexity: 93.473 speed: 431 wps\n",
"Epoch 5 : Train Perplexity: 92.975\n",
"Epoch 5 : Valid Perplexity: 134.348\n",
"Epoch 6 : Learning rate: 0.500\n",
"Itr 10 of 1549, perplexity: 97.073 speed: 1316 wps\n",
"Itr 164 of 1549, perplexity: 89.961 speed: 1327 wps\n",
"Itr 318 of 1549, perplexity: 87.432 speed: 1325 wps\n",
"Itr 472 of 1549, perplexity: 84.165 speed: 1328 wps\n",
"Itr 626 of 1549, perplexity: 80.676 speed: 605 wps\n",
"Itr 780 of 1549, perplexity: 79.990 speed: 667 wps\n",
"Itr 934 of 1549, perplexity: 78.880 speed: 716 wps\n",
"Itr 1088 of 1549, perplexity: 77.823 speed: 759 wps\n",
"Itr 1242 of 1549, perplexity: 77.091 speed: 794 wps\n",
"Itr 1396 of 1549, perplexity: 75.523 speed: 825 wps\n",
"Epoch 6 : Train Perplexity: 74.570\n",
"Epoch 6 : Valid Perplexity: 124.390\n",
"Epoch 7 : Learning rate: 0.250\n",
"Itr 10 of 1549, perplexity: 81.960 speed: 1150 wps\n",
"Itr 164 of 1549, perplexity: 77.454 speed: 1195 wps\n",
"Itr 318 of 1549, perplexity: 75.267 speed: 1183 wps\n",
"Itr 472 of 1549, perplexity: 72.486 speed: 1201 wps\n",
"Itr 626 of 1549, perplexity: 69.372 speed: 1227 wps\n",
"Itr 780 of 1549, perplexity: 68.726 speed: 1239 wps\n",
"Itr 934 of 1549, perplexity: 67.688 speed: 1241 wps\n",
"Itr 1088 of 1549, perplexity: 66.665 speed: 1248 wps\n",
"Itr 1242 of 1549, perplexity: 65.895 speed: 1253 wps\n",
"Itr 1396 of 1549, perplexity: 64.396 speed: 1257 wps\n",
"Epoch 7 : Train Perplexity: 63.422\n",
"Epoch 7 : Valid Perplexity: 122.348\n",
"Epoch 8 : Learning rate: 0.125\n",
"Itr 10 of 1549, perplexity: 73.971 speed: 1365 wps\n",
"Itr 164 of 1549, perplexity: 70.415 speed: 1311 wps\n",
"Itr 318 of 1549, perplexity: 68.497 speed: 1288 wps\n",
"Itr 472 of 1549, perplexity: 65.975 speed: 1284 wps\n",
"Itr 626 of 1549, perplexity: 63.115 speed: 1283 wps\n",
"Itr 780 of 1549, perplexity: 62.527 speed: 1293 wps\n",
"Itr 934 of 1549, perplexity: 61.596 speed: 1299 wps\n",
"Itr 1088 of 1549, perplexity: 60.614 speed: 1303 wps\n",
"Itr 1242 of 1549, perplexity: 59.858 speed: 1306 wps\n",
"Itr 1396 of 1549, perplexity: 58.435 speed: 1307 wps\n",
"Epoch 8 : Train Perplexity: 57.486\n",
"Epoch 8 : Valid Perplexity: 121.998\n",
"Epoch 9 : Learning rate: 0.062\n",
"Itr 10 of 1549, perplexity: 69.804 speed: 1302 wps\n",
"Itr 164 of 1549, perplexity: 66.727 speed: 1330 wps\n",
"Itr 318 of 1549, perplexity: 64.957 speed: 1316 wps\n",
"Itr 472 of 1549, perplexity: 62.584 speed: 1307 wps\n",
"Itr 626 of 1549, perplexity: 59.853 speed: 1303 wps\n",
"Itr 780 of 1549, perplexity: 59.293 speed: 1307 wps\n",
"Itr 934 of 1549, perplexity: 58.414 speed: 1312 wps\n",
"Itr 1088 of 1549, perplexity: 57.457 speed: 1311 wps\n",
"Itr 1242 of 1549, perplexity: 56.716 speed: 1303 wps\n",
"Itr 1396 of 1549, perplexity: 55.338 speed: 1299 wps\n",
"Epoch 9 : Train Perplexity: 54.409\n",
"Epoch 9 : Valid Perplexity: 121.805\n",
"Epoch 10 : Learning rate: 0.031\n",
"Itr 10 of 1549, perplexity: 67.716 speed: 1281 wps\n",
"Itr 164 of 1549, perplexity: 64.759 speed: 1237 wps\n",
"Itr 318 of 1549, perplexity: 63.069 speed: 1251 wps\n",
"Itr 472 of 1549, perplexity: 60.789 speed: 1239 wps\n",
"Itr 626 of 1549, perplexity: 58.134 speed: 1246 wps\n",
"Itr 780 of 1549, perplexity: 57.589 speed: 1251 wps\n",
"Itr 934 of 1549, perplexity: 56.734 speed: 1236 wps\n",
"Itr 1088 of 1549, perplexity: 55.788 speed: 1232 wps\n",
"Itr 1242 of 1549, perplexity: 55.053 speed: 1221 wps\n",
"Itr 1396 of 1549, perplexity: 53.699 speed: 1217 wps\n",
"Epoch 10 : Train Perplexity: 52.778\n",
"Epoch 10 : Valid Perplexity: 121.732\n",
"Epoch 11 : Learning rate: 0.016\n",
"Itr 10 of 1549, perplexity: 66.560 speed: 1260 wps\n",
"Itr 164 of 1549, perplexity: 63.676 speed: 1256 wps\n",
"Itr 318 of 1549, perplexity: 62.014 speed: 1268 wps\n",
"Itr 472 of 1549, perplexity: 59.791 speed: 1233 wps\n",
"Itr 626 of 1549, perplexity: 57.185 speed: 1238 wps\n",
"Itr 780 of 1549, perplexity: 56.648 speed: 1236 wps\n",
"Itr 934 of 1549, perplexity: 55.807 speed: 1226 wps\n",
"Itr 1088 of 1549, perplexity: 54.869 speed: 1215 wps\n",
"Itr 1242 of 1549, perplexity: 54.141 speed: 1206 wps\n",
"Itr 1396 of 1549, perplexity: 52.802 speed: 1201 wps\n",
"Epoch 11 : Train Perplexity: 51.884\n",
"Epoch 11 : Valid Perplexity: 121.586\n",
"Epoch 12 : Learning rate: 0.008\n",
"Itr 10 of 1549, perplexity: 65.875 speed: 1207 wps\n",
"Itr 164 of 1549, perplexity: 63.091 speed: 1203 wps\n",
"Itr 318 of 1549, perplexity: 61.445 speed: 1211 wps\n",
"Itr 472 of 1549, perplexity: 59.245 speed: 1159 wps\n",
"Itr 626 of 1549, perplexity: 56.662 speed: 1160 wps\n",
"Itr 780 of 1549, perplexity: 56.128 speed: 1148 wps\n",
"Itr 934 of 1549, perplexity: 55.293 speed: 1148 wps\n",
"Itr 1088 of 1549, perplexity: 54.363 speed: 1139 wps\n",
"Itr 1242 of 1549, perplexity: 53.638 speed: 1141 wps\n",
"Itr 1396 of 1549, perplexity: 52.310 speed: 1145 wps\n",
"Epoch 12 : Train Perplexity: 51.393\n",
"Epoch 12 : Valid Perplexity: 121.364\n",
"Epoch 13 : Learning rate: 0.004\n",
"Itr 10 of 1549, perplexity: 65.478 speed: 1251 wps\n",
"Itr 164 of 1549, perplexity: 62.757 speed: 1211 wps\n",
"Itr 318 of 1549, perplexity: 61.139 speed: 1218 wps\n",
"Itr 472 of 1549, perplexity: 58.956 speed: 1205 wps\n",
"Itr 626 of 1549, perplexity: 56.381 speed: 1203 wps\n",
"Itr 780 of 1549, perplexity: 55.847 speed: 1201 wps\n",
"Itr 934 of 1549, perplexity: 55.015 speed: 1202 wps\n",
"Itr 1088 of 1549, perplexity: 54.088 speed: 1207 wps\n",
"Itr 1242 of 1549, perplexity: 53.365 speed: 1207 wps\n",
"Itr 1396 of 1549, perplexity: 52.044 speed: 1109 wps\n",
"Epoch 13 : Train Perplexity: 51.128\n",
"Epoch 13 : Valid Perplexity: 121.125\n",
"Epoch 14 : Learning rate: 0.002\n",
"Itr 10 of 1549, perplexity: 65.263 speed: 1278 wps\n",
"Itr 164 of 1549, perplexity: 62.555 speed: 1222 wps\n",
"Itr 318 of 1549, perplexity: 60.959 speed: 1238 wps\n",
"Itr 472 of 1549, perplexity: 58.794 speed: 1255 wps\n",
"Itr 626 of 1549, perplexity: 56.224 speed: 1260 wps\n",
"Itr 780 of 1549, perplexity: 55.693 speed: 1272 wps\n",
"Itr 934 of 1549, perplexity: 54.864 speed: 1281 wps\n",
"Itr 1088 of 1549, perplexity: 53.939 speed: 1287 wps\n",
"Itr 1242 of 1549, perplexity: 53.217 speed: 1291 wps\n",
"Itr 1396 of 1549, perplexity: 51.900 speed: 1294 wps\n",
"Epoch 14 : Train Perplexity: 50.986\n",
"Epoch 14 : Valid Perplexity: 120.963\n",
"Epoch 15 : Learning rate: 0.001\n",
"Itr 10 of 1549, perplexity: 65.155 speed: 1320 wps\n",
"Itr 164 of 1549, perplexity: 62.440 speed: 1326 wps\n",
"Itr 318 of 1549, perplexity: 60.854 speed: 1320 wps\n",
"Itr 472 of 1549, perplexity: 58.702 speed: 1319 wps\n",
"Itr 626 of 1549, perplexity: 56.138 speed: 1322 wps\n",
"Itr 780 of 1549, perplexity: 55.609 speed: 1319 wps\n",
"Itr 934 of 1549, perplexity: 54.782 speed: 1319 wps\n",
"Itr 1088 of 1549, perplexity: 53.859 speed: 1321 wps\n",
"Itr 1242 of 1549, perplexity: 53.138 speed: 1324 wps\n",
"Itr 1396 of 1549, perplexity: 51.824 speed: 1325 wps\n",
"Epoch 15 : Train Perplexity: 50.912\n",
"Epoch 15 : Valid Perplexity: 120.879\n",
"Test Perplexity: 114.788\n",
"Total time take to train and test model %d minutes -41272.78078222275\n"
]
}
],
"source": [
"# Instantiates the PTBModel class\n",
"m=PTBModel.instance() \n",
"K = tf.keras.backend \n",
"for i in range(max_epoch):\n",
" # Define the decay for this epoch\n",
" lr_decay = decay ** max(i - max_epoch_decay_lr, 0.0)\n",
" dcr = learning_rate * lr_decay\n",
" m._lr = dcr\n",
" K.set_value(m._model.optimizer.learning_rate,m._lr)\n",
" print(\"Epoch %d : Learning rate: %.3f\" % (i + 1, m._model.optimizer.learning_rate))\n",
" # Run the loop for this epoch in the training mode\n",
" train_perplexity = run_one_epoch(m, train_data,is_training=True,verbose=True)\n",
" print(\"Epoch %d : Train Perplexity: %.3f\" % (i + 1, train_perplexity))\n",
" \n",
" # Run the loop for this epoch in the validation mode\n",
" valid_perplexity = run_one_epoch(m, valid_data,is_training=False,verbose=False)\n",
" print(\"Epoch %d : Valid Perplexity: %.3f\" % (i + 1, valid_perplexity))\n",
" \n",
"# Run the loop in the testing mode to see how effective was our training\n",
"test_perplexity = run_one_epoch(m, test_data,is_training=False,verbose=False)\n",
"print(\"Test Perplexity: %.3f\" % test_perplexity)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"As you can see, the model's perplexity rating drops very quickly after a few iterations. As was elaborated before, <b>lower Perplexity means that the model is more certain about its prediction</b>. As such, we can be sure that this model is performing well!\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"* * *\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"This is the end of the <b>Applying Recurrent Neural Networks to Text Processing</b> notebook. Hopefully you now have a better understanding of Recurrent Neural Networks and how to implement one utilizing TensorFlow. Thank you for reading this notebook, and good luck on your studies.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"## Want to learn more?\n",
"\n",
"Running deep learning programs usually needs a high performance platform. **PowerAI** speeds up deep learning and AI. Built on IBM’s Power Systems, **PowerAI** is a scalable software platform that accelerates deep learning and AI with blazing performance for individual users or enterprises. The **PowerAI** platform supports popular machine learning libraries and dependencies including TensorFlow, Caffe, Torch, and Theano. You can use [PowerAI on IMB Cloud](https://cocl.us/ML0120EN_PAI).\n",
"\n",
"Also, you can use **Watson Studio** to run these notebooks faster with bigger datasets.**Watson Studio** is IBM’s leading cloud solution for data scientists, built by data scientists. With Jupyter notebooks, RStudio, Apache Spark and popular libraries pre-packaged in the cloud, **Watson Studio** enables data scientists to collaborate on their projects without having to install anything. Join the fast-growing community of **Watson Studio** users today with a free account at [Watson Studio](https://cocl.us/ML0120EN_DSX).This is the end of this lesson. Thank you for reading this notebook, and good luck on your studies.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"### Thanks for completing this lesson!\n",
"\n",
"Notebook created by <a href=\"https://br.linkedin.com/in/walter-gomes-de-amorim-junior-624726121\">Walter Gomes de Amorim Junior</a>, <a href = \"https://linkedin.com/in/saeedaghabozorgi\"> Saeed Aghabozorgi </a></h4>\n",
"\n",
"Updated to TF 2.X by <a href=\"https://www.linkedin.com/in/samaya-madhavan\"> Samaya Madhavan </a>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"button": false,
"deletable": true,
"new_sheet": false,
"run_control": {
"read_only": false
}
},
"source": [
"<hr>\n",
"\n",
"Copyright © 2018 [Cognitive Class](https://cocl.us/DX0108EN_CC). This notebook and its source code are released under the terms of the [MIT License](https://bigdatauniversity.com/mit-license?cm_mmc=Email_Newsletter-_-Developer_Ed%2BTech-_-WW_WW-_-SkillsNetwork-Courses-IBMDeveloperSkillsNetwork-DL0120EN-SkillsNetwork-20629446&cm_mmca1=000026UJ&cm_mmca2=10006555&cm_mmca3=M12345678&cvosrc=email.Newsletter.M12345678&cvo_campaign=000026UJ&cm_mmc=Email_Newsletter-_-Developer_Ed%2BTech-_-WW_WW-_-SkillsNetwork-Courses-IBMDeveloperSkillsNetwork-DL0120EN-SkillsNetwork-20629446&cm_mmca1=000026UJ&cm_mmca2=10006555&cm_mmca3=M12345678&cvosrc=email.Newsletter.M12345678&cvo_campaign=000026UJ).\n"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python",
"language": "python",
"name": "conda-env-python-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.11"
},
"widgets": {
"state": {},
"version": "1.1.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment