Skip to content

Instantly share code, notes, and snippets.

@helve2017
Last active November 3, 2019 05:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save helve2017/db03d15779a87063d81ed61ea3ae449c to your computer and use it in GitHub Desktop.
Save helve2017/db03d15779a87063d81ed61ea3ae449c to your computer and use it in GitHub Desktop.
KerasでステートフルRNNを使ったサンプルコード
# -*- coding: utf-8 -*-
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, SimpleRNN, GRU, LSTM
#%% data preparation
df = pd.read_csv("osaka_temperature2009_2018.csv",
index_col=0, parse_dates=True)
df = df.interpolate(method="linear")
#df.plot()
ss = StandardScaler()
std = ss.fit_transform(df)
std = std.astype(np.float32)
#std.plot()
#%% data arranging
timesteps = 6
batch_size = timesteps
x = np.empty([len(std)-timesteps, timesteps], dtype=np.float32)
y = np.empty(len(std)-timesteps, dtype=np.float32)
for i in range(len(x)):
x[i] = std[i:i+timesteps].T
y[i] = std[i+timesteps]
data_len = batch_size*int(len(x)/batch_size)
x = x[:data_len].reshape(data_len,timesteps,-1)
y = y[:data_len].reshape(data_len,-1)
actfunc = "tanh"
N_EPOCH = 3
#%% stateless model
model = Sequential()
model.add(SimpleRNN(10, activation=actfunc,
stateful=False,
input_shape=(timesteps, 1)))
model.add(Dense(10, activation=actfunc))
model.add(Dense(1))
model.compile(optimizer='RMSprop', loss='mean_squared_error')
history = model.fit(x, y, epochs=N_EPOCH, batch_size=batch_size,
verbose=1, shuffle=False)
#%% stateful model
model = Sequential()
model.add(SimpleRNN(10, activation=actfunc,
stateful=True,
input_shape=(timesteps, 1),
batch_size=batch_size))
model.add(Dense(10, activation=actfunc))
model.add(Dense(1))
model.compile(optimizer='RMSprop', loss='mean_squared_error')
for i in range(N_EPOCH):
history = model.fit(x, y, epochs=1, batch_size=batch_size, verbose=1, shuffle=False)
model.reset_states()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment