Forked from anonymous/split bidirectional rnns.ipynb
Created
January 5, 2017 13:30
-
-
Save braingineer/584a54fba813c1c073585061663b0bde to your computer and use it in GitHub Desktop.
split bidirectional rnns
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true, | |
"collapsed": false | |
}, | |
"cell_type": "code", | |
"source": "import keras as K\nfrom keras.layers import Bidirectional, LSTM, Lambda, Input, merge", | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "Using Theano backend.\n", | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": "Couldn't import dot_parser, loading of dot files will not be possible.\n", | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": "Using gpu device 0: GeForce GTX 980 (CNMeM is disabled, cuDNN 5105)\n", | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"collapsed": true | |
}, | |
"cell_type": "code", | |
"source": "def SplitAt(x_in, i, outshape):\n # might have to deal with broadcasting issues here. If so, zeros_like just needs a 1-length 3rd dim\n return (Lambda(lambda x: x[:, i:]*K.backend.zeros_like(x[:, i:]), output_shape=outshape)(x_in), \n Lambda(lambda x: x[:, :i]*K.backend.zeros_like(x[:, :i]), output_shape=outshape)(x_in))", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"collapsed": false | |
}, | |
"cell_type": "code", | |
"source": "batch = 10\nin_length = 10\nin_feat = 10\nout_feat = 10\n\nout_splitshape = (in_length, in_feat)\n\nF_leftrnn = LSTM(out_feat) # assuming returning last vector, not whole sequence\nF_rightrnn = LSTM(out_feat, go_backwards=True)\n\n\nx_in = Input(batch_shape=(batch, in_length, in_feat))\n# alternatively\n#x_in = Input(batch_size=(batch, in_length), dtype='int32')\n#x_emb = Embedding(in_feat)(x_in)\n\nall_splits = []\n\nfor i in range(in_length):\n xi_left, xi_right = SplitAt(x_in, i, out_splitshape)\n \n xi_left = F_leftrnn(xi_left)\n xi_right = F_rightrnn(xi_right)\n \n xi = merge([xi_left, xi_right], mode='concat', concat_axis=-1)\n all_splits.append(xi)\n\n# should be of size (in_length, batch, out_feat*2) ## out_feat * 2 because of the concat. \nstacked_splits = K.backend.stack(all_splits) \n# desired: (batch, in_length, out_feat*2) \nstacked_splits = K.backend.permute_dimensions(stacked_splits, [1, 0, 2])\n\n# the backend stack should have *args, **kwargs bit, but it doesn't. \n# if it did:\n# stacked_splits = K.backend.stack(all_splits, axis=1) \n# iirc, would create new dim at 1 and stack along. would have to play with to be sure.", | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "/home/cogniton/anaconda/envs/DL/lib/python2.7/site-packages/ipykernel/__main__.py:29: DeprecationWarning: stack(*tensors) interface is deprecated, use stack(tensors, axis=0) instead.\n", | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"collapsed": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "conda-env-DL-py", | |
"display_name": "Python [conda env:DL]", | |
"language": "python" | |
}, | |
"anaconda-cloud": {}, | |
"language_info": { | |
"mimetype": "text/x-python", | |
"nbconvert_exporter": "python", | |
"name": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11", | |
"file_extension": ".py", | |
"codemirror_mode": { | |
"version": 2, | |
"name": "ipython" | |
} | |
}, | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "split bidirectional rnns", | |
"public": true | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment