Skip to content

Instantly share code, notes, and snippets.

@SauravMaheshkar
Created November 17, 2021 18:54
Show Gist options
  • Save SauravMaheshkar/3acdbfdb5ae7ca520783606a51b9bc2f to your computer and use it in GitHub Desktop.
Save SauravMaheshkar/3acdbfdb5ae7ca520783606a51b9bc2f to your computer and use it in GitHub Desktop.
W&B 🔥: K-Fold Cross Validation
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "W&B 🔥: K-Fold Cross Validation",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNtQN06rmH289Xfm/95XefY",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/SauravMaheshkar/3acdbfdb5ae7ca520783606a51b9bc2f/w-b-k-fold-cross-validation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "euYUBxAHLNHR"
},
"source": [
"# Author: [@SauravMaheshkar](https://twitter.com/MaheshkarSaurav)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0J_WzJU7LSAb"
},
"source": [
"# Packages 📦 and Basic Setup\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j0Zyy3i7LTOy"
},
"source": [
"## Install Packages\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "VmZTdUxJao_1"
},
"source": [
"%%capture\n",
"\n",
"## Install the latest version of wandb client 🔥🔥\n",
"!pip install -q --upgrade wandb\n",
"\n",
"import numpy as np\n",
"from sklearn import svm\n",
"from sklearn import datasets"
],
"execution_count": 38,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "swvAOJ8dLYIt"
},
"source": [
"## Project Configuration using **`wandb.config`**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ueyfw3G2Lggm"
},
"source": [
"import os\n",
"import wandb\n",
"\n",
"# Paste your api key here\n",
"os.environ[\"WANDB_API_KEY\"] = '...'\n",
"\n",
"# Initialize the run\n",
"run = wandb.init(project='...', entity='...')\n",
"\n",
"# Feel free to change these and experiment !!\n",
"config = wandb.config\n",
"config.dataset = \"Iris\"\n",
"config.n_splits = 5\n",
"config.random_state = 21\n",
"config.shuffle = True\n",
"config.test_size = 0.4"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "rrcKTKl1MMsg"
},
"source": [
"# The Much Simpler (not always better) : Train Test Split\n",
"---"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Xvj1Cx0odzWD"
},
"source": [
"from sklearn.model_selection import train_test_split\n",
"\n",
"# Load and Download the Dataset\n",
"X, y = datasets.load_iris(return_X_y=True)\n",
"\n",
"# Split using train_test_split\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=config.test_size, random_state=config.random_state)\n",
"\n",
"# Create and train a SVC Model\n",
"clf = svm.SVC(kernel = 'linear', C = 1)\n",
"clf.fit(X_train, y_train)\n",
"\n",
"# Print the score\n",
"score = clf.score(X_test, y_test)\n",
"print(score)\n",
"\n",
"wandb.log({\"train_test_split_score\": score})"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "oBOrrb_WMUTl"
},
"source": [
"# K-Fold Cross Validation\n",
"---"
]
},
{
"cell_type": "code",
"metadata": {
"id": "xQz1ChUZHTk3"
},
"source": [
"from sklearn.model_selection import KFold\n",
"\n",
"# Create KFold instance\n",
"kfold = KFold(n_splits = config.n_splits, shuffle = config.shuffle, random_state = config.random_state)\n",
"\n",
"# Create and train a SVC Model\n",
"clf = svm.SVC(kernel = 'linear', C = 1)\n",
"\n",
"# Iterate over the folds\n",
"for train_index, test_index in kfold.split(X):\n",
"\n",
" # Split the dataset\n",
" X_train, X_test = X[train_index], X[test_index]\n",
" y_train, y_test = y[train_index], y[test_index]\n",
"\n",
" clf.fit(X_train, y_train)\n",
"\n",
" # Print the Scores\n",
" print(clf.score(X_test, y_test))\n",
"\n",
"# Plot the learning curve\n",
"wandb.sklearn.plot_learning_curve(clf, X_train, y_train)\n",
"\n",
"wandb.finish()"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment