Skip to content

Instantly share code, notes, and snippets.

@darden1
Last active November 9, 2018 15:46
Show Gist options
  • Save darden1/2b0fcb8c3811d9911b95696d8e24a3de to your computer and use it in GitHub Desktop.
Save darden1/2b0fcb8c3811d9911b95696d8e24a3de to your computer and use it in GitHub Desktop.
keras_simplernn_return_sequences_false.py
# ターゲットデータを最終時間のみにする
Y_train_rsf, Y_val_rsf = Y_train[:, -1, :], Y_val[:, -1, :]
model_rsf = Sequential()
model_rsf.add(SimpleRNN(rnn_units, input_shape=(n_sequence, n_features), return_sequences=False))
model_rsf.add(Dense(n_classes, activation="linear"))
model_rsf.compile(loss='mean_squared_error', optimizer=SGD(lr))
history_rsf = model_rsf.fit(X_train, Y_train_rsf,
batch_size=batch_size,
epochs=n_epochs,
validation_data=(X_val, Y_val_rsf),
shuffle=True,
verbose=2)
plt.plot(indices, history_rst.history["loss"], label="loss (return_sequences=True)")
plt.plot(indices, history_rst.history["val_loss"], label="val_loss (return_sequences=True)")
plt.plot(indices, history_rsf.history["loss"], label="loss (return_sequences=Frue)")
plt.plot(indices, history_rsf.history["val_loss"], label="val_loss (return_sequences=Frue)")
plt.legend(loc="best")
plt.title("train history")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.grid(True)
plt.show()
Y_pred_rsf = model_rsf.predict(X)
plt.plot(T, Y[:, -1, :], label="true")
plt.plot(T, Y_pred_rst[:, -1, :], label="pred (return_sequences=True)")
plt.plot(T, Y_pred_rsf, label="pred (return_sequences=False)")
plt.legend(loc='best')
plt.title("true and pred")
plt.xlabel("time")
plt.ylabel("amplitude")
plt.xlim([0, 1])
plt.ylim([-2, 2])
plt.grid(True)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment