Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save braingineer/584a54fba813c1c073585061663b0bde to your computer and use it in GitHub Desktop.
Save braingineer/584a54fba813c1c073585061663b0bde to your computer and use it in GitHub Desktop.
split bidirectional rnns
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true,
"collapsed": false
},
"cell_type": "code",
"source": "import keras as K\nfrom keras.layers import Bidirectional, LSTM, Lambda, Input, merge",
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": "Using Theano backend.\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "Couldn't import dot_parser, loading of dot files will not be possible.\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "Using gpu device 0: GeForce GTX 980 (CNMeM is disabled, cuDNN 5105)\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": true
},
"cell_type": "code",
"source": "def SplitAt(x_in, i, outshape):\n # might have to deal with broadcasting issues here. If so, zeros_like just needs a 1-length 3rd dim\n return (Lambda(lambda x: x[:, i:]*K.backend.zeros_like(x[:, i:]), output_shape=outshape)(x_in), \n Lambda(lambda x: x[:, :i]*K.backend.zeros_like(x[:, :i]), output_shape=outshape)(x_in))",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true,
"collapsed": false
},
"cell_type": "code",
"source": "batch = 10\nin_length = 10\nin_feat = 10\nout_feat = 10\n\nout_splitshape = (in_length, in_feat)\n\nF_leftrnn = LSTM(out_feat) # assuming returning last vector, not whole sequence\nF_rightrnn = LSTM(out_feat, go_backwards=True)\n\n\nx_in = Input(batch_shape=(batch, in_length, in_feat))\n# alternatively\n#x_in = Input(batch_size=(batch, in_length), dtype='int32')\n#x_emb = Embedding(in_feat)(x_in)\n\nall_splits = []\n\nfor i in range(in_length):\n xi_left, xi_right = SplitAt(x_in, i, out_splitshape)\n \n xi_left = F_leftrnn(xi_left)\n xi_right = F_rightrnn(xi_right)\n \n xi = merge([xi_left, xi_right], mode='concat', concat_axis=-1)\n all_splits.append(xi)\n\n# should be of size (in_length, batch, out_feat*2) ## out_feat * 2 because of the concat. \nstacked_splits = K.backend.stack(all_splits) \n# desired: (batch, in_length, out_feat*2) \nstacked_splits = K.backend.permute_dimensions(stacked_splits, [1, 0, 2])\n\n# the backend stack should have *args, **kwargs bit, but it doesn't. \n# if it did:\n# stacked_splits = K.backend.stack(all_splits, axis=1) \n# iirc, would create new dim at 1 and stack along. would have to play with to be sure.",
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": "/home/cogniton/anaconda/envs/DL/lib/python2.7/site-packages/ipykernel/__main__.py:29: DeprecationWarning: stack(*tensors) interface is deprecated, use stack(tensors, axis=0) instead.\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true,
"collapsed": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "conda-env-DL-py",
"display_name": "Python [conda env:DL]",
"language": "python"
},
"anaconda-cloud": {},
"language_info": {
"mimetype": "text/x-python",
"nbconvert_exporter": "python",
"name": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11",
"file_extension": ".py",
"codemirror_mode": {
"version": 2,
"name": "ipython"
}
},
"gist": {
"id": "",
"data": {
"description": "split bidirectional rnns",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment