Skip to content

Instantly share code, notes, and snippets.

@darden1
Last active November 8, 2018 12:58
Show Gist options
  • Save darden1/4ce1248902ca952a8ee0f2c5fed859e9 to your computer and use it in GitHub Desktop.
Save darden1/4ce1248902ca952a8ee0f2c5fed859e9 to your computer and use it in GitHub Desktop.
keras_simplernn_return_sequences_trye.py
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN
from keras.optimizers import SGD
n_train = 80
X_train, X_val = X[:n_train], X[n_train:]
Y_train, Y_val = Y[:n_train], Y[n_train:]
batch_size = 10
n_epochs = 200
lr = 0.001
rnn_units = 128
n_features = X.shape[-1]
n_classes = Y.shape[-1]
model_rst = Sequential()
model_rst.add(SimpleRNN(rnn_units, input_shape=(n_sequence, n_features), return_sequences=True))
model_rst.add(Dense(n_classes, activation="linear"))
model_rst.compile(loss='mean_squared_error', optimizer=SGD(lr))
history_rst = model_rst.fit(X_train, Y_train,
batch_size=batch_size,
epochs=n_epochs,
validation_data=(X_val, Y_val),
shuffle=True,
verbose=2)
indices = range(n_epochs)
plt.plot(indices, history_rst.history["loss"], label="loss")
plt.plot(indices, history_rst.history["val_loss"], label="val_loss")
plt.legend(loc="best")
plt.title("train history")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.grid(True)
plt.show()
Y_pred_rst = model_rst.predict(X)
plt.plot(T, Y[:, -1, :], label="true")
plt.plot(T, Y_pred_rst[:, -1, :], label="pred")
plt.legend(loc='best')
plt.title("true and pred")
plt.xlabel("time")
plt.ylabel("amplitude")
plt.xlim([0,1])
plt.ylim([-2,2])
plt.grid(True)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment