Skip to content

Instantly share code, notes, and snippets.

@qwfy
Last active November 18, 2017 16:14
Show Gist options
  • Save qwfy/7e11269a16f02f5463d719e1b0605b3b to your computer and use it in GitHub Desktop.
Save qwfy/7e11269a16f02f5463d719e1b0605b3b to your computer and use it in GitHub Desktop.
LSTM states wiring in Keras
"""
When using Keras' LSTM with a batch of samples for the first
time, the wiring of the states between elements in one sequence,
and between sequences in one batch, is not a obvious thing.
This document attempts to clarify this.
Assumes that keras.recurrent.LSTM's stateful parameter is set
to False.
Put simply, in the forward pass, every sequence (of one batch)
starts with its own initial state, (initilized at every batch,
not just at the first batch of the training loop), and mutate
independently afterwards.
That is, if a batch has three samples [seq1, seq2, seq3], then
the total states needed by the lstm would be:
[ (h1, c1) # for sequence seq1
, (h2, c2) # for sequence seq2
, (h3, c3) # for sequence seq3
]
hi = hj at time 0, and the same for ci.
See the following code for more.
"""
import keras
import keras.backend as K
from keras.models import Model
from keras.layers import Dense, LSTM, Input
import numpy as np
import scipy
# Replies on some implementation details of Keras,
# other version may not work, but you can try.
assert keras.__version__ == '2.0.8'
# %% -------------------------------------
# Generate some toy data
# Call the elements of the sequence "words", for the lack of
# a better name. In Keras, this is called input_dim
word_size = 4
# Sequence length, aka. time steps, steps to unroll
seq_len = 2
batch_size = 3
# LSTM's hidden/carry state size
h_size = 5
total = 13
raw = np.array([np.random.randint(word_size)
for _ in range(total)])
print(f"raw {raw}")
# raw [3 1 1 3 0 0 3 1 2 3 1 3 0]
# Sequences and their targets, e.g.
# xs[0] = [3, 1] # will be one hot encoded
# ys[0] = 1 # will be one hot enocded
xs = []
ys = []
def one_hot(index, length):
a = np.zeros(shape=(length,))
a[index] = 1
return a
for i in range(0, total, seq_len):
if i + seq_len < total:
x = raw[i:i+seq_len]
y = raw[i+seq_len]
xs.append([one_hot(z, word_size) for z in x])
ys.append(one_hot(y, word_size))
xs = np.array(xs)
ys = np.array(ys)
print(f"xs.shape {xs.shape}")
print(f"ys.shape {ys.shape}")
# xs.shape (6, 2, 4) # sequence length is 2
# ys.shape (6, 4)
# %% -------------------------------------
# Keras model
ipt = Input(batch_shape=(batch_size, seq_len, word_size))
lstm_layer = LSTM(units=h_size,
stateful=False,
return_sequences=False,
return_state=False,
unroll=False,
implementation=1,
# since we want to check the implementation,
# we initilize bias to be non zero
bias_initializer=keras.initializers.RandomNormal(),
recurrent_activation=keras.activations.sigmoid,
activation=keras.activations.tanh)
lstm = lstm_layer(ipt)
dense_layer = Dense(word_size,
bias_initializer=keras.initializers.RandomNormal())
dense = dense_layer(lstm)
model = Model(inputs=ipt, outputs=dense)
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.SGD())
# %% -------------------------------------
# Our implementation, to check the understanding of the wiring of lstm's states
def array_to_row_vec(x):
assert len(x.shape) == 1
return x.reshape((1, x.shape[0]))
def one_sequence(sequence, # (seq_len, word_size)
hprev, cprev, # (1, h_size), hidden state and carry state
wi, wf, wc, wo, # (word_size, h_size), weights for input
rwi, rwf, rwc, rwo, # (h_size, h_size), weights for hprev
bi, bf, bc, bo, # (1, h_size), biases
a, ra, # activation and recurrent activation for lstm
wd, bd # (h_size, word_size), (1, word_size), weights and biases for the dense layer, no activation
): # (1, word_size), return only the last hidden state
"""
Different from Keras, we process one sequence at a time.
See one_batch for more.
"""
h, c = hprev.copy(), cprev.copy()
assert h is not hprev
assert c is not cprev
# map function, modify h and c as a side effect
def f(word):
nonlocal h, c
word = array_to_row_vec(word)
# these six lines define a lstm
fgate = ra(np.dot(word, wf) + bf + np.dot(h, rwf))
igate = ra(np.dot(word, wi) + bi + np.dot(h, rwi))
ogate = ra(np.dot(word, wo) + bo + np.dot(h, rwo))
ctilde = a(np.dot(word, wc) + bc + np.dot(h, rwc))
new_c = fgate * c + igate * ctilde
new_h = ogate * a(new_c)
h, c = new_h, new_c
assert c.shape == (1, h_size)
assert h.shape == (1, h_size)
return h
# lstm layer
outputs = list(map(f, sequence))
last_h = outputs[-1]
# dense layer
dense = np.dot(last_h, wd) + bd
return dense
def one_batch(sequences, hprevs, cprevs, *wrbas):
"""
Process one batch of samples.
This implementation is different from Keras':
Given a batch:
A B C D
C D E F
E F G H (batch_size, seq_len)
This implementation:
for each row, (one sequence):
move from left to right, (along "time" axis),
produce a new hidden state and a new carry state (discarded).
Keras:
also moves from left to right,
but it does three rows at one time, more efficient.
In both cases, all three (batch_size) rows starts from an initial (zero) state (h=0, c=0),
that is, (h=0, t=0) is copied three times,
and states (h and c) between sequences (rows in this example) doesn't
affect each other in the forward pass, put graphically:
(h1=0, c1=0) -> A B C D -> (h1, c1)
(h2=0, c2=0) -> C D E F -> (h2, c2)
(h3=0, c3=0) -> E F G H -> (h3, c3)
(hi, ci are mutable)
During the forward pass, h1, h2 and h3 doesn't affect each other,
and the same for c1, c2, c3.
"""
assert sequences.shape == (batch_size, seq_len, word_size)
def f(arg):
(sequence, hprev, cprev) = arg
hprev = array_to_row_vec(hprev)
cprev = array_to_row_vec(cprev)
return one_sequence(sequence, hprev, cprev, *wrbas)
outputs = map(f, zip(sequences, hprevs, cprevs))
return np.concatenate(list(outputs), axis=0)
# Get the values of parameters used by Keras
# lstm layer
initial_hs = np.zeros(shape=(batch_size, h_size))
initial_cs = np.zeros(shape=(batch_size, h_size))
wi = K.get_value(lstm_layer.kernel_i)
wf = K.get_value(lstm_layer.kernel_f)
wc = K.get_value(lstm_layer.kernel_c)
wo = K.get_value(lstm_layer.kernel_o)
rwi = K.get_value(lstm_layer.recurrent_kernel_i)
rwf = K.get_value(lstm_layer.recurrent_kernel_f)
rwc = K.get_value(lstm_layer.recurrent_kernel_c)
rwo = K.get_value(lstm_layer.recurrent_kernel_o)
bi = K.get_value(lstm_layer.bias_i)
bf = K.get_value(lstm_layer.bias_f)
bc = K.get_value(lstm_layer.bias_c)
bo = K.get_value(lstm_layer.bias_o)
assert wi.shape == (word_size, h_size)
assert rwi.shape == (h_size, h_size)
assert bi.shape == (h_size,)
# dense layer
wd = K.get_value(dense_layer.kernel)
bd = K.get_value(dense_layer.bias)
assert wd.shape == (h_size, word_size)
assert bd.shape == (word_size,)
# %% -------------------------------------
# Compare two implementations
pred_batch = xs[:batch_size]
keras_predicts = model.predict(pred_batch)
our_predicts = one_batch(pred_batch,
initial_hs, initial_cs,
wi, wf, wc, wo,
rwi, rwf, rwc, rwo,
array_to_row_vec(bi),
array_to_row_vec(bf),
array_to_row_vec(bc),
array_to_row_vec(bo),
np.tanh, scipy.special.expit,
wd, array_to_row_vec(bd))
print("keras' predicts:")
print(keras_predicts)
print()
print("our predicts:")
print(our_predicts)
print()
print(f"norm of keras:\n{np.linalg.norm(keras_predicts, ord='fro')}")
print()
print(f"norm of diff (this should be very small):\n"
f"{np.linalg.norm(keras_predicts-our_predicts, ord='fro')}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment