Skip to content

Instantly share code, notes, and snippets.

@ariG23498
Last active May 3, 2023 01:47
Show Gist options
  • Save ariG23498/fd76bf197f71cd044e40aa2ffe2b6aee to your computer and use it in GitHub Desktop.
Save ariG23498/fd76bf197f71cd044e40aa2ffe2b6aee to your computer and use it in GitHub Desktop.
Custom RNN logic
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ariG23498/fd76bf197f71cd044e40aa2ffe2b6aee/scratchpad.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lIYdn1woOS1n"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
]
},
{
"cell_type": "code",
"source": [
"class CustomCell(layers.Layer):\n",
" def __init__(self, num_heads, key_dim, dropout, token_size, units, **kwargs):\n",
" self.token_size = token_size\n",
" self.units = units\n",
" self.state_size = tf.TensorShape([token_size, units])\n",
" self.output_size = tf.TensorShape([token_size, units])\n",
"\n",
" # This is the point where we need to add our custom logic\n",
" # instead of the MLP\n",
" self.attention_module = layers.MultiHeadAttention(\n",
" num_heads,\n",
" key_dim,\n",
" dropout=dropout\n",
" )\n",
" \n",
" super().__init__(**kwargs)\n",
"\n",
" def call(self, inputs, states):\n",
" # inputs => (batch, token_size, dims)\n",
" # states => [(batch, token_size, units)]\n",
"\n",
" prev_state = states[0]\n",
"\n",
" outputs = self.attention_module(inputs, inputs) # Self Attention as key and value are the same\n",
" new_state = outputs + prev_state\n",
" \n",
" return outputs, [new_state]"
],
"metadata": {
"id": "Zco4Hq1rNlG-"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"keras.backend.clear_session()\n",
"\n",
"batch_size = 8\n",
"num_frames = 32\n",
"token_size = 543\n",
"units = 3\n",
"\n",
"inputs = tf.random.normal(\n",
" (batch_size, num_frames, token_size, units)\n",
")\n",
"\n",
"cell = CustomCell(num_heads=1, key_dim=3, dropout=0.1, token_size=token_size, units=3)\n",
"rnn = layers.RNN(cell)\n",
"rnn(inputs).shape # This is the last hidden state of the RNN"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6jiaDke1NuHz",
"outputId": "9af33977-7a75-4925-ca7a-4c7f14033381"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([8, 543, 3])"
]
},
"metadata": {},
"execution_count": 5
}
]
}
],
"metadata": {
"colab": {
"name": "scratchpad",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment