Last active
November 18, 2017 16:14
-
-
Save qwfy/7e11269a16f02f5463d719e1b0605b3b to your computer and use it in GitHub Desktop.
LSTM states wiring in Keras
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
When using Keras' LSTM with a batch of samples for the first | |
time, the wiring of the states between elements in one sequence, | |
and between sequences in one batch, is not a obvious thing. | |
This document attempts to clarify this. | |
Assumes that keras.recurrent.LSTM's stateful parameter is set | |
to False. | |
Put simply, in the forward pass, every sequence (of one batch) | |
starts with its own initial state, (initilized at every batch, | |
not just at the first batch of the training loop), and mutate | |
independently afterwards. | |
That is, if a batch has three samples [seq1, seq2, seq3], then | |
the total states needed by the lstm would be: | |
[ (h1, c1) # for sequence seq1 | |
, (h2, c2) # for sequence seq2 | |
, (h3, c3) # for sequence seq3 | |
] | |
hi = hj at time 0, and the same for ci. | |
See the following code for more. | |
""" | |
import keras | |
import keras.backend as K | |
from keras.models import Model | |
from keras.layers import Dense, LSTM, Input | |
import numpy as np | |
import scipy | |
# Replies on some implementation details of Keras, | |
# other version may not work, but you can try. | |
assert keras.__version__ == '2.0.8' | |
# %% ------------------------------------- | |
# Generate some toy data | |
# Call the elements of the sequence "words", for the lack of | |
# a better name. In Keras, this is called input_dim | |
word_size = 4 | |
# Sequence length, aka. time steps, steps to unroll | |
seq_len = 2 | |
batch_size = 3 | |
# LSTM's hidden/carry state size | |
h_size = 5 | |
total = 13 | |
raw = np.array([np.random.randint(word_size) | |
for _ in range(total)]) | |
print(f"raw {raw}") | |
# raw [3 1 1 3 0 0 3 1 2 3 1 3 0] | |
# Sequences and their targets, e.g. | |
# xs[0] = [3, 1] # will be one hot encoded | |
# ys[0] = 1 # will be one hot enocded | |
xs = [] | |
ys = [] | |
def one_hot(index, length): | |
a = np.zeros(shape=(length,)) | |
a[index] = 1 | |
return a | |
for i in range(0, total, seq_len): | |
if i + seq_len < total: | |
x = raw[i:i+seq_len] | |
y = raw[i+seq_len] | |
xs.append([one_hot(z, word_size) for z in x]) | |
ys.append(one_hot(y, word_size)) | |
xs = np.array(xs) | |
ys = np.array(ys) | |
print(f"xs.shape {xs.shape}") | |
print(f"ys.shape {ys.shape}") | |
# xs.shape (6, 2, 4) # sequence length is 2 | |
# ys.shape (6, 4) | |
# %% ------------------------------------- | |
# Keras model | |
ipt = Input(batch_shape=(batch_size, seq_len, word_size)) | |
lstm_layer = LSTM(units=h_size, | |
stateful=False, | |
return_sequences=False, | |
return_state=False, | |
unroll=False, | |
implementation=1, | |
# since we want to check the implementation, | |
# we initilize bias to be non zero | |
bias_initializer=keras.initializers.RandomNormal(), | |
recurrent_activation=keras.activations.sigmoid, | |
activation=keras.activations.tanh) | |
lstm = lstm_layer(ipt) | |
dense_layer = Dense(word_size, | |
bias_initializer=keras.initializers.RandomNormal()) | |
dense = dense_layer(lstm) | |
model = Model(inputs=ipt, outputs=dense) | |
model.compile(loss=keras.losses.categorical_crossentropy, | |
optimizer=keras.optimizers.SGD()) | |
# %% ------------------------------------- | |
# Our implementation, to check the understanding of the wiring of lstm's states | |
def array_to_row_vec(x): | |
assert len(x.shape) == 1 | |
return x.reshape((1, x.shape[0])) | |
def one_sequence(sequence, # (seq_len, word_size) | |
hprev, cprev, # (1, h_size), hidden state and carry state | |
wi, wf, wc, wo, # (word_size, h_size), weights for input | |
rwi, rwf, rwc, rwo, # (h_size, h_size), weights for hprev | |
bi, bf, bc, bo, # (1, h_size), biases | |
a, ra, # activation and recurrent activation for lstm | |
wd, bd # (h_size, word_size), (1, word_size), weights and biases for the dense layer, no activation | |
): # (1, word_size), return only the last hidden state | |
""" | |
Different from Keras, we process one sequence at a time. | |
See one_batch for more. | |
""" | |
h, c = hprev.copy(), cprev.copy() | |
assert h is not hprev | |
assert c is not cprev | |
# map function, modify h and c as a side effect | |
def f(word): | |
nonlocal h, c | |
word = array_to_row_vec(word) | |
# these six lines define a lstm | |
fgate = ra(np.dot(word, wf) + bf + np.dot(h, rwf)) | |
igate = ra(np.dot(word, wi) + bi + np.dot(h, rwi)) | |
ogate = ra(np.dot(word, wo) + bo + np.dot(h, rwo)) | |
ctilde = a(np.dot(word, wc) + bc + np.dot(h, rwc)) | |
new_c = fgate * c + igate * ctilde | |
new_h = ogate * a(new_c) | |
h, c = new_h, new_c | |
assert c.shape == (1, h_size) | |
assert h.shape == (1, h_size) | |
return h | |
# lstm layer | |
outputs = list(map(f, sequence)) | |
last_h = outputs[-1] | |
# dense layer | |
dense = np.dot(last_h, wd) + bd | |
return dense | |
def one_batch(sequences, hprevs, cprevs, *wrbas): | |
""" | |
Process one batch of samples. | |
This implementation is different from Keras': | |
Given a batch: | |
A B C D | |
C D E F | |
E F G H (batch_size, seq_len) | |
This implementation: | |
for each row, (one sequence): | |
move from left to right, (along "time" axis), | |
produce a new hidden state and a new carry state (discarded). | |
Keras: | |
also moves from left to right, | |
but it does three rows at one time, more efficient. | |
In both cases, all three (batch_size) rows starts from an initial (zero) state (h=0, c=0), | |
that is, (h=0, t=0) is copied three times, | |
and states (h and c) between sequences (rows in this example) doesn't | |
affect each other in the forward pass, put graphically: | |
(h1=0, c1=0) -> A B C D -> (h1, c1) | |
(h2=0, c2=0) -> C D E F -> (h2, c2) | |
(h3=0, c3=0) -> E F G H -> (h3, c3) | |
(hi, ci are mutable) | |
During the forward pass, h1, h2 and h3 doesn't affect each other, | |
and the same for c1, c2, c3. | |
""" | |
assert sequences.shape == (batch_size, seq_len, word_size) | |
def f(arg): | |
(sequence, hprev, cprev) = arg | |
hprev = array_to_row_vec(hprev) | |
cprev = array_to_row_vec(cprev) | |
return one_sequence(sequence, hprev, cprev, *wrbas) | |
outputs = map(f, zip(sequences, hprevs, cprevs)) | |
return np.concatenate(list(outputs), axis=0) | |
# Get the values of parameters used by Keras | |
# lstm layer | |
initial_hs = np.zeros(shape=(batch_size, h_size)) | |
initial_cs = np.zeros(shape=(batch_size, h_size)) | |
wi = K.get_value(lstm_layer.kernel_i) | |
wf = K.get_value(lstm_layer.kernel_f) | |
wc = K.get_value(lstm_layer.kernel_c) | |
wo = K.get_value(lstm_layer.kernel_o) | |
rwi = K.get_value(lstm_layer.recurrent_kernel_i) | |
rwf = K.get_value(lstm_layer.recurrent_kernel_f) | |
rwc = K.get_value(lstm_layer.recurrent_kernel_c) | |
rwo = K.get_value(lstm_layer.recurrent_kernel_o) | |
bi = K.get_value(lstm_layer.bias_i) | |
bf = K.get_value(lstm_layer.bias_f) | |
bc = K.get_value(lstm_layer.bias_c) | |
bo = K.get_value(lstm_layer.bias_o) | |
assert wi.shape == (word_size, h_size) | |
assert rwi.shape == (h_size, h_size) | |
assert bi.shape == (h_size,) | |
# dense layer | |
wd = K.get_value(dense_layer.kernel) | |
bd = K.get_value(dense_layer.bias) | |
assert wd.shape == (h_size, word_size) | |
assert bd.shape == (word_size,) | |
# %% ------------------------------------- | |
# Compare two implementations | |
pred_batch = xs[:batch_size] | |
keras_predicts = model.predict(pred_batch) | |
our_predicts = one_batch(pred_batch, | |
initial_hs, initial_cs, | |
wi, wf, wc, wo, | |
rwi, rwf, rwc, rwo, | |
array_to_row_vec(bi), | |
array_to_row_vec(bf), | |
array_to_row_vec(bc), | |
array_to_row_vec(bo), | |
np.tanh, scipy.special.expit, | |
wd, array_to_row_vec(bd)) | |
print("keras' predicts:") | |
print(keras_predicts) | |
print() | |
print("our predicts:") | |
print(our_predicts) | |
print() | |
print(f"norm of keras:\n{np.linalg.norm(keras_predicts, ord='fro')}") | |
print() | |
print(f"norm of diff (this should be very small):\n" | |
f"{np.linalg.norm(keras_predicts-our_predicts, ord='fro')}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment