Skip to content

Instantly share code, notes, and snippets.

@pekaalto
Created January 9, 2020 05:37
Show Gist options
  • Save pekaalto/026d0248b7a5477380dd21c4ca637c09 to your computer and use it in GitHub Desktop.
Save pekaalto/026d0248b7a5477380dd21c4ca637c09 to your computer and use it in GitHub Desktop.
Investigate keras-lstm inputs, outputs and weights.
"""
Investigate keras-lstm inputs, outputs and weights.
Needs tensorflow 2.0
Note: The explanation of weights matches the CPU-implementation of LSTM-layer.
In GPU-implementation the weights are organized slightly differently
"""
import numpy as np
import tensorflow as tf
from scipy.special import expit as sigmoid
LSTM_UNITS = 5
TIME_STEPS = 3
INPUT_DIM = 6
BATCH_SIZE = 2
lstm_layer = tf.keras.layers.LSTM(
units=LSTM_UNITS,
return_sequences=True,
return_state=True,
use_bias=True,
bias_initializer="uniform",
)
input_x_tensor = tf.random.normal(shape=(BATCH_SIZE, TIME_STEPS, INPUT_DIM))
initial_c_state_tensor = tf.random.normal(shape=(BATCH_SIZE, LSTM_UNITS))
initial_h_state_tensor = tf.random.normal(shape=(BATCH_SIZE, LSTM_UNITS))
h_state_sequence, h_state_last, cell_state_last = lstm_layer(
input_x_tensor, initial_state=[initial_h_state_tensor, initial_c_state_tensor]
)
np.testing.assert_array_equal(h_state_sequence[:, -1, :], h_state_last)
assert h_state_sequence.shape == (BATCH_SIZE, TIME_STEPS, LSTM_UNITS)
assert h_state_last.shape == (BATCH_SIZE, LSTM_UNITS)
assert cell_state_last.shape == (BATCH_SIZE, LSTM_UNITS)
lstm_weights = lstm_layer.get_weights()
assert [w.shape for w in lstm_weights] == [
(INPUT_DIM, 4 * LSTM_UNITS),
(LSTM_UNITS, 4 * LSTM_UNITS),
(4 * LSTM_UNITS,),
]
kernel, recurrent_kernel, bias = lstm_layer.get_weights()
big_w = np.concatenate([recurrent_kernel, kernel], axis=0).T
W_i, W_f, W_c, W_o = np.split(big_w, indices_or_sections=4, axis=0)
b_i, b_f, b_c, b_o = np.split(bias, indices_or_sections=4, axis=0)
for w in [W_i, W_f, W_c, W_o]:
assert w.shape == (LSTM_UNITS, INPUT_DIM + LSTM_UNITS)
for b in [b_i, b_f, b_c, b_o]:
assert b.shape == (LSTM_UNITS,)
class LstmSimpleForward:
def __init__(self, W_i, W_f, W_c, W_o, b_i, b_f, b_c, b_o):
"""
W's have shape (LSTM_UNITS, INPUT_DIM + LSTM_UNITS)
b's have shape (LSTM_UNITS,)
"""
self.W_i = W_i
self.W_f = W_f
self.W_c = W_c
self.W_o = W_o
self.b_i = b_i
self.b_f = b_f
self.b_c = b_c
self.b_o = b_o
def step_one(self, h_t1, c_t1, xt):
"""
Calculates one time-step in lstm
:param h_t1: shape [BATCH_SIZE, LSTM_UNITS]
:param c_t1: shape [BATCH_SIZE, LSTM_UNITS]
:param xt: shape [BATCH_SIZE, INPUT_DIM]
:return: new (h-state, c-state) -pair
"""
# x_t shape
# h_t1 shape
# c_t1 shape [BATCH_SIZE, LSTM_UNITS]
# [BATCH_SIZE, LSTM_UNITS + INPUT_DIM]
hx = np.concatenate([h_t1, xt], axis=-1)
# Note that we could also concatenate the weights
# into one big matrix and split the result.
# That would be cleaner implmenetation
# but we will want to align here with the operations described
# https://colah.github.io/posts/2015-08-Understanding-LSTMs/
i_raw, f_raw, c_hat_raw, o_raw = [
(np.dot(hx, W.T) + b)
for (W, b) in zip(
[self.W_i, self.W_f, self.W_c, self.W_o],
[self.b_i, self.b_f, self.b_c, self.b_o],
)
]
i = sigmoid(i_raw)
f = sigmoid(f_raw)
o = sigmoid(o_raw)
c_hat = np.tanh(c_hat_raw)
c = f * c_t1 + i * c_hat
h = o * np.tanh(c)
return [np.array(t) for t in [h, c]]
def step_all(self, input_x, initial_h_state, initial_c_state):
timesteps = input_x.shape[1]
h_state_sequence = []
h_state, c_state = initial_h_state, initial_c_state
for i in range(timesteps):
h_state, c_state = self.step_one(h_state, c_state, input_x[:, i, :])
h_state_sequence.append(h_state)
return (
# swap back to batch-major from time-major
np.array(h_state_sequence).swapaxes(0, 1),
c_state,
)
h_state_sequence_2, cell_state_manual_2 = LstmSimpleForward(
W_i, W_f, W_c, W_o, b_i, b_f, b_c, b_o
).step_all(
input_x=input_x_tensor.numpy(),
initial_h_state=initial_h_state_tensor.numpy(),
initial_c_state=initial_c_state_tensor.numpy(),
)
np.testing.assert_almost_equal(cell_state_last, cell_state_manual_2)
np.testing.assert_almost_equal(h_state_sequence, h_state_sequence_2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment