Skip to content

Instantly share code, notes, and snippets.

@darden1
Created November 8, 2018 13:32
Show Gist options
  • Save darden1/5dd523d138b12f4870529090f0a7a18a to your computer and use it in GitHub Desktop.
Save darden1/5dd523d138b12f4870529090f0a7a18a to your computer and use it in GitHub Desktop.
myrnn_retur_sequences_false.py
model_myrnn_rsf = RecurrentNeuralNetwork(rnn_units, return_sequences=False)
model_myrnn_rsf.fit(X_train, Y_train_rsf,
batch_size=batch_size,
epochs=n_epochs,
mu=lr,
validation_data=(X_val, Y_val_rsf),
verbose=1)
plt.plot(indices, history_rsf.history["loss"], label="loss (keras)")
plt.plot(indices, history_rsf.history["val_loss"], label="val_loss (keras)")
plt.plot(indices, model_myrnn_rsf.loss, label="loss (my rnn)")
plt.plot(indices, model_myrnn_rsf.val_loss, label="val_loss (my rnn)")
plt.legend(loc="best")
plt.title("train history")
plt.xlabel("epochs")
plt.ylabel("loss")
plt.grid(True)
plt.show()
Y_pred_myrnn_rsf = model_myrnn_rsf.predict(X)
plt.plot(T, Y[:, -1, :], label="true")
plt.plot(T, Y_pred_rsf, label="pred (keras)")
plt.plot(T, Y_pred_myrnn_rsf, label="pred (my rnn)")
plt.legend(loc='best')
plt.title("true and pred")
plt.xlabel("time")
plt.ylabel("amplitude")
plt.xlim([0,1])
plt.ylim([-2,2])
plt.grid(True)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment