Skip to content

Instantly share code, notes, and snippets.

@Hsankesara
Last active June 3, 2018 17:43
Show Gist options
  • Save Hsankesara/7399ac2338a9916d0f663852dc1d91c5 to your computer and use it in GitHub Desktop.
Save Hsankesara/7399ac2338a9916d0f663852dc1d91c5 to your computer and use it in GitHub Desktop.
Basic RNN model in tensorflow
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "RNN.ipynb",
"version": "0.3.2",
"views": {},
"default_view": {},
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3",
"language": "python"
}
},
"cells": [
{
"metadata": {
"id": "Zh0NwKZcn_cU",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "4c733725-a849-46bb-c077-0dd1bab2ff8b",
"executionInfo": {
"status": "ok",
"timestamp": 1528032277418,
"user_tz": -330,
"elapsed": 2246,
"user": {
"displayName": "Heet Sankesara",
"photoUrl": "//lh5.googleusercontent.com/-mO3PS3oBtRQ/AAAAAAAAAAI/AAAAAAAAADs/_ic_rsddWU4/s50-c-k-no/photo.jpg",
"userId": "112240562069909648160"
}
}
},
"cell_type": "code",
"source": [
"!ls"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"datalab poetry.csv\r\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "SrBj_rffoMyV",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import numpy as np"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "trP-g463ocib",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"poetry = pd.read_csv('poetry.csv')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "iV9_HxKcofkx",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"poem = poetry['content'][0]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "qKbmBG2Voie-",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"poem = poem.replace('\\r\\n\\r\\n', '\\r\\n')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "4K6E1qsSorpd",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"poem = poem.replace('\\r\\n', ' ')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "wQKWX4mKo3O3",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"poem = poem.replace('\\'', '')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "EFW3yHbipNoW",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"base_uri": "https://localhost:8080/",
"height": 55
},
"outputId": "0f083c59-5972-4f34-c48e-6ceac6f675fd",
"executionInfo": {
"status": "ok",
"timestamp": 1528032289907,
"user_tz": -330,
"elapsed": 1059,
"user": {
"displayName": "Heet Sankesara",
"photoUrl": "//lh5.googleusercontent.com/-mO3PS3oBtRQ/AAAAAAAAAAI/AAAAAAAAADs/_ic_rsddWU4/s50-c-k-no/photo.jpg",
"userId": "112240562069909648160"
}
}
},
"cell_type": "code",
"source": [
"poem"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'Let the bird of loudest lay On the sole Arabian tree Herald sad and trumpet be, To whose sound chaste wings obey. But thou shrieking harbinger, Foul precurrer of the fiend, Augur of the fevers end, To this troop come thou not near. From this session interdict Every fowl of tyrant wing, Save the eagle, featherd king; Keep the obsequy so strict. Let the priest in surplice white, That defunctive music can, Be the death-divining swan, Lest the requiem lack his right. And thou treble-dated crow, That thy sable gender makst With the breath thou givst and takst, Mongst our mourners shalt thou go. Here the anthem doth commence: Love and constancy is dead; Phoenix and the Turtle fled In a mutual flame from hence. So they lovd, as love in twain Had the essence but in one; Two distincts, division none: Number there in love was slain. Hearts remote, yet not asunder; Distance and no space was seen Twixt this Turtle and his queen: But in them it were a wonder. So between them love did shine That the Turtle saw his right Flaming in the Phoenix sight: Either was the others mine. Property was thus appalled That the self was not the same; Single natures double name Neither two nor one was called. Reason, in itself confounded, Saw division grow together, To themselves yet either neither, Simple were so well compounded; That it cried, \"How true a twain Seemeth this concordant one! Love has reason, reason none, If what parts can so remain.\" Whereupon it made this threne To the Phoenix and the Dove, Co-supremes and stars of love, As chorus to their tragic scene: threnos Beauty, truth, and rarity, Grace in all simplicity, Here enclosd, in cinders lie. Death is now the Phoenix nest, And the Turtles loyal breast To eternity doth rest, Leaving no posterity: Twas not their infirmity, It was married chastity. Truth may seem but cannot be; Beauty brag but tis not she; Truth and beauty buried be. To this urn let those repair That are either true or fair; For these dead birds sigh a prayer.'"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"metadata": {
"id": "7SqYLgMvpng5",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"poem = poem.split(' ')"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "MCKwmGYppvj1",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"words = list(set(poem))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "RkyA-h_q5m1K",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"cellView": "code"
},
"cell_type": "code",
"source": [
"strdict = {}\n",
"revdict = {}\n",
"i = 0\n",
"for word in words:\n",
" strdict[word] = i\n",
" revdict[i] = word\n",
" i += 1"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ktb11a_j5qne",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def convert_to_df(start, size, seq, vsize):\n",
" inp = seq[start:start + size]\n",
" inp = np.array([strdict[x] for x in inp])\n",
" inp = inp.reshape((1, size))\n",
" out = np.zeros((vsize))\n",
" out[strdict[seq[start + size]]] = 1\n",
" return inp, out"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "I47QFPnyqTtG",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"import tensorflow as tf"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "KcBhjDxhxfy3",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def rnn_cell(W, U,b, prev_cell, curr):\n",
" return tf.add(tf.matmul(U, curr), tf.matmul(W, prev_cell)) + b"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "1NMOgovPqnzr",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"learning_rate = 0.0001\n",
"training_iters = 10000\n",
"display_step = 500\n",
"num_inp = 4\n",
"m = len(poem)\n",
"vocab_size = len(words)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "0tqele7awk0O",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"x = tf.placeholder(tf.float64, [1, num_inp])\n",
"y = tf.placeholder(tf.float64, [vocab_size])"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "FhqUFQyMw3SG",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"W = tf.Variable(tf.random_normal([vocab_size, vocab_size]))\n",
"U = tf.Variable(tf.random_normal([vocab_size, 1]))\n",
"b = tf.Variable(tf.zeros([vocab_size, 1]))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "BDpm960uxe74",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"base_uri": "https://localhost:8080/",
"height": 35
},
"outputId": "aa284257-375a-45ff-ef2b-e148ee7c5a56",
"executionInfo": {
"status": "ok",
"timestamp": 1528032313271,
"user_tz": -330,
"elapsed": 1054,
"user": {
"displayName": "Heet Sankesara",
"photoUrl": "//lh5.googleusercontent.com/-mO3PS3oBtRQ/AAAAAAAAAAI/AAAAAAAAADs/_ic_rsddWU4/s50-c-k-no/photo.jpg",
"userId": "112240562069909648160"
}
}
},
"cell_type": "code",
"source": [
"tf.device(\"/device:GPU:0\")"
],
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<contextlib._GeneratorContextManager at 0x7fb01d974e80>"
]
},
"metadata": {
"tags": []
},
"execution_count": 18
}
]
},
{
"metadata": {
"id": "6xKblWN4yRsq",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"layer = np.zeros((vocab_size, 1))\n",
"for i in range(num_inp):\n",
" curr = x[:, i]\n",
" curr = tf.reshape(curr, (1, 1))\n",
" curr = tf.cast(curr, tf.float32)\n",
" layer = tf.cast(layer, tf.float32)\n",
" layer = rnn_cell(W, U, b, layer, curr)\n",
" if i < num_inp - 1:\n",
" layer = tf.nn.tanh(layer)\n",
"layer = tf.reshape(layer, [-1])\n",
"out = tf.nn.softmax(layer)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "tbSNM77mzCL1",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=layer, labels=y))\n",
"optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "K-4QkAIL2RxA",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"# Model evaluation\n",
"correct_pred = tf.equal(tf.argmax(out,0), tf.argmax(y,0))\n",
"accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Pz0OUV1H4oms",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"saver = tf.train.Saver()\n",
"init = tf.global_variables_initializer()"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "e3AkyEMS4sha",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
},
"base_uri": "https://localhost:8080/",
"height": 458
},
"outputId": "753b15e1-6eb1-4c7d-b9cd-365bde434c47",
"executionInfo": {
"status": "ok",
"timestamp": 1528039089039,
"user_tz": -330,
"elapsed": 6369850,
"user": {
"displayName": "Heet Sankesara",
"photoUrl": "//lh5.googleusercontent.com/-mO3PS3oBtRQ/AAAAAAAAAAI/AAAAAAAAADs/_ic_rsddWU4/s50-c-k-no/photo.jpg",
"userId": "112240562069909648160"
}
}
},
"cell_type": "code",
"source": [
"with tf.Session() as sess:\n",
" sess.run(init)\n",
" window_size = m - num_inp\n",
" for epoch in range(training_iters):\n",
" cst_total = 0\n",
" acc_total = 0\n",
" for i in range(window_size):\n",
" df_x, df_y = convert_to_df(i, num_inp, poem, vocab_size)\n",
" _, cst, o, acc, cp, l = sess.run([optimizer, cost, out, accuracy, correct_pred, layer], feed_dict = {x : df_x, y : df_y})\n",
" cst_total += cst\n",
" acc_total += acc\n",
" if (epoch + 1) % display_step == 0:\n",
" print('After ', (epoch + 1), 'iterations: Cost = ', cst_total / window_size, 'and Accuracy: ', acc_total / window_size )\n",
" print('Optimiation finished!!!')\n",
" print(\"Lets test\")\n",
" sentence = 'Let the bird of'\n",
" sent = sentence.split(' ')\n",
" sent = [strdict[s] for s in sent]\n",
" for q in range(32):\n",
" df_test = np.array(sent)\n",
" df_test = df_test.reshape((1, num_inp))\n",
" one_hot = sess.run(out, feed_dict={x : df_test})\n",
" index = int(tf.argmax(one_hot, 0).eval())\n",
" sentence = \"%s %s\" % (sentence,revdict[index])\n",
" sent = sent[1:]\n",
" sent.append(index)\n",
" print(sentence)\n",
" "
],
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"text": [
"After 500 iterations: Cost = 20.123338758524493 and Accuracy: 0.24316939890710382\n",
"After 1000 iterations: Cost = 6.411410093315618 and Accuracy: 0.3770491803278688\n",
"After 1500 iterations: Cost = 2.764653643213434 and Accuracy: 0.592896174863388\n",
"After 2000 iterations: Cost = 1.8735815742342021 and Accuracy: 0.6092896174863388\n",
"After 2500 iterations: Cost = 0.9350996642665488 and Accuracy: 0.8060109289617486\n",
"After 3000 iterations: Cost = 0.8598392717645359 and Accuracy: 0.7868852459016393\n",
"After 3500 iterations: Cost = 0.6601651109438192 and Accuracy: 0.8005464480874317\n",
"After 4000 iterations: Cost = 0.3157847127642007 and Accuracy: 0.9153005464480874\n",
"After 4500 iterations: Cost = 0.2527449232082523 and Accuracy: 0.9344262295081968\n",
"After 5000 iterations: Cost = 0.2547022749178819 and Accuracy: 0.9262295081967213\n",
"After 5500 iterations: Cost = 0.15962486185046335 and Accuracy: 0.9590163934426229\n",
"After 6000 iterations: Cost = 0.18193453793148542 and Accuracy: 0.9426229508196722\n",
"After 6500 iterations: Cost = 0.22826607685378722 and Accuracy: 0.9234972677595629\n",
"After 7000 iterations: Cost = 0.08878657359810452 and Accuracy: 0.9699453551912568\n",
"After 7500 iterations: Cost = 0.1666746637715216 and Accuracy: 0.9426229508196722\n",
"After 8000 iterations: Cost = 0.22317534572637585 and Accuracy: 0.9289617486338798\n",
"After 8500 iterations: Cost = 0.06316598943300857 and Accuracy: 0.9726775956284153\n",
"After 9000 iterations: Cost = 0.060251965944464726 and Accuracy: 0.9754098360655737\n",
"After 9500 iterations: Cost = 0.06550918032982553 and Accuracy: 0.9726775956284153\n",
"After 10000 iterations: Cost = 0.10437052107914599 and Accuracy: 0.9699453551912568\n",
"Optimiation finished!!!\n",
"Lets test\n",
"Let the bird of loudest lay On the sole Arabian tree Herald sad and trumpet be, To whose urn in name Neither were a wonder. So between them love did shine That the Turtle saw his\n"
],
"name": "stdout"
}
]
},
{
"metadata": {
"id": "pksTKvi6TL-_",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment