-
-
Save YasuThompson/638c3999adf2f3961f89248f8ae49ffe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"BUFFER_SIZE = len(input_tensor_train)\n", | |
"BATCH_SIZE = 64\n", | |
"steps_per_epoch = len(input_tensor_train)//BATCH_SIZE\n", | |
"\n", | |
"''' You compress the 9414 dimensional input vectors into 256 dimensional vectors. '''\n", | |
"embedding_dim = 256\n", | |
"''' The dimension of the hidden state/vector. '''\n", | |
"units = 1024\n", | |
"vocab_inp_size = len(inp_lang.word_index)+1 # 9414 \n", | |
"vocab_tar_size = len(targ_lang.word_index)+1 # 4935 \n", | |
"\n", | |
"dataset = tf.data.Dataset.from_tensor_slices((input_tensor_train, target_tensor_train)).shuffle(BUFFER_SIZE)\n", | |
"dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"'''\n", | |
" In this implementation, during training the 'Encoder' class gets a (64, 16) tensor as an input, \n", | |
" and it gives out a (64, 16, 1024) tensor as an output, regrdless of how many words the inputs have. \n", | |
" That means the class gets the whole sentence as a sequence of integers, and gives out a 1024-dim \n", | |
" vector every every times step, I mean each token. \n", | |
"'''\n", | |
"\n", | |
"class Encoder(tf.keras.Model):\n", | |
" def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz):\n", | |
" super(Encoder, self).__init__()\n", | |
" self.batch_sz = batch_sz # 64 \n", | |
" self.enc_units = enc_units # 24000 // 64 = 375 \n", | |
" \n", | |
" '''\n", | |
" As I explained in the last article, you propagate input 9414 dimensional vectors to 256 embedding vectors. \n", | |
" '''\n", | |
" self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim) # (9414, 256)\n", | |
" \n", | |
" '''\n", | |
" We use a RNN model named GRU for this seq2seq translation model. \n", | |
" All you have to keep in mind is, in this implentation, at time step t, one GRU cell takes 'embedding_dim'(=256) \n", | |
" dimensional vector as an input, and gives out a 16 dimensional (the maximum size of input sentences) output vector \n", | |
" and succeeds a hidden state/vector to the next GRU cell. \n", | |
" '''\n", | |
" \n", | |
" self.gru = tf.keras.layers.GRU(self.enc_units, # 1024 *the dimension of the hidden vector/state. \n", | |
" return_sequences=True, \n", | |
" return_state=True,\n", | |
" recurrent_initializer='glorot_uniform')\n", | |
"\n", | |
" def call(self, x, hidden):\n", | |
" x = self.embedding(x)\n", | |
" \n", | |
" '''\n", | |
" tf.keras.layers.GRU class gets [batch, timesteps, feature] sized tensors as inputs. \n", | |
" https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU\n", | |
" '''\n", | |
" output, state = self.gru(x, initial_state = hidden)\n", | |
" return output, state\n", | |
"\n", | |
"\n", | |
" def initialize_hidden_state(self):\n", | |
" return tf.zeros((self.batch_sz, self.enc_units))\n", | |
"\n", | |
"\n", | |
"'''\n", | |
" You construct an 'Encoder' calss as below. \n", | |
" One cell get a 9414 dimensional one-hot vector, and i\n", | |
"'''\n", | |
"encoder = Encoder(vocab_inp_size, # 9414 \n", | |
" embedding_dim, # 256 \n", | |
" units, # 1024 \n", | |
" BATCH_SIZE # 24000 \n", | |
" )" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment