Skip to content

Instantly share code, notes, and snippets.

@angeligareta
Created July 21, 2021 06:46
Show Gist options
  • Save angeligareta/b42785185ee245e846455cf2d6b343ff to your computer and use it in GitHub Desktop.
Save angeligareta/b42785185ee245e846455cf2d6b343ff to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "stratified_sampling.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "SPVHS92rqQ7L"
},
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from sklearn import datasets\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"def get_dataset_partitions_pd(df, train_split=0.8, val_split=0.1, test_split=0.1, target_variable=None):\n",
" assert (train_split + test_split + val_split) == 1\n",
" \n",
" # Only allows for equal validation and test splits\n",
" assert val_split == test_split \n",
"\n",
" # Shuffle\n",
" df_sample = df.sample(frac=1, random_state=12)\n",
"\n",
" # Specify seed to always have the same split distribution between runs\n",
" # If target variable is provided, generate stratified sets\n",
" if target_variable is not None:\n",
" grouped_df = df_sample.groupby(target_variable)\n",
" arr_list = [np.split(g, [int(train_split * len(g)), int((1 - val_split) * len(g))]) for i, g in grouped_df]\n",
"\n",
" train_ds = pd.concat([t[0] for t in arr_list])\n",
" val_ds = pd.concat([t[1] for t in arr_list])\n",
" test_ds = pd.concat([v[2] for v in arr_list])\n",
"\n",
" else:\n",
" indices_or_sections = [int(train_split * len(df)), int((1 - val_split) * len(df))]\n",
" train_ds, val_ds, test_ds = np.split(df_sample, indices_or_sections)\n",
" \n",
" return train_ds, val_ds, test_ds"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_HFvDin4n4Xd",
"outputId": "03aebdc2-d739-4380-e8e6-370181d1a9a9"
},
"source": [
"dataset = datasets.load_iris()\n",
"X = pd.DataFrame(dataset.data)\n",
"y = pd.DataFrame(dataset.target)\n",
"print(f'Distribution in original set: \\n{y[0].value_counts().sort_index() / len(y)}')"
],
"execution_count": 72,
"outputs": [
{
"output_type": "stream",
"text": [
"Distribution in original set: \n",
"0 0.333333\n",
"1 0.333333\n",
"2 0.333333\n",
"Name: 0, dtype: float64\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AeUTn588tsWK",
"outputId": "bea9e0eb-69e6-4953-d9c9-24831849bab6"
},
"source": [
"train_ds, val_ds, test_ds = get_dataset_partitions_pd(y)\n",
"print(f'Distribution in training set: \\n{train_ds[0].value_counts().sort_index() / len(train_ds)}\\n\\n'+\n",
" f'Distribution in validation set: \\n{val_ds[0].value_counts().sort_index() / len(val_ds)}\\n\\n'+\n",
" f'Distribution in testing set: \\n{test_ds[0].value_counts().sort_index() / len(test_ds)}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Distribution in training set: \n",
"0 0.341667\n",
"1 0.358333\n",
"2 0.300000\n",
"Name: 0, dtype: float64\n",
"\n",
"Distribution in validation set: \n",
"0 0.333333\n",
"1 0.266667\n",
"2 0.400000\n",
"Name: 0, dtype: float64\n",
"\n",
"Distribution in testing set: \n",
"0 0.266667\n",
"1 0.200000\n",
"2 0.533333\n",
"Name: 0, dtype: float64\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Rd0RAVJBo3XU",
"outputId": "84760724-16b6-424d-a5bc-f0ba99f5f0aa"
},
"source": [
"train_ds, val_ds, test_ds = get_dataset_partitions_pd(y, target_variable=0)\n",
"print(f'Distribution in training set: \\n{train_ds[0].value_counts().sort_index() / len(train_ds)}\\n\\n'+\n",
" f'Distribution in validation set: \\n{val_ds[0].value_counts().sort_index() / len(val_ds)}\\n\\n'+\n",
" f'Distribution in testing set: \\n{test_ds[0].value_counts().sort_index() / len(test_ds)}')"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Distribution in training set: \n",
"0 0.333333\n",
"1 0.333333\n",
"2 0.333333\n",
"Name: 0, dtype: float64\n",
"\n",
"Distribution in validation set: \n",
"0 0.333333\n",
"1 0.333333\n",
"2 0.333333\n",
"Name: 0, dtype: float64\n",
"\n",
"Distribution in testing set: \n",
"0 0.333333\n",
"1 0.333333\n",
"2 0.333333\n",
"Name: 0, dtype: float64\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment