Create a gist now

Instantly share code, notes, and snippets.

Embed
keras LSTM, sample
# encoding: utf-8
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from keras.layers.recurrent import LSTM
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
import sys
import numpy as np
from numpy.random import *
import matplotlib.pyplot as plt
import pandas as pd
import math
import random
#
def toy_problem(T=100, ampl=0.05):
x=rand(200+1) *30
return x
#
def make_dataset(low_data, n_prev=100):
data, target = [], []
maxlen = 25
for i in range(len(low_data)-maxlen):
data.append(low_data[i:i + maxlen])
target.append(low_data[i + maxlen])
re_data = np.array(data).reshape(len(data), maxlen, 1)
re_target = np.array(target).reshape(len(data), 1)
return re_data, re_target
#main
f = toy_problem()
#print(f)
#quit()
g, h = make_dataset(f)
print(g)
print( len(g) )
#print(h)
#quit()
future_test = g[175].T
# 1つの学習データの時間の長さ -> 25
time_length = future_test.shape[1]
# 未来の予測データを保存していく変数
future_result = np.empty((0))
length_of_sequence = g.shape[1]
#length_of_sequence = 10
in_out_neurons = 1
n_hidden = 300
# モデル構築
model = Sequential()
model.add(LSTM(n_hidden, batch_input_shape=(None, length_of_sequence, in_out_neurons), return_sequences=False))
model.add(Dense(in_out_neurons))
model.add(Activation("linear"))
optimizer = Adam(lr=0.001)
model.compile(loss="mean_squared_error", optimizer=optimizer)
# 学習
early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=20)
model.fit(g, h,
batch_size=300,
epochs=100,
validation_split=0.1,
callbacks=[early_stopping]
)
# epochs=100,
# 予測
predicted = model.predict(g)
# 未来予想
#for step2 in range(400):
for step2 in range(100):
test_data = np.reshape(future_test, (1, time_length, 1))
batch_predict = model.predict(test_data)
future_test = np.delete(future_test, 0)
future_test = np.append(future_test, batch_predict)
future_result = np.append(future_result, batch_predict)
print("#future_result")
print(future_result )
# sin波をプロット
plt.figure()
plt.plot(range(25,len(predicted)+25),predicted, color="r", label="predict")
plt.plot(range(0, len(f)), f, color="b", label="row")
plt.plot(range(0+len(f), len(future_result)+len(f)), future_result, color="g", label="future")
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment