Skip to content

Instantly share code, notes, and snippets.

@superjax
Created November 10, 2017 03:47
Show Gist options
  • Save superjax/1d7c04ec3ec174d5ee981f8213eba5e8 to your computer and use it in GitHub Desktop.
Save superjax/1d7c04ec3ec174d5ee981f8213eba5e8 to your computer and use it in GitHub Desktop.
my implementation of a Gated Recurrent Unit in Tensorflow
from tensorflow.python.ops.rnn_cell import RNNCell
from tensorflow.python.ops import math_ops
import tensorflow as tf
class myGRU(RNNCell):
def __init__(self, num_units, forget_bias=1.0,
state_is_tuple=True, activation=None, reuse=None):
super(RNNCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation or math_ops.tanh
self._need_to_init = True
self._W_z = None
self._W_r = None
self._W_h = None
self._U_z = None
self._U_r = None
self._U_h = None
self._b_z = None
self._b_r = None
self._b_h = None
@property
def state_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
if self._need_to_init:
input_shape = inputs.get_shape().as_list()
state_shaoe = state.get_shape().as_list()
self._need_to_init = False
with tf.variable_scope("GRU"):
self._W_z = tf.get_variable('Wz', shape=[self._num_units, input_shape[1]])
self._W_r = tf.get_variable('Wr', shape=[self._num_units, input_shape[1]])
self._W_h = tf.get_variable('Wh', shape=[self._num_units, input_shape[1]])
self._U_z = tf.get_variable('Uz', shape=[self._num_units, self._num_units])
self._U_r = tf.get_variable('Ur', shape=[self._num_units, self._num_units])
self._U_h = tf.get_variable('Uh', shape=[self._num_units, self._num_units])
self._b_z = tf.get_variable('bz', shape=[self._num_units, 1])
self._b_r = tf.get_variable('br', shape=[self._num_units, 1])
self._b_h = tf.get_variable('bh', shape=[self._num_units, 1])
z = tf.nn.sigmoid(tf.matmul(self._W_z, tf.transpose(inputs)) + tf.matmul(self._U_z, tf.transpose(state)) + self._b_z)
r = tf.nn.sigmoid(tf.matmul(self._W_r, tf.transpose(inputs)) + tf.matmul(self._U_r, tf.transpose(state)) + self._b_r)
new_h = z*tf.transpose(state) + (1 - z)*tf.nn.tanh(tf.matmul(self._W_h, tf.transpose(inputs)) + tf.matmul(self._U_h, r * tf.transpose(state)) + self._b_h)
new_h = tf.transpose(new_h)
return new_h, new_h
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment