Skip to content

Instantly share code, notes, and snippets.

@Cppowboy
Forked from tam17aki/HyperLSTMCell.py
Created October 10, 2018 12:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Cppowboy/3971a18c08a1b91c856ee19e50fa364f to your computer and use it in GitHub Desktop.
Save Cppowboy/3971a18c08a1b91c856ee19e50fa364f to your computer and use it in GitHub Desktop.
An implementation of hyper LSTM.
# -*- coding: utf-8 -*-
# Copyright (C) 2017 by Akira TAMAMORI
# Copyright (C) 2016 by hardmaru
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import tensorflow as tf
import numpy as np
# Orthogonal Initializer from
# https://github.com/OlavHN/bnlstm
def orthogonal(shape):
flat_shape = (shape[0], np.prod(shape[1:]))
a = np.random.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
q = u if u.shape == flat_shape else v
return q.reshape(shape)
def lstm_ortho_initializer(scale=1.0):
def _initializer(shape, dtype=tf.float32, partition_info=None):
size_x = shape[0]
size_h = shape[1] / 4 # assumes lstm.
t = np.zeros(shape)
t[:, :size_h] = orthogonal([size_x, size_h]) * scale
t[:, size_h:size_h * 2] = orthogonal([size_x, size_h]) * scale
t[:, size_h * 2:size_h * 3] = orthogonal([size_x, size_h]) * scale
t[:, size_h * 3:] = orthogonal([size_x, size_h]) * scale
return tf.constant(t, dtype)
return _initializer
def layer_norm_all(h, batch_size, base, num_units, scope="layer_norm",
reuse=False, gamma_start=1.0, epsilon=1e-3, use_bias=True):
# Layer Norm (faster version, but not using defun)
#
# Performas layer norm on multiple base at once (ie, i, g, j, o for lstm)
#
# Reshapes h in to perform layer norm in parallel
h_reshape = tf.reshape(h, [batch_size, base, num_units])
mean = tf.reduce_mean(h_reshape, [2], keep_dims=True)
var = tf.reduce_mean(tf.square(h_reshape - mean), [2], keep_dims=True)
epsilon = tf.constant(epsilon)
rstd = tf.rsqrt(var + epsilon)
h_reshape = (h_reshape - mean) * rstd
# reshape back to original
h = tf.reshape(h_reshape, [batch_size, base * num_units])
with tf.variable_scope(scope):
if reuse is True:
tf.get_variable_scope().reuse_variables()
gamma = tf.get_variable(
'ln_gamma', [4 * num_units],
initializer=tf.constant_initializer(gamma_start))
if use_bias:
beta = tf.get_variable(
'ln_beta', [4 * num_units],
initializer=tf.constant_initializer(0.0))
if use_bias:
return gamma * h + beta
return gamma * h
def layer_norm(x, num_units, scope="layer_norm", reuse=False, gamma_start=1.0,
epsilon=1e-3, use_bias=True):
axes = [1]
mean = tf.reduce_mean(x, axes, keep_dims=True)
x_shifted = x - mean
var = tf.reduce_mean(tf.square(x_shifted), axes, keep_dims=True)
inv_std = tf.rsqrt(var + epsilon)
with tf.variable_scope(scope):
if reuse is True:
tf.get_variable_scope().reuse_variables()
gamma = tf.get_variable(
'ln_gamma', [num_units],
initializer=tf.constant_initializer(gamma_start))
if use_bias:
beta = tf.get_variable(
'ln_beta', [num_units],
initializer=tf.constant_initializer(0.0))
output = gamma * (x_shifted) * inv_std
if use_bias:
output = output + beta
return output
def super_linear(x, output_size, scope=None, reuse=False,
init_w="ortho", weight_start=0.0, use_bias=True,
bias_start=0.0, input_size=None):
# support function doing linear operation. uses ortho initializer defined
# earlier.
shape = x.get_shape().as_list()
with tf.variable_scope(scope or "linear"):
if reuse is True:
tf.get_variable_scope().reuse_variables()
w_init = None # uniform
if input_size is None:
x_size = shape[1]
else:
x_size = input_size
if init_w == "zeros":
w_init = tf.constant_initializer(0.0)
elif init_w == "constant":
w_init = tf.constant_initializer(weight_start)
elif init_w == "gaussian":
w_init = tf.random_normal_initializer(stddev=weight_start)
elif init_w == "ortho":
w_init = lstm_ortho_initializer(1.0)
w = tf.get_variable("super_linear_w",
[x_size, output_size],
tf.float32, initializer=w_init)
if use_bias:
b = tf.get_variable(
"super_linear_b", [output_size], tf.float32,
initializer=tf.constant_initializer(bias_start))
return tf.matmul(x, w) + b
return tf.matmul(x, w)
def hyper_norm(layer, hyper_output, embedding_size, num_units,
scope="hyper", use_bias=True):
'''
HyperNetwork norm operator
provides context-dependent weights
layer: layer to apply operation on
hyper_output: output of the hypernetwork cell at time t
embedding_size: embedding size of the output vector (see paper)
num_units: number of hidden units in main rnn
'''
# recurrent batch norm init trick (https://arxiv.org/abs/1603.09025).
init_gamma = 0.10 # cooijmans' da man.
with tf.variable_scope(scope):
zw = super_linear(hyper_output, embedding_size, init_w="constant",
weight_start=0.00, use_bias=True,
bias_start=1.0, scope="zw")
alpha = super_linear(zw, num_units, init_w="constant",
weight_start=init_gamma / embedding_size,
use_bias=False, scope="alpha")
result = tf.mul(alpha, layer)
return result
def hyper_bias(layer, hyper_output, embedding_size, num_units,
scope="hyper"):
'''
HyperNetwork norm operator
provides context-dependent bias
layer: layer to apply operation on
hyper_output: output of the hypernetwork cell at time t
embedding_size: embedding size of the output vector (see paper)
num_units: number of hidden units in main rnn
'''
with tf.variable_scope(scope):
zb = super_linear(hyper_output, embedding_size, init_w="gaussian",
weight_start=0.01, use_bias=False,
bias_start=0.0, scope="zb")
beta = super_linear(zb, num_units, init_w="constant",
weight_start=0.00, use_bias=False, scope="beta")
return layer + beta
class LSTMCell(tf.contrib.rnn.RNNCell):
"""
Layer-Norm, with Ortho Initialization and
Recurrent Dropout without Memory Loss.
https://arxiv.org/abs/1607.06450 - Layer Norm
https://arxiv.org/abs/1603.05118 - Recurrent Dropout without Memory Loss
derived from
https://github.com/OlavHN/bnlstm
https://github.com/LeavesBreathe/tensorflow_with_latest_papers
"""
def __init__(self, num_units, forget_bias=1.0, use_layer_norm=False,
use_recurrent_dropout=False, dropout_keep_prob=0.90):
"""Initialize the Layer Norm LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (default 1.0).
use_recurrent_dropout: float, Whether to use Recurrent Dropout
(default False)
dropout_keep_prob: float, dropout keep probability (default 0.90)
"""
self.num_units = num_units
self.forget_bias = forget_bias
self.use_layer_norm = use_layer_norm
self.use_recurrent_dropout = use_recurrent_dropout
self.dropout_keep_prob = dropout_keep_prob
@property
def output_size(self):
return self.num_units
@property
def state_size(self):
return tf.contrib.rnn.LSTMStateTuple(self.num_units, self.num_units)
def __call__(self, x, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
c, h = state
batch_size = x.get_shape().as_list()[0]
x_size = x.get_shape().as_list()[1]
w_init = None # uniform
h_init = lstm_ortho_initializer()
W_xh = tf.get_variable(
'W_xh', [x_size, 4 * self.num_units], initializer=w_init)
W_hh = tf.get_variable(
'W_hh_i', [self.num_units, 4 * self.num_units],
initializer=h_init)
W_full = tf.concat([W_xh, W_hh], 0)
bias = tf.get_variable(
'bias', [4 * self.num_units],
initializer=tf.constant_initializer(0.0))
concat = tf.concat([x, h], 1) # concat for speed.
concat = tf.matmul(concat, W_full) + bias
# new way of doing layer norm (faster)
if self.use_layer_norm:
concat = layer_norm_all(
concat, batch_size, 4, self.num_units, 'ln')
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = tf.split(concat, 4, 1)
if self.use_recurrent_dropout:
g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
else:
g = tf.tanh(j)
new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
if self.use_layer_norm:
new_h = tf.tanh(layer_norm(
new_c, self.num_units, 'ln_c')) * tf.sigmoid(o)
else:
new_h = tf.tanh(new_c) * tf.sigmoid(o)
return new_h, tf.contrib.rnn.LSTMStateTuple(new_c, new_h)
class HyperLSTMCell(tf.contrib.rnn.RNNCell):
'''
HyperLSTM, with Ortho Initialization,
Layer Norm and Recurrent Dropout without Memory Loss.
https://arxiv.org/abs/1609.09106
'''
def __init__(self, num_units, forget_bias=1.0,
use_recurrent_dropout=False, dropout_keep_prob=0.90,
use_layer_norm=True,
hyper_num_units=128, hyper_embedding_size=16,
hyper_use_recurrent_dropout=False):
'''Initialize the Layer Norm HyperLSTM cell.
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (default 1.0).
use_recurrent_dropout: float, Whether to use Recurrent Dropout
(default False)
dropout_keep_prob: float, dropout keep probability (default 0.90)
use_layer_norm: boolean. (default True)
Controls whether we use LayerNorm layers in main LSTM and
HyperLSTM cell.
hyper_num_units: int, number of units in HyperLSTM cell.
(default is 128, recommend experimenting with 256 for larger tasks)
hyper_embedding_size: int, size of signals emitted from HyperLSTM
cell. (default is 4, recommend trying larger
values but larger is not always better)
hyper_use_recurrent_dropout: boolean. (default False)
Controls whether HyperLSTM cell also uses recurrent dropout.
(Not in Paper.)
Recommend turning this on only if hyper_num_units becomes very
large (>= 512)
'''
self.num_units = num_units
self.forget_bias = forget_bias
self.use_recurrent_dropout = use_recurrent_dropout
self.dropout_keep_prob = dropout_keep_prob
self.use_layer_norm = use_layer_norm
self.hyper_num_units = hyper_num_units
self.hyper_embedding_size = hyper_embedding_size
self.hyper_use_recurrent_dropout = hyper_use_recurrent_dropout
self.total_num_units = self.num_units + self.hyper_num_units
self.hyper_cell = LSTMCell(
hyper_num_units,
use_recurrent_dropout=hyper_use_recurrent_dropout,
use_layer_norm=use_layer_norm,
dropout_keep_prob=dropout_keep_prob)
@property
def output_size(self):
return self.num_units
@property
def state_size(self):
return tf.contrib.rnn.LSTMStateTuple(
self.num_units + self.hyper_num_units,
self.num_units + self.hyper_num_units)
def __call__(self, x, state, timestep=0, scope=None):
with tf.variable_scope(scope or type(self).__name__):
total_c, total_h = state
c = total_c[:, 0:self.num_units]
h = total_h[:, 0:self.num_units]
hyper_state = tf.contrib.rnn.LSTMStateTuple(
total_c[:, self.num_units:],
total_h[:, self.num_units:])
w_init = None # uniform
h_init = lstm_ortho_initializer(1.0)
x_size = x.get_shape().as_list()[1]
embedding_size = self.hyper_embedding_size
num_units = self.num_units
batch_size = x.get_shape().as_list()[0]
W_xh = tf.get_variable('W_xh',
[x_size, 4 * num_units], initializer=w_init)
W_hh = tf.get_variable('W_hh',
[num_units, 4 * num_units],
initializer=h_init)
bias = tf.get_variable('bias',
[4 * num_units],
initializer=tf.constant_initializer(0.0))
# concatenate the input and hidden states for hyperlstm input
hyper_input = tf.concat([x, h], 1)
hyper_output, hyper_new_state = self.hyper_cell(
hyper_input, hyper_state)
xh = tf.matmul(x, W_xh)
hh = tf.matmul(h, W_hh)
# split Wxh contributions
ix, jx, fx, ox = tf.split(xh, 4, 1)
ix = hyper_norm(ix, hyper_output, embedding_size,
num_units, 'hyper_ix')
jx = hyper_norm(jx, hyper_output, embedding_size,
num_units, 'hyper_jx')
fx = hyper_norm(fx, hyper_output, embedding_size,
num_units, 'hyper_fx')
ox = hyper_norm(ox, hyper_output, embedding_size,
num_units, 'hyper_ox')
# split Whh contributions
ih, jh, fh, oh = tf.split(hh, 4, 1)
ih = hyper_norm(ih, hyper_output, embedding_size,
num_units, 'hyper_ih')
jh = hyper_norm(jh, hyper_output, embedding_size,
num_units, 'hyper_jh')
fh = hyper_norm(fh, hyper_output, embedding_size,
num_units, 'hyper_fh')
oh = hyper_norm(oh, hyper_output, embedding_size,
num_units, 'hyper_oh')
# split bias
ib, jb, fb, ob = tf.split(bias, 4, 0) # bias is to be broadcasted.
ib = hyper_bias(ib, hyper_output, embedding_size,
num_units, 'hyper_ib')
jb = hyper_bias(jb, hyper_output, embedding_size,
num_units, 'hyper_jb')
fb = hyper_bias(fb, hyper_output, embedding_size,
num_units, 'hyper_fb')
ob = hyper_bias(ob, hyper_output, embedding_size,
num_units, 'hyper_ob')
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i = ix + ih + ib
j = jx + jh + jb
f = fx + fh + fb
o = ox + oh + ob
if self.use_layer_norm:
concat = tf.concat([i, j, f, o], 1)
concat = layer_norm_all(
concat, batch_size, 4, num_units, 'ln_all')
i, j, f, o = tf.split(concat, 4, 1)
if self.use_recurrent_dropout:
g = tf.nn.dropout(tf.tanh(j), self.dropout_keep_prob)
else:
g = tf.tanh(j)
new_c = c * tf.sigmoid(f + self.forget_bias) + tf.sigmoid(i) * g
if self.use_layer_norm:
new_h = tf.tanh(layer_norm(
new_c, num_units, 'ln_c')) * tf.sigmoid(o)
else:
new_h = tf.tanh(new_c) * tf.sigmoid(o)
hyper_c, hyper_h = hyper_new_state
new_total_c = tf.concat([new_c, hyper_c], 1)
new_total_h = tf.concat([new_h, hyper_h], 1)
return new_h, tf.contrib.rnn.LSTMStateTuple(new_total_c, new_total_h)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment