Skip to content

Instantly share code, notes, and snippets.

@mertbozkir
Last active May 14, 2022 12:58
Show Gist options
  • Save mertbozkir/57437162fec3961ce777692549d88a46 to your computer and use it in GitHub Desktop.
Save mertbozkir/57437162fec3961ce777692549d88a46 to your computer and use it in GitHub Desktop.
Learning Curves Plotting function
def val_curve_params(model, X, y, param_name, param_range, scoring = 'roc_auc', cv = 10):
train_score, test_score = validation_curve(
model, X = X, y = y, param_name = param_name, param_range = param_range, scoring = scoring, cv = cv)
mean_train_score = np.mean(train_score, axis = 1)
mean_test_score = np.mean(test_score, axis = 1)
plt.plot(param_range, mean_train_score,
label = 'Training Score', color = 'b')
plt.plot(param_range, mean_test_score,
label = 'Validation Score', color = 'g')
plt.title(f'Validation Curve for {type(model).__name__}')
plt.xlabel(f'Number of {param_name}')
plt.ylabel(f'{scoring}')
plt.tight_layout()
plt.legend(loc = 'best')
plt.show(block = True)
# Example: val_curve_params(random_forest, X, y, 'max_depth', range(1, 11), scoring = 'roc_auc')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment