Created
May 27, 2019 10:45
-
-
Save pocokhc/dbc897c94380839a1c09faef0d079c52 to your computer and use it in GitHub Desktop.
KerasのステートレスLSTMとステートフルLSTMを検証した時のコードです。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.models import Sequential | |
from keras.layers import * | |
from keras.optimizers import Adam | |
from keras.preprocessing.sequence import TimeseriesGenerator | |
from keras.utils import np_utils | |
from keras import backend as K | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import random | |
import time | |
# http://torch.classcat.com/2018/06/26/keras-ex-tutorials-stateful-lstm/ | |
def main(seq_length, batch_size ,model_type, shape, epochs, shuffle, test_every_reset, hidden_state=False): | |
# 生データセットを定義します | |
alphabet = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |
alphabet_int = [ i for i in range(len(alphabet))] | |
# 文字の数字 (0-25) へのマッピングとその逆を作成します。 | |
char_to_int = dict((c, i) for i, c in enumerate(alphabet)) | |
int_to_char = dict((i, c) for i, c in enumerate(alphabet)) | |
def int_to_char_seq(seq): | |
seq = seq.reshape(seq_length) | |
s = "" | |
for c in seq: | |
c = int(c * float(len(alphabet))) | |
s += int_to_char[c] | |
return s | |
#------- | |
# https://keras.io/ja/preprocessing/sequence/ | |
data = TimeseriesGenerator(alphabet_int, alphabet_int, length=seq_length)[0] | |
x_data = data[0] | |
y_data = data[1] | |
print(x_data) | |
print(y_data) | |
# normalize | |
x_data = x_data / float(len(alphabet)) | |
x_data = np.reshape(x_data, (len(x_data),) + shape ) #(batch_size,len,data) | |
# one hot encode the output variable | |
y_data = np_utils.to_categorical(y_data) | |
print(x_data.shape) | |
print(y_data.shape) | |
#------------------- | |
# model | |
model = Sequential() | |
if model_type == "dense": | |
model.add(Flatten(input_shape=shape)) | |
model.add(Dense(16)) | |
elif model_type == "lstm": | |
model.add(LSTM(16, input_shape=shape)) | |
elif model_type == "lstm_ful": | |
model.add(LSTM(16, batch_input_shape=(batch_size,) + shape, stateful=True, name="lstm")) | |
else: | |
raise ValueError("model type error.") | |
model.add(Dense(y_data.shape[1], activation="softmax")) | |
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) | |
#---------------------- | |
# train | |
t0 = time.time() | |
if model_type == "lstm_ful": | |
for _ in range(epochs): | |
model.reset_states() | |
model.fit(np.asarray(x_data), np.asarray(y_data), epochs=1, batch_size=batch_size, shuffle=shuffle, verbose=0) | |
else: | |
model.fit(np.asarray(x_data), np.asarray(y_data), epochs=epochs, batch_size=batch_size, shuffle=shuffle, verbose=0) | |
print("fit time : ", time.time()-t0) | |
#---------------------- | |
# summarize performance of the model | |
t0 = time.time() | |
if model_type == "lstm_ful": | |
model.reset_states() | |
scores = model.evaluate(np.asarray(x_data), np.asarray(y_data), batch_size=batch_size, verbose=0) | |
print("Model Accuracy: %.2f%%" % (scores[1]*100)) | |
#---------------------- | |
# test1 | |
pred1_ok = 0 | |
if model_type == "lstm_ful": | |
model.reset_states() | |
for i in reversed(range(len(x_data))): | |
x = x_data[i] | |
y = y_data[i] | |
t = np.asarray([x for _ in range(batch_size)]) | |
if test_every_reset: | |
model.reset_states() | |
pre = model.predict(t)[0] | |
if np.argmax(pre) == np.argmax(y): | |
pred1_ok += 1 | |
print(int_to_char_seq(x), "->", int_to_char[np.argmax(pre)]) | |
print("Test1 Accuracy: %.2f%%" % (pred1_ok/len(x_data)*100)) | |
#---------------------- | |
# test2 | |
if model_type == "lstm_ful": | |
model.reset_states() | |
pred2_ok = 0 | |
x = x_data[0] | |
for i in range(len(x_data)): | |
y = y_data[i] | |
t = np.asarray([x for _ in range(batch_size)]) | |
if test_every_reset: | |
model.reset_states() | |
pre = model.predict(t)[0] | |
print(int_to_char_seq(x), "->", int_to_char[np.argmax(pre)]) | |
if np.argmax(pre) == np.argmax(y): | |
pred2_ok += 1 | |
x = x.reshape(seq_length) | |
x = np.delete(x, 0) | |
x = np.append(x, np.argmax(pre) / float(len(alphabet))) | |
x = x.reshape(shape) | |
print("Test2 Accuracy: %.2f%%" % (pred2_ok/len(x_data)*100)) | |
print("test time : ", time.time()-t0) | |
#-------------------------------------- | |
if not hidden_state: | |
return | |
# 同じ model を作成 | |
model2 = Sequential() | |
model2.add(LSTM(16, batch_input_shape=(batch_size,) + shape, stateful=True, name="lstm")) | |
model2.add(Dense(y_data.shape[1], activation="softmax")) | |
model2.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) | |
#------------------------------- | |
# train | |
t0 = time.time() | |
for _ in range(epochs): | |
# まず K まで model の短期記憶を進める。 | |
model.reset_states() | |
for i in range(len(x_data)): | |
if i >= 11: | |
break | |
x = x_data[i] | |
model.predict(np.asarray([x]))[0] | |
# hidden state を取得 | |
lstm = model.get_layer("lstm") | |
state = [K.get_value(lstm.states[0]), K.get_value(lstm.states[1])] | |
# model2 の lstm の初期状態を state にする | |
model2.get_layer("lstm").reset_states(state) | |
# 学習 | |
model2.fit(np.asarray(x_data), np.asarray(y_data), epochs=1, batch_size=1, shuffle=False, verbose=0) | |
print("fit time : ", time.time()-t0) | |
#------------------------------- | |
# test1 普通に始めた場合 | |
pred1_ok = 0 | |
model2.reset_states() | |
for i in range(len(x_data)): | |
x = x_data[i] | |
y = y_data[i] | |
pre = model2.predict(np.asarray([x]))[0] | |
if np.argmax(pre) == np.argmax(y): | |
pred1_ok += 1 | |
print(int_to_char_seq(x), "->", int_to_char[np.argmax(pre)]) | |
print("Test1 Accuracy: %.2f%%" % (pred1_ok/len(x_data)*100)) | |
#------------------------------- | |
# test2 K から始めた場合 | |
pred2_ok = 0 | |
model2.reset_states() | |
for i in range(10, len(x_data)): | |
x = x_data[i] | |
y = y_data[i] | |
pre = model2.predict(np.asarray([x]))[0] | |
if np.argmax(pre) == np.argmax(y): | |
pred2_ok += 1 | |
print(int_to_char_seq(x), "->", int_to_char[np.argmax(pre)]) | |
print("Test2 Accuracy: %.2f%%" % (pred2_ok/len(x_data)*100)) | |
#------------------------------- | |
# test3 学習と同じ条件 | |
pred3_ok = 0 | |
# まず K まで model の短期記憶を進める。 | |
model.reset_states() | |
for i in range(len(x_data)): | |
if i >= 11: | |
break | |
x = x_data[i] | |
model.predict(np.asarray([x]))[0] | |
# hidden state を取得 | |
lstm = model.get_layer("lstm") | |
state = [K.get_value(lstm.states[0]), K.get_value(lstm.states[1])] | |
# model2 の lstm の初期状態を state にする | |
model2.get_layer("lstm").reset_states(state) | |
for i in range(len(x_data)): | |
x = x_data[i] | |
y = y_data[i] | |
pre = model2.predict(np.asarray([x]))[0] | |
if np.argmax(pre) == np.argmax(y): | |
pred3_ok += 1 | |
print(int_to_char_seq(x), "->", int_to_char[np.argmax(pre)]) | |
print("Test3 Accuracy: %.2f%%" % (pred3_ok/len(x_data)*100)) | |
#------ | |
# 実験1 | |
#main(seq_length=1, batch_size=1, model_type="dense", shape=(1, 1), epochs=500, shuffle=True, test_every_reset=False) | |
#main(seq_length=1, batch_size=1, model_type="lstm", shape=(1, 1), epochs=500, shuffle=True, test_every_reset=False) | |
#main(seq_length=1, batch_size=1, model_type="lstm_ful", shape=(1, 1), epochs=500, shuffle=False, test_every_reset=False) | |
#main(seq_length=1, batch_size=1, model_type="lstm_ful", shape=(1, 1), epochs=500, shuffle=True, test_every_reset=False) | |
# 実験2 | |
#main(seq_length=3, batch_size=1, model_type="dense", shape=(1, 3), epochs=500, shuffle=True, test_every_reset=False) | |
#main(seq_length=3, batch_size=1, model_type="lstm", shape=(1, 3), epochs=500, shuffle=True, test_every_reset=False) | |
#main(seq_length=3, batch_size=1, model_type="lstm_ful", shape=(1, 3), epochs=500, shuffle=False, test_every_reset=False) | |
# 実験3 | |
#main(seq_length=3, batch_size=1, model_type="dense", shape=(3, 1), epochs=500, shuffle=True, test_every_reset=False) | |
#main(seq_length=3, batch_size=1, model_type="lstm", shape=(3, 1), epochs=500, shuffle=True, test_every_reset=False) | |
#main(seq_length=3, batch_size=1, model_type="lstm_ful", shape=(3, 1), epochs=500, shuffle=False, test_every_reset=False) | |
# 実験4 | |
#main(seq_length=3, batch_size=23, model_type="lstm", shape=(3, 1), epochs=1000, shuffle=True, test_every_reset=False) | |
#main(seq_length=3, batch_size=23, model_type="lstm_ful", shape=(3, 1), epochs=1000, shuffle=True, test_every_reset=True) | |
#main(seq_length=3, batch_size=23, model_type="lstm_ful", shape=(3, 1), epochs=1000, shuffle=False, test_every_reset=False) | |
# hidden state | |
main(seq_length=3, batch_size=1, model_type="lstm_ful", shape=(3, 1), epochs=1000, shuffle=False, test_every_reset=False, hidden_state=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment