Skip to content

Instantly share code, notes, and snippets.

@nuric
Created March 27, 2018 11:24
Show Gist options
  • Save nuric/62addaa01de0d75a9a608d2c6978c771 to your computer and use it in GitHub Desktop.
Save nuric/62addaa01de0d75a9a608d2c6978c771 to your computer and use it in GitHub Desktop.
A wrapper for Keras GRU that skips timesteps if inputs for that timestep are all zeros.
"""ZeroGRU module."""
import keras.backend as K
import keras.layers as L
class ZeroGRUCell(L.GRUCell):
"""GRU Cell that skips timestep if inputs is zero as well."""
def call(self, inputs, states, training=None):
"""Step function of the cell."""
h_tm1 = states[0] # previous output
# Check if all inputs are zero for this timestep
cond = K.all(K.equal(inputs, 0), axis=-1)
new_output, new_states = super().call(inputs, states, training=training)
# Skip timestep based on the condition
curr_output = K.switch(cond, h_tm1, new_output)
curr_states = [K.switch(cond, states[i], new_states[i]) for i in range(len(states))]
return curr_output, curr_states
class ZeroGRU(L.GRU):
"""Layer wrapper for the ZeroGRUCell."""
# Just swap the GRUCell with ZeroGRUCell
def __init__(self, units,
activation='tanh',
recurrent_activation='hard_sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
implementation=1,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
reset_after=False,
**kwargs):
cell = ZeroGRUCell(units,
activation=activation,
recurrent_activation=recurrent_activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
recurrent_regularizer=recurrent_regularizer,
bias_regularizer=bias_regularizer,
kernel_constraint=kernel_constraint,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=implementation,
reset_after=reset_after)
super(L.GRU, self).__init__(cell,
return_sequences=return_sequences,
return_state=return_state,
go_backwards=go_backwards,
stateful=stateful,
unroll=unroll,
**kwargs)
self.activity_regularizer = L.regularizers.get(activity_regularizer)
if __name__ == '__main__':
import numpy as np
from keras.models import Sequential
model = Sequential()
# Contextual embeddeding of symbols
onehot_weights = np.eye(4)
onehot_weights[0, 0] = 0 # Clear zero index
model.add(L.Embedding(4, 4,
trainable=False,
weights=[onehot_weights],
name='onehot'))
model.add(ZeroGRU(2, return_sequences=True))
x = np.array([[0, 2, 1, 0, 1 ,1, 0, 0]])
y = model.predict(x)
print(y.shape)
print(y)
# (1, 8, 2)
# [[[ 0. 0. ]
# [ 0.01389048 0.2353647 ]
# [-0.35381496 0.37560514]
# [-0.35381496 0.37560514]
# [-0.452064 0.48499036]
# [-0.46228996 0.56209606]
# [-0.46228996 0.56209606]
# [-0.46228996 0.56209606]]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment