Skip to content

Instantly share code, notes, and snippets.

@paultsw
Created May 30, 2017 18:30
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save paultsw/8ad3285b913106c8d9b309ea43befc82 to your computer and use it in GitHub Desktop.
Save paultsw/8ad3285b913106c8d9b309ea43befc82 to your computer and use it in GitHub Desktop.
LSTM Wrapper with Gaussian Noise added to Input
class GaussianNoiseWrapper(tf.contrib.rnn.RNNCell):
"""
Wrapper for RNNCells to add gaussian noise to the input with specified mean,
stdv before each call of the internal RNNCell.
(The structure of this wrapper class follows the design of wrappers in tf.contrib.rnn
such as InputProjectionWrapper, DropoutWrapper, OutputProjectionWrapper, etc.)
"""
def __init__(self, cell, mean=0.0, stddev=0.1):
"""
Create a cell with gaussian noise appended. This just stores the mean and standard
deviation along with the cell.
"""
if not isinstance(cell, tf.contrib.rnn.RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
self._cell = cell
self._mu = mean
self._sigma = stddev
@property
def state_size(self):
return self._cell.state_size
@property
def output_size(self):
return self._cell.output_size
def zero_state(self, batch_size, dtype):
with tf.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
"""Run the input projection and then the cell."""
# Default scope: "GaussianNoiseWrapper"
with tf.variable_scope(scope or "gaussian_noise_wrapper"):
_noise = tf.random_uniform(tf.shape(inputs), mean=self._mu, stddev=self._sigma)
noise_added = tf.add(inputs, _noise)
return self._cell(noise_added, state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment