Created
May 16, 2019 20:15
-
-
Save L-Lewis/091ae3acd748a7d63596d63f71b51c01 to your computer and use it in GitHub Desktop.
Function for evaluating a neural network for regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def nn_model_evaluation(model, skip_epochs=0, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test): | |
""" | |
For a given neural network model that has already been fit, prints for the train and tests sets the MSE and r squared | |
values, a line graph of the loss in each epoch, and a scatterplot of predicted vs. actual values with a line | |
representing where predicted = actual values. Optionally, a value for skip_epoch can be provided, which skips that | |
number of epochs in the line graph of losses (useful in cases where the loss in the first epoch is orders of magnitude | |
larger than subsequent epochs). Training and test sets can also optionally be specified. | |
""" | |
# MSE and r squared values | |
y_test_pred = model.predict(X_test) | |
y_train_pred = model.predict(X_train) | |
print("Training MSE:", round(mean_squared_error(y_train, y_train_pred),4)) | |
print("Validation MSE:", round(mean_squared_error(y_test, y_test_pred),4)) | |
print("\nTraining r2:", round(r2_score(y_train, y_train_pred),4)) | |
print("Validation r2:", round(r2_score(y_test, y_test_pred),4)) | |
# Line graph of losses | |
model_results = model.history.history | |
plt.plot(list(range((skip_epochs+1),len(model_results['loss'])+1)), model_results['loss'][skip_epochs:], label='Train') | |
plt.plot(list(range((skip_epochs+1),len(model_results['val_loss'])+1)), model_results['val_loss'][skip_epochs:], label='Test', color='green') | |
plt.legend() | |
plt.title('Training and test loss at each epoch', fontsize=14) | |
plt.show() | |
# Scatterplot of predicted vs. actual values | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
fig.suptitle('Predicted vs. actual values', fontsize=14, y=1) | |
plt.subplots_adjust(top=0.93, wspace=0) | |
ax1.scatter(y_test, y_test_pred, s=2, alpha=0.7) | |
ax1.plot(list(range(2,8)), list(range(2,8)), color='black', linestyle='--') | |
ax1.set_title('Test set') | |
ax1.set_xlabel('Actual values') | |
ax1.set_ylabel('Predicted values') | |
ax2.scatter(y_train, y_train_pred, s=2, alpha=0.7) | |
ax2.plot(list(range(2,8)), list(range(2,8)), color='black', linestyle='--') | |
ax2.set_title('Train set') | |
ax2.set_xlabel('Actual values') | |
ax2.set_ylabel('') | |
ax2.set_yticklabels(labels='') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment