Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Last active November 26, 2021 11:00
Show Gist options
  • Save bzamecnik/bd3786a074f8cb891bc2a397343070f1 to your computer and use it in GitHub Desktop.
Save bzamecnik/bd3786a074f8cb891bc2a397343070f1 to your computer and use it in GitHub Desktop.
CuDNN-compatible GRU in Keras
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CuDNN-compatible GRU in Keras\n",
"\n",
"https://github.com/keras-team/keras/pull/9112\n",
"\n",
"2018-01-19, Bohumír Zámečník <bohumir.zamecnik@gmail.com>, [Rossum.ai](https://rossum.ai/)\n",
"\n",
"We'd like to show how to train a model with GRU layer on GPU accelerated by CuDNN and then use it for inference without the need of GPU while getting the same predictions.\n",
"\n",
"- train a model quickly on GPU with `CuDNNGRU`\n",
"- save weights\n",
"- create model on CPU with plain CuDNN-compatible GRU (from this PR)\n",
"- load weights (and convert)\n",
"- make the same predictions (up to numerical precision)\n",
"\n",
"Model based on the [`imdb_lstm.py` Keras example](https://github.com/keras-team/keras/blob/master/examples/imdb_lstm.py).\n",
"\n",
"## Installing\n",
"\n",
"```\n",
"git clone git@github.com:bzamecnik/keras.git\n",
"cd keras/\n",
"git checkout cudnn-compatible-gru\n",
"pip install -e .\n",
"```\n",
"\n",
"## Basic usage examples\n",
"\n",
"### Creating the GRU layers\n",
"\n",
"We introduced a new parameter `reset_after=True` which makes GRU compatible with CuDNN GRU convention.\n",
"\n",
"Note that that default recurrent_activation in GRU is `'hard_sigmoid'`, but CuDNN supports only `'sigmoid'`, so we have to set it explicitly.\n",
"\n",
"```\n",
"from keras.layers import CuDNNGRU, GRU\n",
"```\n",
"\n",
"Basic GRU - runs on CPU/GPU, not compatible with CuDNNGRU:\n",
"\n",
"```\n",
"gru = GRU(n_units)\n",
"```\n",
"\n",
"CuDNN-accelerated GRU - runs only on GPU, different convention, much faster:\n",
"\n",
"```\n",
"gru_cudnn = CuDNNGRU(n_units)\n",
"```\n",
"\n",
"New CuDNN-compatible GRU - runs on CPU/GPU, compatible with CuDNNGRU:\n",
"```\n",
"gru_compatible = GRU(n_units, reset_after=True, recurrent_activation='sigmoid')\n",
"```\n",
"\n",
"Keras GRU has two implementations (`implementation=1` or `2`). The first one performs matrix multiplications separately for each projection matrix, the second one merges matrices together into a single multiplication, thus might be a bit faster on GPU. In the `reset_after` convention we can do one multiplication, in `reset_before` we can merge only two matrices and perform another one after applying the reset gate.\n",
"\n",
"In our measurement we see that for predictions implementation 2 is slightly faster on GPU, while slower on CPU.\n",
"\n",
"### Loading weights\n",
"\n",
"Training on GPU:\n",
"\n",
"```\n",
"gru_cudnn = CuDNNGRU(n_units)\n",
"model = ... make model with gru_cudnn ...\n",
"model.fit(...)\n",
"model.save_weights('weights_cudnn.h5')\n",
"```\n",
"\n",
"Predictions on CPU (uses default `implementation=1`):\n",
"\n",
"```\n",
"gru = GRU(n_units, reset_after=True, recurrent_activation='sigmoid')\n",
"model = ... make model with gru ...\n",
"model.load_weights('weights_cudnn.h5')\n",
"predictions = model.predict(x)\n",
"```\n",
"\n",
"The code for loading weights detects weights from CuDNNGRU and automatically converts them for usage in GRU."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
"env: CUDA_VISIBLE_DEVICES=0\n"
]
}
],
"source": [
"# Use a GPU\n",
"%env CUDA_DEVICE_ORDER=PCI_BUS_ID\n",
"%env CUDA_VISIBLE_DEVICES=0"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"from __future__ import print_function\n",
"\n",
"import random\n",
"\n",
"import keras.backend as K\n",
"from keras.datasets import imdb\n",
"from keras.layers import CuDNNGRU, GRU, Dense, Embedding, Input\n",
"from keras.models import Sequential\n",
"from keras.preprocessing.sequence import pad_sequences\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"from tensorflow.python.client import device_lib"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check that we have a GPU available:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[name: \"/device:CPU:0\"\n",
" device_type: \"CPU\"\n",
" memory_limit: 268435456\n",
" locality {\n",
" }\n",
" incarnation: 16772068897675471964, name: \"/device:GPU:0\"\n",
" device_type: \"GPU\"\n",
" memory_limit: 182583296\n",
" locality {\n",
" bus_id: 1\n",
" }\n",
" incarnation: 5111035405032839957\n",
" physical_device_desc: \"device: 0, name: GeForce GTX 1070, pci bus id: 0000:01:00.0, compute capability: 6.1\"]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device_lib.list_local_devices()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare a simple dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading IMDB dataset...\n",
"25000 train sequences\n",
"25000 test sequences\n",
"Pad sequences (samples x time)\n",
"x_train shape: (25000, 80)\n",
"x_test shape: (25000, 80)\n",
"CPU times: user 49.4 s, sys: 402 ms, total: 49.8 s\n",
"Wall time: 49.7 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"max_features = 20000\n",
"maxlen = 80 # cut texts after this number of words (among top max_features most common words)\n",
"batch_size = 32\n",
"\n",
"print('Loading IMDB dataset...')\n",
"(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)\n",
"print(len(x_train), 'train sequences')\n",
"print(len(x_test), 'test sequences')\n",
"\n",
"print('Pad sequences (samples x time)')\n",
"x_train = pad_sequences(x_train, maxlen=maxlen)\n",
"x_test = pad_sequences(x_test, maxlen=maxlen)\n",
"print('x_train shape:', x_train.shape)\n",
"print('x_test shape:', x_test.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a model with CuDNN"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def set_seed(seed=42):\n",
" # trying to make the runs repeatable, unfortunately there's still some source of randomness :/\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
" tf.set_random_seed(seed)\n",
"\n",
"def make_model(gru_class, **gru_kwawgs):\n",
" \"\"\"\n",
" Creates a simple text prediction model with given RNN layer class\n",
" and its arguments.\n",
" \"\"\"\n",
" print(gru_class.__name__, gru_kwawgs)\n",
" set_seed()\n",
" model = Sequential()\n",
" model.add(Embedding(max_features, 100))\n",
" model.add(gru_class(128, **gru_kwawgs))\n",
" model.add(Dense(1, activation='sigmoid'))\n",
"\n",
" # try using different optimizers and different optimizer configs\n",
" model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])\n",
" return model\n",
"\n",
"def train(model, subset_size=2*1024, epochs=2):\n",
" set_seed()\n",
" model.fit(x_train[:subset_size], y_train[:subset_size],\n",
" validation_data=(x_test[:subset_size], y_test[:subset_size]),\n",
" batch_size=batch_size, epochs=epochs)\n",
" return model\n",
"\n",
"def evaluate(model, subset_size=2*1024):\n",
" set_seed()\n",
" result = {}\n",
" result['train_loss'], result['train_acc'] = model.evaluate(x_train[:subset_size], y_train[:subset_size], batch_size=batch_size)\n",
" result['test_loss'], result['test_acc'] = model.evaluate(x_test[:subset_size], y_test[:subset_size], batch_size=batch_size)\n",
" print(result)\n",
" return [result['train_loss'], result['train_acc'], result['test_loss'], result['test_acc']]\n",
"\n",
"def train_eval(model):\n",
" return evaluate(train(model))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CuDNNGRU {}\n",
"Train on 2048 samples, validate on 2048 samples\n",
"Epoch 1/2\n",
"2048/2048 [==============================] - 2s 752us/step - loss: 0.6785 - acc: 0.5718 - val_loss: 0.6156 - val_acc: 0.6880\n",
"Epoch 2/2\n",
"2048/2048 [==============================] - 1s 602us/step - loss: 0.4156 - acc: 0.8306 - val_loss: 0.5231 - val_acc: 0.7480\n",
"2048/2048 [==============================] - 0s 122us/step\n",
"2048/2048 [==============================] - 0s 123us/step\n",
"{'train_acc': 0.95556640625, 'train_loss': 0.16661585529800504, 'test_loss': 0.52313004108145833, 'test_acc': 0.748046875}\n"
]
}
],
"source": [
"model_cudnn = make_model(CuDNNGRU)\n",
"results_cudnn = train_eval(model_cudnn)\n",
"pred_cudnn = model_cudnn.predict(x_test)\n",
"model_cudnn.save_weights('gru_cudnn.h5')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load trained weights into a CuDNN-compatible GRU model\n",
"\n",
"We can easily reuse weights trained with CuDNN GRU. The prediction is slower but the benefit is that it's not dependent on GPU with CuDNN and thus can run on a CPU."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GRU {'implementation': 1, 'reset_after': True, 'recurrent_activation': 'sigmoid'}\n",
"2048/2048 [==============================] - 2s 808us/step\n",
"2048/2048 [==============================] - 2s 769us/step\n",
"{'train_acc': 0.95556640625, 'train_loss': 0.16661585564725101, 'test_loss': 0.52313003782182932, 'test_acc': 0.748046875}\n",
"Does it match with the CuDNN GRU model?\n",
"Evaluation results: True\n",
"Predictions: True\n",
"GRU {'implementation': 2, 'reset_after': True, 'recurrent_activation': 'sigmoid'}\n",
"2048/2048 [==============================] - 2s 735us/step\n",
"2048/2048 [==============================] - 2s 769us/step\n",
"{'train_acc': 0.95556640625, 'train_loss': 0.16661585564725101, 'test_loss': 0.52313003782182932, 'test_acc': 0.748046875}\n",
"Does it match with the CuDNN GRU model?\n",
"Evaluation results: True\n",
"Predictions: True\n"
]
}
],
"source": [
"def check_gru_compatible_model(model):\n",
" model.load_weights('gru_cudnn.h5')\n",
" eval_results = evaluate(model)\n",
" pred = model.predict(x_test, verbose=1)\n",
" print('Does it match with the CuDNN GRU model?')\n",
" print('Evaluation results:', np.allclose(results_cudnn, eval_results))\n",
" print('Predictions:', np.allclose(pred_cudnn, pred))\n",
"\n",
"for implementation in [1, 2]:\n",
" model = make_model(GRU,\n",
" reset_after=True,\n",
" recurrent_activation='sigmoid',\n",
" implementation=implementation)\n",
" check_gru_compatible_model(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Show that `reset_before` GRU and `hard_sigmoid` are not compatible"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Weights from CuDNN do not match GRU `reset_after=False` (single vs. double set of biases)."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GRU {'implementation': 1, 'variant': 'reset_before'}\n",
"Cannot load weights:\n",
"Dimension 0 in both shapes must be equal, but are 384 and 768 for 'Assign_79' (op: 'Assign') with input shapes: [384], [768].\n"
]
}
],
"source": [
"try:\n",
" check_gru_compatible_model(make_model(GRU, reset_after=False, implementation=1))\n",
"except ValueError as e:\n",
" print('Cannot load weights:')\n",
" print(e)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`recurrent_activation='hard_sigmoid'` doesn't match exactly single CuDNN uses `'sigmoid'` inside."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GRU {'implementation': 1, 'reset_after': True}\n",
"2048/2048 [==============================] - 2s 963us/step\n",
"2048/2048 [==============================] - 2s 964us/step\n",
"{'train_acc': 0.95361328125, 'train_loss': 0.17014663806185126, 'test_loss': 0.52106391731649637, 'test_acc': 0.74658203125}\n",
"Does it match with the CuDNN GRU model?\n",
"Evaluation results: False\n",
"Predictions: False\n"
]
}
],
"source": [
"check_gru_compatible_model(make_model(GRU, reset_after=True, implementation=1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Speedup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prediction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### On GPU"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"25000/25000 [==============================] - 3s 102us/step\n"
]
}
],
"source": [
"model_cudnn.predict(x_test, verbose=1);"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"25000/25000 [==============================] - 19s 757us/step\n"
]
}
],
"source": [
"model_compatible = make_model(GRU, reset_after=True, recurrent_activation='sigmoid', implementation=1)\n",
"model_compatible.load_weights('gru_cudnn.h5')\n",
"model_compatible.predict(x_test, verbose=1);"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GRU {'implementation': 2, 'reset_after': True, 'recurrent_activation': 'sigmoid'}\n",
"25000/25000 [==============================] - 17s 692us/step\n"
]
}
],
"source": [
"model_compatible_2 = make_model(GRU, reset_after=True, recurrent_activation='sigmoid', implementation=2)\n",
"model_compatible_2.load_weights('gru_cudnn.h5')\n",
"model_compatible_2.predict(x_test, verbose=1);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"|model|total time (s)|speedup|ms/step|speedup|\n",
"|-----|--------------|-------|-------|-------|\n",
"|plain impl 1|19|1x|0.757|1x|\n",
"|plain impl 2|17|1.12x|0.692|1.09x|\n",
"|CuDNN|3|6.3x|0.102|7.4x|\n",
"\n",
"In prediction we're getting ~7x speedup using CuDNN and around 10% speedup using `implementation=2` (merged matrix multiplication)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### On CPU\n",
"\n",
"Results when the code is ran without GPU. We compare `implementation` `1` and `2`."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GRU {'implementation': 1, 'reset_after': True, 'recurrent_activation': 'sigmoid'}\n",
"25000/25000 [==============================] - 13s 519us/step\n"
]
}
],
"source": [
"model_compatible = make_model(GRU, reset_after=True, recurrent_activation='sigmoid', implementation=1)\n",
"model_compatible.load_weights('gru_cudnn.h5')\n",
"model_compatible.predict(x_test, verbose=1);"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GRU {'implementation': 2, 'reset_after': True, 'recurrent_activation': 'sigmoid'}\n",
"25000/25000 [==============================] - 15s 594us/step\n"
]
}
],
"source": [
"model_compatible_2 = make_model(GRU, reset_after=True, recurrent_activation='sigmoid', implementation=2)\n",
"model_compatible_2.load_weights('gru_cudnn.h5')\n",
"model_compatible_2.predict(x_test, verbose=1);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"|model|total time (s)|speedup|ms/step|speedup|\n",
"|-----|--------------|-------|-------|-------|\n",
"|plain impl 1|13|1x|519|1x|\n",
"|plain impl 2|15|0.87x|594|0.87x|\n",
"\n",
"It seems that prediction on this model without using CuDNN is slighly faster on CPU than GPU and that `implementation=1` is also slightly faster than `implementation=2`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training\n",
"\n",
"Let's compare speed of training of both implementation on a GPU."
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CuDNNGRU {}\n",
"Train on 25000 samples, validate on 25000 samples\n",
"Epoch 1/2\n",
"25000/25000 [==============================] - 16s 632us/step - loss: 0.4482 - acc: 0.7806 - val_loss: 0.3877 - val_acc: 0.8268\n",
"Epoch 2/2\n",
"25000/25000 [==============================] - 15s 616us/step - loss: 0.2472 - acc: 0.9008 - val_loss: 0.3596 - val_acc: 0.8446\n"
]
},
{
"data": {
"text/plain": [
"<keras.models.Sequential at 0x7f57891f6650>"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train(make_model(CuDNNGRU), len(x_train))"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GRU {'implementation': 1, 'reset_after': True, 'recurrent_activation': 'sigmoid'}\n",
"Train on 25000 samples, validate on 25000 samples\n",
"Epoch 1/2\n",
"25000/25000 [==============================] - 88s 4ms/step - loss: 0.4369 - acc: 0.7890 - val_loss: 0.3722 - val_acc: 0.8457\n",
"Epoch 2/2\n",
"25000/25000 [==============================] - 88s 4ms/step - loss: 0.2355 - acc: 0.9076 - val_loss: 0.3582 - val_acc: 0.8449\n"
]
},
{
"data": {
"text/plain": [
"<keras.models.Sequential at 0x7f5789c6edd0>"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train(make_model(GRU,\n",
" reset_after=True,\n",
" recurrent_activation='sigmoid',\n",
" implementation=1), len(x_train))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GRU {'implementation': 2, 'reset_after': True, 'recurrent_activation': 'sigmoid'}\n",
"Train on 25000 samples, validate on 25000 samples\n",
"Epoch 1/2\n",
"25000/25000 [==============================] - 84s 3ms/step - loss: 0.4369 - acc: 0.7891 - val_loss: 0.3723 - val_acc: 0.8456\n",
"Epoch 2/2\n",
"25000/25000 [==============================] - 84s 3ms/step - loss: 0.2355 - acc: 0.9077 - val_loss: 0.3582 - val_acc: 0.8447\n"
]
},
{
"data": {
"text/plain": [
"<keras.models.Sequential at 0x7f57881c2a90>"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train(make_model(GRU,\n",
" reset_after=True,\n",
" recurrent_activation='sigmoid',\n",
" implementation=2), len(x_train))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Results\n",
"\n",
"- terminology: \"step\" ~ \"batch\"\n",
"- speedup to plain impl 1\n",
"\n",
"|model|epoch time (s)|speedup|ms/step|speedup|\n",
"|-----|--------------|-------|-------|-------|\n",
"|plain impl 1|88|1x|4.|1x|\n",
"|plain impl 2|84|1.05x|3.|1.3x|\n",
"|CuDNN |15|5.87x|0.616|6.49x|\n",
"\n",
"For training on GPU we get ~6x speedup using CuDNN and slight speedup using `implementation=2`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"We showed that the newly implemented variant of `GRU` that's compatible with `CuDNNGRU` allows workflow with fast training on GPU (accelerated ~6x) and predictions on CPU when GPU is not available while providing the same results."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@bzamecnik
Copy link
Author

In the resulting PR it's even possible to convert GRU weights to CuDNNGRU and LSTM weights to CuDNNLSTM!

@SatishDivakarla
Copy link

I am getting following error when trying to load_weights that are saved from cudnngru model.

AttributeError: 'Dataset' object has no attribute 'reshape'

Architecture code on GPU machine:
input_layer = Input(shape=(sequence_length,))
embedding_layer = Embedding(embedding_matrix.shape[0], embedding_matrix.shape[1],
weights=[embedding_matrix], trainable=False)(input_layer)
x = Bidirectional(CuDNNGRU(recurrent_units, return_sequences=True))(embedding_layer)
x = Dropout(dropout_rate)(x)
x = Bidirectional(CuDNNGRU(recurrent_units, return_sequences=True))(x)
x_max = GlobalMaxPool1D()(x)
x_avg = GlobalAveragePooling1D()(x)
x = concatenate([x_max, x_avg])
# x = Dense(dense_size, activation="relu")(x)
output_layer = Dense(6, activation="sigmoid")(x)
model = Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='binary_crossentropy', optimizer=RMSprop(clipvalue=1, clipnorm=1), metrics=['accuracy'])

Code to save the weights on GPU machine:
model.save_weights("model0_weights.h5")

Architecture code on CPU machine:
input_layer = Input(shape=(sequence_length,))
embedding_layer = Embedding(embedding_matrix.shape[0], embedding_matrix.shape[1],
weights=[embedding_matrix], trainable=False)(input_layer)
x = Bidirectional(GRU(recurrent_units, reset_after=True, recurrent_activation='sigmoid', return_sequences=True))(embedding_layer)
x = Dropout(dropout_rate)(x)
x = Bidirectional(GRU(recurrent_units, reset_after=True, recurrent_activation='sigmoid', return_sequences=True))(x)
x_max = GlobalMaxPool1D()(x)
x_avg = GlobalAveragePooling1D()(x)
x = concatenate([x_max, x_avg])
output_layer = Dense(6, activation="sigmoid")(x)
model = Model(inputs=input_layer, outputs=output_layer)
model.compile(loss='binary_crossentropy', optimizer=RMSprop(clipvalue=1, clipnorm=1), metrics=['accuracy'])

Code to reload the saved weights:
model_0_weights = model.load_weights("model0_weights.h5")

Please let me know if I am doing something wrong.

@SatishDivakarla
Copy link

@bzamecnik, Could you please respond to my comment. I have a trained model with CudNNGRU trained on GPU, but I am not able to make them use on CPU environment for predictions. Thanks in advance.

@leowang16
Copy link

Having the same error:
AttributeError: 'Dataset' object has no attribute 'reshape'

@bzamecnik
Copy link
Author

@SatishDivakarla: Sorry, I missed your comment... Don't you have the original stack trace? If I remember correctly I saw something similar in Keras issues and the cause was that the weights were a h5py Dataset instead of a numpy array. I'm not sure about resolution.

@bzamecnik
Copy link
Author

bzamecnik commented Apr 23, 2018

@SatishDivakarla: Aah, keras-team/keras#9112 (comment). Fixed in keras-team/keras#9662 (2018-03-15). Will be released in Keras 2.1.6.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment