Skip to content

Instantly share code, notes, and snippets.

@milhidaka
Created June 19, 2017 09:00
Show Gist options
  • Save milhidaka/bc7899b536f2a54778c3b78e0191fb96 to your computer and use it in GitHub Desktop.
Save milhidaka/bc7899b536f2a54778c3b78e0191fb96 to your computer and use it in GitHub Desktop.
reproduce same result as imdb_lstm model using numpy
# reproduce lstm prediction by basic numpy operations
# model trained on imdb_lstm.py
# based on https://github.com/fchollet/keras/blob/master/keras/layers/recurrent.py#L1130
import numpy as np
from scipy.special import expit # logistic function
import h5py
"""
{'class_name': 'Sequential',
'config': [{'class_name': 'Embedding',
'config': {'activity_regularizer': None,
'batch_input_shape': [None, None],
'dtype': 'int32',
'embeddings_constraint': None,
'embeddings_initializer': {'class_name': 'RandomUniform',
'config': {'maxval': 0.05, 'minval': -0.05, 'seed': None}},
'embeddings_regularizer': None,
'input_dim': 20000,
'input_length': None,
'mask_zero': False,
'name': 'embedding_1',
'output_dim': 128,
'trainable': True}},
{'class_name': 'LSTM',
'config': {'activation': 'tanh',
'activity_regularizer': None,
'bias_constraint': None,
'bias_initializer': {'class_name': 'Zeros', 'config': {}},
'bias_regularizer': None,
'dropout': 0.2,
'go_backwards': False,
'implementation': 0,
'kernel_constraint': None,
'kernel_initializer': {'class_name': 'VarianceScaling',
'config': {'distribution': 'uniform',
'mode': 'fan_avg',
'scale': 1.0,
'seed': None}},
'kernel_regularizer': None,
'name': 'lstm_1',
'recurrent_activation': 'hard_sigmoid',
'recurrent_constraint': None,
'recurrent_dropout': 0.2,
'recurrent_initializer': {'class_name': 'Orthogonal',
'config': {'gain': 1.0, 'seed': None}},
'recurrent_regularizer': None,
'return_sequences': False,
'stateful': False,
'trainable': True,
'unit_forget_bias': True,
'units': 128,
'unroll': False,
'use_bias': True}},
{'class_name': 'Dense',
'config': {'activation': 'sigmoid',
'activity_regularizer': None,
'bias_constraint': None,
'bias_initializer': {'class_name': 'Zeros', 'config': {}},
'bias_regularizer': None,
'kernel_constraint': None,
'kernel_initializer': {'class_name': 'VarianceScaling',
'config': {'distribution': 'uniform',
'mode': 'fan_avg',
'scale': 1.0,
'seed': None}},
'kernel_regularizer': None,
'name': 'dense_1',
'trainable': True,
'units': 1,
'use_bias': True}}]}
"""
def hard_sigmoid(x):
return np.clip(x * 0.2 + 0.5, 0.0, 1.0)
max_features = 20000
# cut texts after this number of words (among top max_features most common
# words)
maxlen = 80
batch_size = 32
hidden_dim = 128
result = np.load("imdb_lstm_result.npz")
x_test = result["x_test"]
y_test = result["y_test"]
pred_test = result["pred_test"]
model_data = h5py.File("imdb_lstm.h5")
# (20000, 128)
w_embedding = model_data["model_weights/embedding_1/embedding_1/embeddings:0"].value
# (128, 512)
w_lstm_kernel = model_data["model_weights/lstm_1/lstm_1/kernel:0"].value
# (128, 512)
w_lstm_recurrent_kernel = model_data["model_weights/lstm_1/lstm_1/recurrent_kernel:0"].value
# (512, )
w_lstm_bias = model_data["model_weights/lstm_1/lstm_1/bias:0"].value
# (128, 1)
w_dense_kernel = model_data["model_weights/dense_1/dense_1/kernel:0"].value
# (1, )
w_dense_bias = model_data["model_weights/dense_1/dense_1/bias:0"].value
for i in range(10): # len(x_test)):
lstm_h = np.zeros((hidden_dim, ), dtype=np.float32)
lstm_cs = np.zeros((hidden_dim, ), dtype=np.float32)
x_embedded = w_embedding[x_test[i]] # (time, 128)
for t in range(len(x_embedded)):
lstm_vec = np.dot(x_embedded[t], w_lstm_kernel) + \
np.dot(lstm_h, w_lstm_recurrent_kernel) + w_lstm_bias
lstm_i = lstm_vec[hidden_dim * 0:hidden_dim * 1]
lstm_f = lstm_vec[hidden_dim * 1:hidden_dim * 2]
lstm_ci = lstm_vec[hidden_dim * 2:hidden_dim * 3]
lstm_o = lstm_vec[hidden_dim * 3:hidden_dim * 4]
lstm_i = hard_sigmoid(lstm_i)
lstm_f = hard_sigmoid(lstm_f)
lstm_ci = np.tanh(lstm_ci)
lstm_cs = lstm_ci * lstm_i + lstm_cs * lstm_f
lstm_o = hard_sigmoid(lstm_o)
lstm_co = np.tanh(lstm_cs)
lstm_h = lstm_co * lstm_o
pred = expit(np.dot(lstm_h, w_dense_kernel) + w_dense_bias)
print(pred, pred_test[i])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment