Skip to content

Instantly share code, notes, and snippets.

@L-Lewis
Created May 16, 2019 20:15
Show Gist options
  • Save L-Lewis/091ae3acd748a7d63596d63f71b51c01 to your computer and use it in GitHub Desktop.
Save L-Lewis/091ae3acd748a7d63596d63f71b51c01 to your computer and use it in GitHub Desktop.
Function for evaluating a neural network for regression
def nn_model_evaluation(model, skip_epochs=0, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test):
"""
For a given neural network model that has already been fit, prints for the train and tests sets the MSE and r squared
values, a line graph of the loss in each epoch, and a scatterplot of predicted vs. actual values with a line
representing where predicted = actual values. Optionally, a value for skip_epoch can be provided, which skips that
number of epochs in the line graph of losses (useful in cases where the loss in the first epoch is orders of magnitude
larger than subsequent epochs). Training and test sets can also optionally be specified.
"""
# MSE and r squared values
y_test_pred = model.predict(X_test)
y_train_pred = model.predict(X_train)
print("Training MSE:", round(mean_squared_error(y_train, y_train_pred),4))
print("Validation MSE:", round(mean_squared_error(y_test, y_test_pred),4))
print("\nTraining r2:", round(r2_score(y_train, y_train_pred),4))
print("Validation r2:", round(r2_score(y_test, y_test_pred),4))
# Line graph of losses
model_results = model.history.history
plt.plot(list(range((skip_epochs+1),len(model_results['loss'])+1)), model_results['loss'][skip_epochs:], label='Train')
plt.plot(list(range((skip_epochs+1),len(model_results['val_loss'])+1)), model_results['val_loss'][skip_epochs:], label='Test', color='green')
plt.legend()
plt.title('Training and test loss at each epoch', fontsize=14)
plt.show()
# Scatterplot of predicted vs. actual values
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
fig.suptitle('Predicted vs. actual values', fontsize=14, y=1)
plt.subplots_adjust(top=0.93, wspace=0)
ax1.scatter(y_test, y_test_pred, s=2, alpha=0.7)
ax1.plot(list(range(2,8)), list(range(2,8)), color='black', linestyle='--')
ax1.set_title('Test set')
ax1.set_xlabel('Actual values')
ax1.set_ylabel('Predicted values')
ax2.scatter(y_train, y_train_pred, s=2, alpha=0.7)
ax2.plot(list(range(2,8)), list(range(2,8)), color='black', linestyle='--')
ax2.set_title('Train set')
ax2.set_xlabel('Actual values')
ax2.set_ylabel('')
ax2.set_yticklabels(labels='')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment