Skip to content

Instantly share code, notes, and snippets.

@darden1
Created November 8, 2018 13:31
Show Gist options
  • Save darden1/a6131e98b5d586b7de0d7f31aee5fa49 to your computer and use it in GitHub Desktop.
Save darden1/a6131e98b5d586b7de0d7f31aee5fa49 to your computer and use it in GitHub Desktop.
myrnn_retur_sequences_true.py
model_myrnn_rst = RecurrentNeuralNetwork(rnn_units, return_sequences=True)
model_myrnn_rst.fit(X_train, Y_train,
batch_size=batch_size,
epochs=n_epochs,
mu=lr,
validation_data=(X_val, Y_val),
verbose=1)
plt.plot(indices, history_rst.history["loss"], label="loss (keras)")
plt.plot(indices, history_rst.history["val_loss"], label="val_loss (keras)")
plt.plot(indices, model_myrnn_rst.loss, label="loss (my rnn)")
plt.plot(indices, model_myrnn_rst.val_loss, label="val_loss (my rnn)")
plt.legend(loc="best")
plt.title("train history")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.grid(True)
plt.show()
Y_pred_myrnn_rst = model_myrnn_rst.predict(X)
plt.plot(T, Y[:, -1, :], label="true")
plt.plot(T, Y_pred_rst[:, -1, :], label="pred (keras)")
plt.plot(T, Y_pred_myrnn_rst[:, -1, :], label="pred (my rnn)")
plt.legend(loc='best')
plt.title("true and pred")
plt.xlabel("time")
plt.ylabel("amplitude")
plt.xlim([0,1])
plt.ylim([-2,2])
plt.grid(True)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment