Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
LSTM Keras Stateful for variable length inputs
from __future__ import print_function
import numpy as np
from keras.layers import Dense
from keras.layers import LSTM
from keras.models import Sequential
from numpy.random import choice
USE_SEQUENCES = False
USE_STATELESS_MODEL = False
# you can all the four possible combinations
# USE_SEQUENCES and USE_STATELESS_MODEL
max_len = 20
N_train = 100
N_test = 10
N = N_train + N_test
var_length_arr = choice(a=range(1, max_len), size=N, replace=True)
x = []
for i in range(N):
x.append(np.zeros((var_length_arr[i], 1)))
y = np.zeros((N, 1))
one_indexes = choice(a=N, size=int(0.5 * N), replace=False)
for i in one_indexes:
x[i][0] = 1
y[i] = 1
X_train = x[:N_train]
X_test = x[N_train:]
y_train = y[:N_train]
y_test = y[N_train:]
# STATEFUL MODEL
print('Build STATEFUL model...')
model = Sequential()
model.add(LSTM(10,
batch_input_shape=(1, 1, 1),
return_sequences=False,
stateful=True))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
print('Train...')
for epoch in range(15):
mean_tr_acc = []
mean_tr_loss = []
for i in range(len(X_train)):
y_true = y_train[i]
for j in range(len(X_train[i])):
tr_loss, tr_acc = model.train_on_batch(np.reshape(X_train[i][j], (1, 1, 1)), np.array([y_true]))
mean_tr_acc.append(tr_acc)
mean_tr_loss.append(tr_loss)
model.reset_states()
print('accuracy training = {}'.format(np.mean(mean_tr_acc)))
print('loss training = {}'.format(np.mean(mean_tr_loss)))
print('___________________________________')
mean_te_acc = []
mean_te_loss = []
for i in range(len(X_test)):
for j in range(len(X_test[i])):
te_loss, te_acc = model.test_on_batch(np.reshape(X_test[i][j], (1, 1, 1)), y_test[i])
mean_te_acc.append(te_acc)
mean_te_loss.append(te_loss)
model.reset_states()
for j in range(len(X_test[i])):
y_pred = model.predict_on_batch(np.reshape(X_test[i][j], (1, 1, 1)))
model.reset_states()
print('accuracy testing = {}'.format(np.mean(mean_te_acc)))
print('loss testing = {}'.format(np.mean(mean_te_loss)))
print('___________________________________')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.