Skip to content

Instantly share code, notes, and snippets.

@michelkana
Last active May 30, 2022 03:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save michelkana/e3063e521dadcb15769fd866af23d961 to your computer and use it in GitHub Desktop.
Save michelkana/e3063e521dadcb15769fd866af23d961 to your computer and use it in GitHub Desktop.
# prediction
y_pred_without_dropout = model_without_dropout.predict(x_test)
y_pred_with_dropout = model_with_dropout.predict(x_test)
# plotting
fig, ax = plt.subplots(1,1,figsize=(10,5))
ax.scatter(x_train, y_train, s=10, label='train data')
ax.plot(x_test, x_test, ls='--', label='test data', color='green')
ax.plot(x_test, y_pred_without_dropout, label='predicted ANN - R2 {:.2f}'.format(r2_score(x_test, y_pred_without_dropout)), color='red')
ax.plot(x_test, y_pred_with_dropout, label='predicted ANN Dropout - R2 {:.2f}'.format(r2_score(x_test, y_pred_with_dropout)), color='black')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
ax.set_title('test data');
@sonineha2191
Copy link

Code line is missing at the starting-
"from sklearn.metrics import r2_score"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment