Skip to content

Instantly share code, notes, and snippets.

@8bit-pixies
Created November 10, 2019 20:50
Show Gist options
  • Save 8bit-pixies/163548b607e091f601876bfcf6e8888d to your computer and use it in GitHub Desktop.
Save 8bit-pixies/163548b607e091f601876bfcf6e8888d to your computer and use it in GitHub Desktop.
This is an implementation of grucell in Keras. This shoudl allow for a bit more flexibility when not working under the "recurrent" framework
"""
This is a manual implementaiton of grucell so that it will work in more
general envrionments...
"""
import tensorflow as tf
input_size = 64
cell_size = 32
inputs = tf.keras.layers.Input(shape=(input_size,))
states = tf.keras.layers.Input(shape=(cell_size,))
#inputs_z = inputs
#inputs_r = inputs
#inputs_h = inputs
#h_tm1_z = h_tm1
#h_tm1_r = h_tm1
#h_tm1_h = h_tm1
x_z = tf.keras.layers.Dense(cell_size, activation=None, name='z')(inputs)
x_r = tf.keras.layers.Dense(cell_size, activation=None, name='r')(inputs)
x_h = tf.keras.layers.Dense(cell_size, activation=None, name='h')(inputs)
recurrent_z = tf.keras.layers.Dense(cell_size, activation=None, name='r_z')(states)
recurrent_r = tf.keras.layers.Dense(cell_size, activation=None, name='r_r')(states)
z = tf.keras.layers.Activation('sigmoid')(tf.keras.layers.Add()([x_z, recurrent_z]))
r = tf.keras.layers.Activation('sigmoid')(tf.keras.layers.Add()([x_r, recurrent_r]))
recurrent_h = tf.keras.layers.Dense(cell_size, activation=None)(tf.keras.layers.Multiply()([r, states]))
hh = tf.keras.layers.Activation('tanh')(tf.keras.layers.Add()([x_h, recurrent_h]))
h = tf.keras.layers.Lambda(lambda tensors : tensors[0] * tensors[1] + (1 - tensors[0]) * tensors[2])([z, states, hh])
model = tf.keras.models.Model(inputs=[inputs, states], outputs=h)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment