-
-
Save kiransair/96cab2a4fc926e33089a5d9b783b2bda to your computer and use it in GitHub Desktop.
TF_Forum_24149.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyOq+W8P7APaj0hxTtt2NxR7", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/kiransair/96cab2a4fc926e33089a5d9b783b2bda/tf_forum_24149.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "5pqRiP41JGhi" | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"from tensorflow.keras.layers import Dropout,Bidirectional,LSTM,Dense,LSTMCell,Input\n", | |
"from tensorflow.keras import Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class LuongAttention(tf.keras.layers.Layer):\n", | |
" def __init__(self, units):\n", | |
" super(LuongAttention, self).__init__()\n", | |
" self.W1 = tf.keras.layers.Dense(units)\n", | |
" self.W2 = tf.keras.layers.Dense(units)\n", | |
" self.V = tf.keras.layers.Dense(1)\n", | |
"\n", | |
" def call(self, query, values):\n", | |
" query_with_time_axis = tf.expand_dims(query, 1)\n", | |
" values_transposed = tf.transpose(values, perm=[0, 2, 1])\n", | |
" score = self.V(tf.nn.tanh(\n", | |
" self.W1(query_with_time_axis) + self.W2(values)))\n", | |
" attention_weights = tf.nn.softmax(score, axis=1)\n", | |
" context_vector = attention_weights * values\n", | |
" context_vector = tf.reduce_sum(context_vector, axis=1)\n", | |
" return context_vector, attention_weights" | |
], | |
"metadata": { | |
"id": "dMPPtRMbJQ8S" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Encoder(tf.keras.layers.Layer):\n", | |
" def __init__(self,\n", | |
" lstm_units,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" num_features,\n", | |
" regularization):\n", | |
" super().__init__()\n", | |
" self.lstm_units = lstm_units\n", | |
" self.dropout_rate = dropout_rate\n", | |
" self.l2_penalty = l2_penalty\n", | |
" self.dropout = Dropout(dropout_rate)\n", | |
" if regularization:\n", | |
" self.lstm1 = Bidirectional(LSTM(lstm_units[0],\n", | |
" return_sequences = True,\n", | |
" return_state = True,\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty)))\n", | |
" self.lstm2 = Bidirectional(LSTM(lstm_units[0],\n", | |
" return_sequences = True,\n", | |
" return_state = True,\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty)))\n", | |
" self.dense = Dense(num_features,\n", | |
" activation = 'relu',\n", | |
" kernel_initializer = tf.keras.initializers.HeNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty))\n", | |
" else:\n", | |
" self.lstm1 = Bidirectional(LSTM(lstm_units[0],\n", | |
" return_sequences = True,\n", | |
" return_state = True,\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal()))\n", | |
" self.lstm2 = Bidirectional(LSTM(lstm_units[0],\n", | |
" return_sequences = True,\n", | |
" return_state = True,\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal()))\n", | |
" self.dense = Dense(num_features,\n", | |
" activation = 'relu',\n", | |
" kernel_initializer = tf.keras.initializers.HeNormal())\n", | |
"\n", | |
" def call(self,\n", | |
" encoder_inputs,\n", | |
" training = None):\n", | |
" output_lstm1, forward_state_h_lstm1, forward_state_c_lstm1, backward_state_h_lstm1, backward_state_c_lstm1 = self.lstm1(encoder_inputs)\n", | |
"\n", | |
" output_lstm2, forward_state_h_lstm2, forward_state_c_lstm2, backward_state_h_lstm2, backward_state_c_lstm2 = self.lstm2(output_lstm1,\n", | |
" initial_state = [forward_state_h_lstm1,\n", | |
" forward_state_c_lstm1,\n", | |
" backward_state_h_lstm1,\n", | |
" backward_state_c_lstm1])\n", | |
" state_h_lstm2 = tf.concat([forward_state_h_lstm2, backward_state_h_lstm2], axis = -1)\n", | |
" state_c_lstm2 = tf.concat([forward_state_c_lstm2, backward_state_c_lstm2], axis = -1)\n", | |
" state_lstm2 = [state_h_lstm2, state_c_lstm2]\n", | |
" return output_lstm2, state_lstm2\n" | |
], | |
"metadata": { | |
"id": "SbOf14idJS-w" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class DecoderBase(tf.keras.layers.Layer):\n", | |
" def __init__(self,\n", | |
" out_step,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" regularization):\n", | |
" super().__init__()\n", | |
" self.out_step = out_step\n", | |
" self.dropout_rate = dropout_rate\n", | |
" self.base_dropout = Dropout(dropout_rate)\n", | |
" if regularization:\n", | |
" self.base_dense = Dense(1,\n", | |
" activation = 'relu',\n", | |
" kernel_initializer = tf.keras.initializers.HeNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty))\n", | |
" else:\n", | |
" self.base_dense = Dense(1,\n", | |
" activation = 'relu',\n", | |
" kernel_initializer = tf.keras.initializers.HeNormal())\n", | |
"\n", | |
" def run_single_recurrent_step(self,\n", | |
" inputs,\n", | |
" states,\n", | |
" input_sequence_data,\n", | |
" training):\n", | |
" raise NotImplementedError()\n", | |
"\n", | |
" def call(self,\n", | |
" decoder_inputs,\n", | |
" initial_inputs,\n", | |
" initial_states,\n", | |
" input_sequence_data,\n", | |
" teacher_force_prob = None,\n", | |
" training = None):\n", | |
" predictions = []\n", | |
" input_data = self.base_dropout(initial_inputs)\n", | |
" input_data = self.base_dense(input_data)\n", | |
" states = initial_states\n", | |
" for t in range(self.out_step):\n", | |
" inputs = input_data\n", | |
" outputs, states_output = self.run_single_recurrent_step(inputs, states, input_sequence_data, training)\n", | |
" predictions.append(outputs)\n", | |
" teacher_force = random.random() < teacher_force_prob if teacher_force_prob is not None else False\n", | |
" if teacher_force:\n", | |
" input_data = decoder_inputs[:, t, :]\n", | |
" else:\n", | |
" input_data = outputs\n", | |
" states = states_output\n", | |
"\n", | |
" outputs_predictions = tf.stack(predictions)\n", | |
" outputs_predictions = tf.transpose(outputs_predictions, [1, 0, 2])\n", | |
" return outputs_predictions" | |
], | |
"metadata": { | |
"id": "vDm9ihwuJUtQ" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class DecoderVanilla(DecoderBase):\n", | |
" def __init__(self,\n", | |
" lstm_units,\n", | |
" out_step,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" regularization):\n", | |
" super().__init__(out_step,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" regularization)\n", | |
" self.lstm_units = lstm_units\n", | |
" self.dropout = Dropout(dropout_rate)\n", | |
" if regularization:\n", | |
" self.lstm_cell = LSTMCell(lstm_units[1],\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty))\n", | |
" self.dense = Dense(1,\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty))\n", | |
" else:\n", | |
" self.lstm_cell = LSTMCell(lstm_units[1],\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal())\n", | |
" self.dense = Dense(1)\n", | |
"\n", | |
" def run_single_recurrent_step(self,\n", | |
" inputs,\n", | |
" states,\n", | |
" input_sequence_data,\n", | |
" training):\n", | |
" return_outputs, return_states = self.lstm_cell(inputs, states = states)\n", | |
" return_outputs = self.dense(tf.concat([return_outputs, inputs], axis = -1))\n", | |
"\n", | |
" return return_outputs, return_states" | |
], | |
"metadata": { | |
"id": "Na8vn9pnJW6M" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class DecoderWithAttention(DecoderBase):\n", | |
" def __init__(self,\n", | |
" lstm_units,\n", | |
" out_step,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" regularization):\n", | |
" super().__init__(out_step,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" regularization)\n", | |
" self.lstm_units = lstm_units\n", | |
" self.dropout = Dropout(dropout_rate)\n", | |
" self.attention = LuongAttention(32)\n", | |
" if regularization:\n", | |
" self.lstm_cell = LSTMCell(lstm_units[1],\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty))\n", | |
" self.dense = Dense(1,\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty))\n", | |
" else:\n", | |
" self.lstm_cell = LSTMCell(lstm_units[1],\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal())\n", | |
" self.dense = Dense(1)\n", | |
"\n", | |
" def run_single_recurrent_step(self,\n", | |
" inputs,\n", | |
" states,\n", | |
" input_sequence_data,\n", | |
" training):\n", | |
" query = states[0]\n", | |
" values = input_sequence_data\n", | |
" context_vector, attention_weights = self.attention(query, values)\n", | |
" inputs_concat = tf.concat([context_vector, inputs], axis = -1)\n", | |
" return_outputs, return_states = self.lstm_cell(inputs_concat, states = states)\n", | |
" return_outputs = self.dense(tf.concat([return_outputs, inputs_concat, context_vector], axis = -1))\n", | |
"\n", | |
" return return_outputs, return_states" | |
], | |
"metadata": { | |
"id": "S1NfSlyZJZP_" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def seq2seq(encoder_input_shape,\n", | |
" decoder_input_shape,\n", | |
" out_step,\n", | |
" num_features,\n", | |
" type_decoder,\n", | |
" lstm_units,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" teacher_force_prob,\n", | |
" regularization,\n", | |
" training):\n", | |
" encoder = Encoder(lstm_units,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" num_features,\n", | |
" regularization)\n", | |
" if type_decoder == 'Vanilla':\n", | |
" decoder = DecoderVanilla(lstm_units,\n", | |
" out_step,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" regularization)\n", | |
" elif type_decoder == 'WithAttention':\n", | |
" decoder = DecoderWithAttention(lstm_units,\n", | |
" out_step,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" regularization)\n", | |
" encoder_inputs = Input(encoder_input_shape)\n", | |
" decoder_inputs = Input(decoder_input_shape)\n", | |
" encoder_outputs, encoder_states = encoder(encoder_inputs, training)\n", | |
" decoder_outputs = decoder(decoder_inputs, encoder_outputs[:,-1,:], encoder_states, encoder_outputs, teacher_force_prob, training)\n", | |
" # Buil model\n", | |
" model = Model(inputs = [encoder_inputs, decoder_inputs], outputs = decoder_outputs, name = 'Seq2Seq')\n", | |
"\n", | |
" return model" | |
], | |
"metadata": { | |
"id": "AGCZpngKJbQl" | |
}, | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"lstm_units = [32,\n", | |
" 2*32]\n", | |
"n_features = 14\n", | |
"dropout_rate = 0.3\n", | |
" #dropout_rate = 0.3\n", | |
"l2_penalty = 0.01\n", | |
"#l2_penalty = 0.001\n", | |
"batch_size = 512\n", | |
"n_epochs = 200\n", | |
"out_step = 12\n", | |
"regularization = True\n", | |
"training = True\n", | |
"\n", | |
"\n", | |
"encoder_input_shape = (12, 14)\n", | |
"decoder_input_shape = (12, 1)\n", | |
"\n", | |
"model_vanilla = seq2seq(encoder_input_shape,\n", | |
" decoder_input_shape,\n", | |
" out_step,\n", | |
" n_features,\n", | |
" 'Vanilla',\n", | |
" lstm_units,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" teacher_force_prob = None,\n", | |
" regularization = True,\n", | |
" training = True)" | |
], | |
"metadata": { | |
"id": "UHuFKDMTJdUV" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"x = tf.ones(shape = (800, 12, 14))\n", | |
"y = tf.ones(shape = (800, 12, 1))\n", | |
"out = model_vanilla((x, y)) # this model takes 2 inputs" | |
], | |
"metadata": { | |
"id": "yStOID6RJfVg" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Encoder(tf.keras.layers.Layer):\n", | |
" def __init__(self,\n", | |
" lstm_units,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" num_features,\n", | |
" regularization):\n", | |
" super().__init__()\n", | |
" self.lstm_units = lstm_units\n", | |
" self.dropout_rate = dropout_rate\n", | |
" self.l2_penalty = l2_penalty\n", | |
" self.dropout = Dropout(dropout_rate)\n", | |
" if regularization:\n", | |
" self.lstm1 = Bidirectional(LSTM(lstm_units[0],\n", | |
" return_sequences = True,\n", | |
" return_state = True,\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty)))\n", | |
" self.lstm2 = Bidirectional(LSTM(lstm_units[0],\n", | |
" return_sequences = True,\n", | |
" return_state = True,\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty)))\n", | |
" self.dense = Dense(num_features,\n", | |
" activation = 'relu',\n", | |
" kernel_initializer = tf.keras.initializers.HeNormal(),\n", | |
" kernel_regularizer = tf.keras.regularizers.l2(l2_penalty))\n", | |
" else:\n", | |
" self.lstm1 = Bidirectional(LSTM(lstm_units[0],\n", | |
" return_sequences = True,\n", | |
" return_state = True,\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal()))\n", | |
" self.lstm2 = Bidirectional(LSTM(lstm_units[0],\n", | |
" return_sequences = True,\n", | |
" return_state = True,\n", | |
" dropout = dropout_rate,\n", | |
" recurrent_dropout = dropout_rate,\n", | |
" kernel_initializer = tf.keras.initializers.GlorotNormal()))\n", | |
" self.dense = Dense(num_features,\n", | |
" activation = 'relu',\n", | |
" kernel_initializer = tf.keras.initializers.HeNormal())\n", | |
"\n", | |
" def call(self,\n", | |
" encoder_inputs,\n", | |
" training = None):\n", | |
" output_lstm1, forward_state_h_lstm1, forward_state_c_lstm1, backward_state_h_lstm1, backward_state_c_lstm1 = self.lstm1(encoder_inputs)\n", | |
"\n", | |
" output_lstm2, forward_state_h_lstm2, forward_state_c_lstm2, backward_state_h_lstm2, backward_state_c_lstm2 = self.lstm2(output_lstm1,\n", | |
" initial_state = [forward_state_h_lstm1,\n", | |
" forward_state_c_lstm1,\n", | |
" backward_state_h_lstm1,\n", | |
" backward_state_c_lstm1])\n", | |
" state_h_lstm2 = tf.concat([forward_state_h_lstm2, backward_state_h_lstm2], axis = -1)\n", | |
" state_c_lstm2 = tf.concat([forward_state_c_lstm2, backward_state_c_lstm2], axis = -1)\n", | |
" state_lstm2 = [state_h_lstm2, state_c_lstm2]\n", | |
" return output_lstm2, state_lstm2\n" | |
], | |
"metadata": { | |
"id": "T8JXZNRLJhVb" | |
}, | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"lstm_units = [32,\n", | |
" 2*32]\n", | |
"n_features = 14\n", | |
"dropout_rate = 0.3\n", | |
" #dropout_rate = 0.3\n", | |
"l2_penalty = 0.01\n", | |
"#l2_penalty = 0.001\n", | |
"batch_size = 512\n", | |
"n_epochs = 200\n", | |
"out_step = 12\n", | |
"regularization = True\n", | |
"training = True\n", | |
"\n", | |
"\n", | |
"encoder_input_shape = (12, 14)\n", | |
"decoder_input_shape = (12, 1)\n", | |
"\n", | |
"model_vanilla = seq2seq(encoder_input_shape,\n", | |
" decoder_input_shape,\n", | |
" out_step,\n", | |
" n_features,\n", | |
" 'Vanilla',\n", | |
" lstm_units,\n", | |
" dropout_rate,\n", | |
" l2_penalty,\n", | |
" teacher_force_prob = None,\n", | |
" regularization = True,\n", | |
" training = True)" | |
], | |
"metadata": { | |
"id": "aw8W0JAXJkzc" | |
}, | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"x = tf.ones(shape = (900, 12, 14))\n", | |
"y = tf.ones(shape = (900, 12, 1))" | |
], | |
"metadata": { | |
"id": "AyhLU1aEJmnS" | |
}, | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"out = model_vanilla((x, y)) # this model takes 2 inputs" | |
], | |
"metadata": { | |
"id": "ez91kY-mJp1v" | |
}, | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "2ypp6-gCJs3X" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment