Skip to content

Instantly share code, notes, and snippets.

@mahmoodm2
Last active January 1, 2023 22:58
Show Gist options
  • Save mahmoodm2/6a18161ade47f7d53880332b43d516f0 to your computer and use it in GitHub Desktop.
Save mahmoodm2/6a18161ade47f7d53880332b43d516f0 to your computer and use it in GitHub Desktop.
Synthetic data generation using GAN
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Synthetic data generation using GAN",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/mahmoodm2/6a18161ade47f7d53880332b43d516f0/synthetic-data-generation-using-gan.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6uKM9n6k5FIO"
},
"source": [
"###This Jupyter Notebook is used to demonstrare using the Generative Adversarial Netwrok(GAN) to generate synthetic tabular data."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MkdDSNyY4-za"
},
"source": [
"###Downloading the Dataset"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zz309b-ftcIO",
"outputId": "99f8a832-02fb-4243-f019-bfa946f3aae4"
},
"source": [
"!wget https://storage.googleapis.com/synthea-public/synthea_sample_data_csv_apr2020.zip\n",
"!unzip synthea_sample_data_csv_apr2020.zip"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"--2021-05-24 05:56:54-- https://storage.googleapis.com/synthea-public/synthea_sample_data_csv_apr2020.zip\n",
"Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.128.128, 142.251.6.128, 74.125.126.128, ...\n",
"Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.128.128|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 8982431 (8.6M) [application/zip]\n",
"Saving to: ‘synthea_sample_data_csv_apr2020.zip.1’\n",
"\n",
"\r synthea_s 0%[ ] 0 --.-KB/s \rsynthea_sample_data 100%[===================>] 8.57M --.-KB/s in 0.05s \n",
"\n",
"2021-05-24 05:56:54 (183 MB/s) - ‘synthea_sample_data_csv_apr2020.zip.1’ saved [8982431/8982431]\n",
"\n",
"Archive: synthea_sample_data_csv_apr2020.zip\n",
"replace csv/medications.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: A\n",
" inflating: csv/medications.csv \n",
" inflating: csv/providers.csv \n",
" inflating: csv/payer_transitions.csv \n",
" inflating: csv/imaging_studies.csv \n",
" inflating: csv/supplies.csv \n",
" inflating: csv/payers.csv \n",
" inflating: csv/allergies.csv \n",
" inflating: csv/procedures.csv \n",
" inflating: csv/organizations.csv \n",
" inflating: csv/conditions.csv \n",
" inflating: csv/careplans.csv \n",
" inflating: csv/encounters.csv \n",
" inflating: csv/devices.csv \n",
" inflating: csv/immunizations.csv \n",
" inflating: csv/patients.csv \n",
" inflating: csv/observations.csv \n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "gi0jEinivij2",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2df905bf-17b9-4502-80bc-f9b4e30976ff"
},
"source": [
"!pip install ctgan"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: ctgan in /usr/local/lib/python3.7/dist-packages (0.4.2)\n",
"Requirement already satisfied: torch<2,>=1.4 in /usr/local/lib/python3.7/dist-packages (from ctgan) (1.8.1+cu101)\n",
"Requirement already satisfied: rdt<0.5,>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from ctgan) (0.4.1)\n",
"Requirement already satisfied: numpy<2,>=1.18.0 in /usr/local/lib/python3.7/dist-packages (from ctgan) (1.19.5)\n",
"Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from ctgan) (20.9)\n",
"Requirement already satisfied: torchvision<1,>=0.5.0 in /usr/local/lib/python3.7/dist-packages (from ctgan) (0.9.1+cu101)\n",
"Requirement already satisfied: pandas<1.1.5,>=1.1 in /usr/local/lib/python3.7/dist-packages (from ctgan) (1.1.4)\n",
"Requirement already satisfied: scikit-learn<1,>=0.23 in /usr/local/lib/python3.7/dist-packages (from ctgan) (0.24.2)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch<2,>=1.4->ctgan) (3.7.4.3)\n",
"Requirement already satisfied: scipy<2,>=1.4.1 in /usr/local/lib/python3.7/dist-packages (from rdt<0.5,>=0.4.1->ctgan) (1.4.1)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging->ctgan) (2.4.7)\n",
"Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.7/dist-packages (from torchvision<1,>=0.5.0->ctgan) (7.1.2)\n",
"Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas<1.1.5,>=1.1->ctgan) (2.8.1)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas<1.1.5,>=1.1->ctgan) (2018.9)\n",
"Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn<1,>=0.23->ctgan) (1.0.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn<1,>=0.23->ctgan) (2.1.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas<1.1.5,>=1.1->ctgan) (1.15.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AVMSMTEKtojo",
"outputId": "ece02cd4-066f-4d79-8300-3bbc7037febd"
},
"source": [
"import pandas as pd\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"data = pd.read_csv('csv/patients.csv')\n",
"\n",
"keep_features = ['MARITAL', 'RACE', 'ETHNICITY', 'GENDER', 'BIRTHPLACE', 'CITY', 'ZIP', 'STATE' , 'HEALTHCARE_EXPENSES', 'HEALTHCARE_COVERAGE']\n",
"categorical_features = ['MARITAL', 'RACE', 'ETHNICITY', 'GENDER', 'BIRTHPLACE', 'CITY', 'STATE' , 'ZIP']\n",
"\n",
"real_data = data[keep_features]\n",
"print(real_data.columns)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"Index(['MARITAL', 'RACE', 'ETHNICITY', 'GENDER', 'BIRTHPLACE', 'CITY', 'ZIP',\n",
" 'STATE', 'HEALTHCARE_EXPENSES', 'HEALTHCARE_COVERAGE'],\n",
" dtype='object')\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 415
},
"id": "1ExnblkT9Erh",
"outputId": "8bd18ca4-7aae-482b-d407-2b1f0b96f4a0"
},
"source": [
"real_data"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MARITAL</th>\n",
" <th>RACE</th>\n",
" <th>ETHNICITY</th>\n",
" <th>GENDER</th>\n",
" <th>BIRTHPLACE</th>\n",
" <th>CITY</th>\n",
" <th>ZIP</th>\n",
" <th>STATE</th>\n",
" <th>HEALTHCARE_EXPENSES</th>\n",
" <th>HEALTHCARE_COVERAGE</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>M</td>\n",
" <td>white</td>\n",
" <td>hispanic</td>\n",
" <td>M</td>\n",
" <td>Marigot Saint Andrew Parish DM</td>\n",
" <td>Chicopee</td>\n",
" <td>1013.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>271227.08</td>\n",
" <td>1334.88</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>M</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>M</td>\n",
" <td>Danvers Massachusetts US</td>\n",
" <td>Somerville</td>\n",
" <td>2143.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>793946.01</td>\n",
" <td>3204.49</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>M</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>M</td>\n",
" <td>Springfield Massachusetts US</td>\n",
" <td>Chicopee</td>\n",
" <td>1020.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>574111.90</td>\n",
" <td>2606.40</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>M</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>F</td>\n",
" <td>Yarmouth Massachusetts US</td>\n",
" <td>Lowell</td>\n",
" <td>1851.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>935630.30</td>\n",
" <td>8756.19</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>NaN</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>M</td>\n",
" <td>Patras Achaea GR</td>\n",
" <td>Boston</td>\n",
" <td>2135.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>598763.07</td>\n",
" <td>3772.20</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1166</th>\n",
" <td>M</td>\n",
" <td>asian</td>\n",
" <td>hispanic</td>\n",
" <td>F</td>\n",
" <td>Juarez Chihuahua MX</td>\n",
" <td>Cambridge</td>\n",
" <td>2141.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>1622314.87</td>\n",
" <td>32086.31</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1167</th>\n",
" <td>S</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>M</td>\n",
" <td>Upton Massachusetts US</td>\n",
" <td>Beverly</td>\n",
" <td>1915.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>979724.25</td>\n",
" <td>3130.52</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1168</th>\n",
" <td>S</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>F</td>\n",
" <td>Fall River Massachusetts US</td>\n",
" <td>Norwood</td>\n",
" <td>NaN</td>\n",
" <td>Massachusetts</td>\n",
" <td>1560540.35</td>\n",
" <td>52391.24</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1169</th>\n",
" <td>M</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>F</td>\n",
" <td>Springfield Massachusetts US</td>\n",
" <td>Norwood</td>\n",
" <td>2062.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>1375833.47</td>\n",
" <td>13157.00</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1170</th>\n",
" <td>M</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>F</td>\n",
" <td>Worcester Massachusetts US</td>\n",
" <td>Norwood</td>\n",
" <td>2090.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>1510158.34</td>\n",
" <td>26565.65</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>1171 rows × 10 columns</p>\n",
"</div>"
],
"text/plain": [
" MARITAL RACE ... HEALTHCARE_EXPENSES HEALTHCARE_COVERAGE\n",
"0 M white ... 271227.08 1334.88\n",
"1 M white ... 793946.01 3204.49\n",
"2 M white ... 574111.90 2606.40\n",
"3 M white ... 935630.30 8756.19\n",
"4 NaN white ... 598763.07 3772.20\n",
"... ... ... ... ... ...\n",
"1166 M asian ... 1622314.87 32086.31\n",
"1167 S white ... 979724.25 3130.52\n",
"1168 S white ... 1560540.35 52391.24\n",
"1169 M white ... 1375833.47 13157.00\n",
"1170 M white ... 1510158.34 26565.65\n",
"\n",
"[1171 rows x 10 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RB4ukzuX3RXU"
},
"source": [
"## Training the model\n"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cpV1FWHevaWO",
"outputId": "c99b6c52-4072-4baf-e467-2ff933765304"
},
"source": [
"from ctgan import CTGANSynthesizer\n",
"\n",
"ctgan = CTGANSynthesizer(verbose=True)\n",
"ctgan.fit(real_data, categorical_features, epochs = 300)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch 1, Loss G: 2.7220,Loss D: -0.0167\n",
"Epoch 2, Loss G: 2.6461,Loss D: -0.0450\n",
"Epoch 3, Loss G: 2.4273,Loss D: -0.0735\n",
"Epoch 4, Loss G: 2.7161,Loss D: -0.1428\n",
"Epoch 5, Loss G: 2.5028,Loss D: -0.1914\n",
"Epoch 6, Loss G: 2.5397,Loss D: -0.2294\n",
"Epoch 7, Loss G: 2.5450,Loss D: -0.2382\n",
"Epoch 8, Loss G: 2.5592,Loss D: -0.2775\n",
"Epoch 9, Loss G: 2.5199,Loss D: -0.3047\n",
"Epoch 10, Loss G: 2.5094,Loss D: -0.2992\n",
"Epoch 11, Loss G: 2.4861,Loss D: -0.3178\n",
"Epoch 12, Loss G: 2.4973,Loss D: -0.2785\n",
"Epoch 13, Loss G: 2.5476,Loss D: -0.3551\n",
"Epoch 14, Loss G: 2.3788,Loss D: -0.2778\n",
"Epoch 15, Loss G: 2.4980,Loss D: -0.2382\n",
"Epoch 16, Loss G: 2.2980,Loss D: -0.2273\n",
"Epoch 17, Loss G: 2.5093,Loss D: -0.3117\n",
"Epoch 18, Loss G: 2.3282,Loss D: -0.2865\n",
"Epoch 19, Loss G: 2.5446,Loss D: -0.3035\n",
"Epoch 20, Loss G: 2.3724,Loss D: -0.3000\n",
"Epoch 21, Loss G: 2.2893,Loss D: -0.3692\n",
"Epoch 22, Loss G: 2.4974,Loss D: -0.3598\n",
"Epoch 23, Loss G: 2.5267,Loss D: -0.3151\n",
"Epoch 24, Loss G: 2.4257,Loss D: -0.4370\n",
"Epoch 25, Loss G: 2.4397,Loss D: -0.3936\n",
"Epoch 26, Loss G: 2.5698,Loss D: -0.4311\n",
"Epoch 27, Loss G: 2.4059,Loss D: -0.4092\n",
"Epoch 28, Loss G: 2.6675,Loss D: -0.4984\n",
"Epoch 29, Loss G: 2.6464,Loss D: -0.4963\n",
"Epoch 30, Loss G: 2.3802,Loss D: -0.5488\n",
"Epoch 31, Loss G: 2.5288,Loss D: -0.4798\n",
"Epoch 32, Loss G: 2.4125,Loss D: -0.5222\n",
"Epoch 33, Loss G: 2.4611,Loss D: -0.4862\n",
"Epoch 34, Loss G: 2.3612,Loss D: -0.4904\n",
"Epoch 35, Loss G: 2.4119,Loss D: -0.4830\n",
"Epoch 36, Loss G: 2.2983,Loss D: -0.3992\n",
"Epoch 37, Loss G: 2.2561,Loss D: -0.3575\n",
"Epoch 38, Loss G: 2.1506,Loss D: -0.3187\n",
"Epoch 39, Loss G: 2.3781,Loss D: -0.4043\n",
"Epoch 40, Loss G: 2.0233,Loss D: -0.3109\n",
"Epoch 41, Loss G: 2.0642,Loss D: -0.3248\n",
"Epoch 42, Loss G: 2.0557,Loss D: -0.1483\n",
"Epoch 43, Loss G: 2.0589,Loss D: -0.1364\n",
"Epoch 44, Loss G: 1.9122,Loss D: -0.0944\n",
"Epoch 45, Loss G: 1.8047,Loss D: -0.2662\n",
"Epoch 46, Loss G: 1.7784,Loss D: -0.0055\n",
"Epoch 47, Loss G: 1.7492,Loss D: -0.0721\n",
"Epoch 48, Loss G: 1.5633,Loss D: -0.0962\n",
"Epoch 49, Loss G: 1.6252,Loss D: 0.0612\n",
"Epoch 50, Loss G: 1.3566,Loss D: 0.1531\n",
"Epoch 51, Loss G: 1.4798,Loss D: 0.1587\n",
"Epoch 52, Loss G: 1.3238,Loss D: 0.1388\n",
"Epoch 53, Loss G: 1.1748,Loss D: 0.2707\n",
"Epoch 54, Loss G: 1.2396,Loss D: 0.3116\n",
"Epoch 55, Loss G: 1.0063,Loss D: 0.4126\n",
"Epoch 56, Loss G: 1.2137,Loss D: 0.3624\n",
"Epoch 57, Loss G: 1.1719,Loss D: 0.3525\n",
"Epoch 58, Loss G: 1.1455,Loss D: 0.2201\n",
"Epoch 59, Loss G: 1.0129,Loss D: 0.3502\n",
"Epoch 60, Loss G: 1.2132,Loss D: 0.2884\n",
"Epoch 61, Loss G: 1.0484,Loss D: 0.3396\n",
"Epoch 62, Loss G: 1.2568,Loss D: 0.1721\n",
"Epoch 63, Loss G: 1.0787,Loss D: 0.3294\n",
"Epoch 64, Loss G: 1.2225,Loss D: 0.2763\n",
"Epoch 65, Loss G: 1.1407,Loss D: 0.2441\n",
"Epoch 66, Loss G: 1.2368,Loss D: 0.2397\n",
"Epoch 67, Loss G: 1.0690,Loss D: 0.2842\n",
"Epoch 68, Loss G: 1.3808,Loss D: 0.2167\n",
"Epoch 69, Loss G: 1.2878,Loss D: 0.1394\n",
"Epoch 70, Loss G: 1.3125,Loss D: 0.2123\n",
"Epoch 71, Loss G: 1.1866,Loss D: 0.1712\n",
"Epoch 72, Loss G: 1.1126,Loss D: 0.1258\n",
"Epoch 73, Loss G: 1.0783,Loss D: 0.1324\n",
"Epoch 74, Loss G: 1.2846,Loss D: 0.1744\n",
"Epoch 75, Loss G: 1.1055,Loss D: 0.1394\n",
"Epoch 76, Loss G: 1.1162,Loss D: 0.1411\n",
"Epoch 77, Loss G: 1.1037,Loss D: 0.1205\n",
"Epoch 78, Loss G: 1.1642,Loss D: 0.1726\n",
"Epoch 79, Loss G: 1.2163,Loss D: 0.1492\n",
"Epoch 80, Loss G: 1.3392,Loss D: 0.1172\n",
"Epoch 81, Loss G: 1.3620,Loss D: 0.0744\n",
"Epoch 82, Loss G: 1.2348,Loss D: 0.0469\n",
"Epoch 83, Loss G: 1.3188,Loss D: 0.0818\n",
"Epoch 84, Loss G: 1.2584,Loss D: 0.0977\n",
"Epoch 85, Loss G: 0.9544,Loss D: 0.0757\n",
"Epoch 86, Loss G: 1.3084,Loss D: 0.1224\n",
"Epoch 87, Loss G: 1.2146,Loss D: 0.0708\n",
"Epoch 88, Loss G: 1.3828,Loss D: 0.0582\n",
"Epoch 89, Loss G: 1.2468,Loss D: 0.0272\n",
"Epoch 90, Loss G: 0.9287,Loss D: 0.0418\n",
"Epoch 91, Loss G: 1.2695,Loss D: -0.0580\n",
"Epoch 92, Loss G: 1.1650,Loss D: 0.0311\n",
"Epoch 93, Loss G: 1.1005,Loss D: -0.0075\n",
"Epoch 94, Loss G: 1.0958,Loss D: -0.0779\n",
"Epoch 95, Loss G: 1.0175,Loss D: -0.0479\n",
"Epoch 96, Loss G: 1.0798,Loss D: -0.0387\n",
"Epoch 97, Loss G: 0.9573,Loss D: 0.0024\n",
"Epoch 98, Loss G: 1.2633,Loss D: 0.0258\n",
"Epoch 99, Loss G: 1.1980,Loss D: -0.0468\n",
"Epoch 100, Loss G: 0.8943,Loss D: -0.0297\n",
"Epoch 101, Loss G: 1.1320,Loss D: -0.0520\n",
"Epoch 102, Loss G: 1.1132,Loss D: 0.0590\n",
"Epoch 103, Loss G: 1.0539,Loss D: -0.0086\n",
"Epoch 104, Loss G: 0.9532,Loss D: -0.0071\n",
"Epoch 105, Loss G: 0.9822,Loss D: 0.0503\n",
"Epoch 106, Loss G: 0.9551,Loss D: 0.0449\n",
"Epoch 107, Loss G: 1.0087,Loss D: 0.0597\n",
"Epoch 108, Loss G: 0.9967,Loss D: -0.0227\n",
"Epoch 109, Loss G: 0.8712,Loss D: -0.0156\n",
"Epoch 110, Loss G: 1.0134,Loss D: 0.0081\n",
"Epoch 111, Loss G: 1.0755,Loss D: 0.0219\n",
"Epoch 112, Loss G: 0.7897,Loss D: -0.0476\n",
"Epoch 113, Loss G: 1.1207,Loss D: -0.0420\n",
"Epoch 114, Loss G: 1.0223,Loss D: 0.0482\n",
"Epoch 115, Loss G: 1.1407,Loss D: 0.0039\n",
"Epoch 116, Loss G: 1.0872,Loss D: -0.0750\n",
"Epoch 117, Loss G: 1.4322,Loss D: 0.0596\n",
"Epoch 118, Loss G: 1.2366,Loss D: -0.0371\n",
"Epoch 119, Loss G: 1.0306,Loss D: 0.0141\n",
"Epoch 120, Loss G: 0.8798,Loss D: 0.0237\n",
"Epoch 121, Loss G: 1.0133,Loss D: 0.0176\n",
"Epoch 122, Loss G: 1.0250,Loss D: 0.0203\n",
"Epoch 123, Loss G: 0.8172,Loss D: 0.1568\n",
"Epoch 124, Loss G: 0.9424,Loss D: 0.0468\n",
"Epoch 125, Loss G: 0.8411,Loss D: 0.1041\n",
"Epoch 126, Loss G: 0.9051,Loss D: 0.0677\n",
"Epoch 127, Loss G: 1.0352,Loss D: 0.0956\n",
"Epoch 128, Loss G: 1.1504,Loss D: 0.0208\n",
"Epoch 129, Loss G: 0.9771,Loss D: 0.0187\n",
"Epoch 130, Loss G: 1.1199,Loss D: 0.0049\n",
"Epoch 131, Loss G: 1.0768,Loss D: 0.0004\n",
"Epoch 132, Loss G: 0.9676,Loss D: -0.0217\n",
"Epoch 133, Loss G: 0.7691,Loss D: -0.0302\n",
"Epoch 134, Loss G: 0.8795,Loss D: -0.0011\n",
"Epoch 135, Loss G: 0.9124,Loss D: -0.0395\n",
"Epoch 136, Loss G: 1.0391,Loss D: -0.0147\n",
"Epoch 137, Loss G: 0.8113,Loss D: -0.0358\n",
"Epoch 138, Loss G: 0.8814,Loss D: -0.0458\n",
"Epoch 139, Loss G: 0.8053,Loss D: 0.0013\n",
"Epoch 140, Loss G: 0.8118,Loss D: 0.0303\n",
"Epoch 141, Loss G: 0.7793,Loss D: 0.0663\n",
"Epoch 142, Loss G: 0.8917,Loss D: 0.0327\n",
"Epoch 143, Loss G: 0.9252,Loss D: 0.0500\n",
"Epoch 144, Loss G: 0.8735,Loss D: 0.0354\n",
"Epoch 145, Loss G: 0.9865,Loss D: -0.0071\n",
"Epoch 146, Loss G: 1.0630,Loss D: 0.0339\n",
"Epoch 147, Loss G: 0.9280,Loss D: 0.0647\n",
"Epoch 148, Loss G: 1.0534,Loss D: -0.0397\n",
"Epoch 149, Loss G: 0.9826,Loss D: -0.0063\n",
"Epoch 150, Loss G: 0.9695,Loss D: 0.0367\n",
"Epoch 151, Loss G: 0.8811,Loss D: -0.0083\n",
"Epoch 152, Loss G: 0.8932,Loss D: 0.0303\n",
"Epoch 153, Loss G: 1.1060,Loss D: 0.0387\n",
"Epoch 154, Loss G: 0.9923,Loss D: 0.0255\n",
"Epoch 155, Loss G: 0.9793,Loss D: 0.1147\n",
"Epoch 156, Loss G: 0.7850,Loss D: 0.0505\n",
"Epoch 157, Loss G: 0.9503,Loss D: 0.0085\n",
"Epoch 158, Loss G: 0.9086,Loss D: 0.0088\n",
"Epoch 159, Loss G: 0.9582,Loss D: 0.0356\n",
"Epoch 160, Loss G: 1.0433,Loss D: -0.0400\n",
"Epoch 161, Loss G: 0.9563,Loss D: -0.0370\n",
"Epoch 162, Loss G: 0.8524,Loss D: -0.0236\n",
"Epoch 163, Loss G: 0.9272,Loss D: -0.0684\n",
"Epoch 164, Loss G: 0.9566,Loss D: -0.0119\n",
"Epoch 165, Loss G: 0.8715,Loss D: 0.0315\n",
"Epoch 166, Loss G: 0.7799,Loss D: 0.0613\n",
"Epoch 167, Loss G: 0.7510,Loss D: 0.0142\n",
"Epoch 168, Loss G: 0.9471,Loss D: -0.0258\n",
"Epoch 169, Loss G: 0.7951,Loss D: 0.0221\n",
"Epoch 170, Loss G: 0.7874,Loss D: -0.0423\n",
"Epoch 171, Loss G: 0.7759,Loss D: 0.0210\n",
"Epoch 172, Loss G: 0.6541,Loss D: 0.0298\n",
"Epoch 173, Loss G: 0.7881,Loss D: 0.0037\n",
"Epoch 174, Loss G: 0.7806,Loss D: 0.0232\n",
"Epoch 175, Loss G: 1.0056,Loss D: 0.0311\n",
"Epoch 176, Loss G: 0.7557,Loss D: 0.0046\n",
"Epoch 177, Loss G: 0.7037,Loss D: -0.0429\n",
"Epoch 178, Loss G: 0.9277,Loss D: -0.0689\n",
"Epoch 179, Loss G: 0.7939,Loss D: -0.0163\n",
"Epoch 180, Loss G: 0.8747,Loss D: 0.0225\n",
"Epoch 181, Loss G: 0.7599,Loss D: 0.0364\n",
"Epoch 182, Loss G: 1.0194,Loss D: 0.0112\n",
"Epoch 183, Loss G: 0.7092,Loss D: 0.0488\n",
"Epoch 184, Loss G: 0.7849,Loss D: 0.0450\n",
"Epoch 185, Loss G: 0.7199,Loss D: -0.0126\n",
"Epoch 186, Loss G: 0.7728,Loss D: 0.0037\n",
"Epoch 187, Loss G: 0.8458,Loss D: 0.0303\n",
"Epoch 188, Loss G: 0.8293,Loss D: -0.0230\n",
"Epoch 189, Loss G: 0.8573,Loss D: -0.0484\n",
"Epoch 190, Loss G: 0.8272,Loss D: 0.0031\n",
"Epoch 191, Loss G: 0.7993,Loss D: 0.0784\n",
"Epoch 192, Loss G: 0.7652,Loss D: -0.0231\n",
"Epoch 193, Loss G: 1.0990,Loss D: -0.0149\n",
"Epoch 194, Loss G: 0.6609,Loss D: 0.0494\n",
"Epoch 195, Loss G: 0.9004,Loss D: -0.0013\n",
"Epoch 196, Loss G: 0.9205,Loss D: -0.0169\n",
"Epoch 197, Loss G: 0.9881,Loss D: 0.0463\n",
"Epoch 198, Loss G: 0.7775,Loss D: 0.0623\n",
"Epoch 199, Loss G: 0.7577,Loss D: -0.0290\n",
"Epoch 200, Loss G: 0.4771,Loss D: 0.0394\n",
"Epoch 201, Loss G: 0.7217,Loss D: 0.0466\n",
"Epoch 202, Loss G: 0.7339,Loss D: 0.0388\n",
"Epoch 203, Loss G: 0.5890,Loss D: 0.0724\n",
"Epoch 204, Loss G: 0.6561,Loss D: 0.0555\n",
"Epoch 205, Loss G: 0.8715,Loss D: -0.0098\n",
"Epoch 206, Loss G: 0.7008,Loss D: 0.0330\n",
"Epoch 207, Loss G: 0.7343,Loss D: 0.0189\n",
"Epoch 208, Loss G: 0.7506,Loss D: -0.0591\n",
"Epoch 209, Loss G: 0.7358,Loss D: -0.1123\n",
"Epoch 210, Loss G: 0.8187,Loss D: -0.0802\n",
"Epoch 211, Loss G: 0.8247,Loss D: -0.0270\n",
"Epoch 212, Loss G: 0.7559,Loss D: -0.0096\n",
"Epoch 213, Loss G: 0.5888,Loss D: 0.1001\n",
"Epoch 214, Loss G: 0.5673,Loss D: 0.1054\n",
"Epoch 215, Loss G: 0.5811,Loss D: 0.0744\n",
"Epoch 216, Loss G: 0.9078,Loss D: 0.1532\n",
"Epoch 217, Loss G: 0.5786,Loss D: 0.1000\n",
"Epoch 218, Loss G: 0.7233,Loss D: 0.0323\n",
"Epoch 219, Loss G: 0.7319,Loss D: 0.0424\n",
"Epoch 220, Loss G: 0.7076,Loss D: 0.0082\n",
"Epoch 221, Loss G: 0.6885,Loss D: 0.0088\n",
"Epoch 222, Loss G: 0.7126,Loss D: -0.0506\n",
"Epoch 223, Loss G: 0.5396,Loss D: 0.0225\n",
"Epoch 224, Loss G: 0.5722,Loss D: -0.0346\n",
"Epoch 225, Loss G: 0.9440,Loss D: -0.0512\n",
"Epoch 226, Loss G: 0.8581,Loss D: -0.0096\n",
"Epoch 227, Loss G: 0.7515,Loss D: -0.0632\n",
"Epoch 228, Loss G: 0.5515,Loss D: -0.0164\n",
"Epoch 229, Loss G: 0.6719,Loss D: -0.0063\n",
"Epoch 230, Loss G: 0.4614,Loss D: 0.0704\n",
"Epoch 231, Loss G: 0.4597,Loss D: 0.0108\n",
"Epoch 232, Loss G: 0.5004,Loss D: 0.1508\n",
"Epoch 233, Loss G: 0.5043,Loss D: 0.0843\n",
"Epoch 234, Loss G: 0.6481,Loss D: -0.0396\n",
"Epoch 235, Loss G: 0.5557,Loss D: -0.0045\n",
"Epoch 236, Loss G: 0.7464,Loss D: -0.0184\n",
"Epoch 237, Loss G: 0.5849,Loss D: 0.0124\n",
"Epoch 238, Loss G: 0.5131,Loss D: -0.0530\n",
"Epoch 239, Loss G: 0.6377,Loss D: -0.0355\n",
"Epoch 240, Loss G: 0.5646,Loss D: 0.0026\n",
"Epoch 241, Loss G: 0.6161,Loss D: -0.1150\n",
"Epoch 242, Loss G: 0.4843,Loss D: -0.0011\n",
"Epoch 243, Loss G: 0.8724,Loss D: -0.0231\n",
"Epoch 244, Loss G: 0.5319,Loss D: -0.0043\n",
"Epoch 245, Loss G: 0.4672,Loss D: 0.0975\n",
"Epoch 246, Loss G: 0.5956,Loss D: 0.1880\n",
"Epoch 247, Loss G: 0.7126,Loss D: 0.0936\n",
"Epoch 248, Loss G: 0.4375,Loss D: 0.1400\n",
"Epoch 249, Loss G: 0.7449,Loss D: 0.0571\n",
"Epoch 250, Loss G: 0.4647,Loss D: 0.0067\n",
"Epoch 251, Loss G: 0.5160,Loss D: 0.0214\n",
"Epoch 252, Loss G: 0.7419,Loss D: -0.0534\n",
"Epoch 253, Loss G: 0.5988,Loss D: -0.0425\n",
"Epoch 254, Loss G: 0.3704,Loss D: -0.0227\n",
"Epoch 255, Loss G: 0.5754,Loss D: -0.0804\n",
"Epoch 256, Loss G: 0.5358,Loss D: 0.0334\n",
"Epoch 257, Loss G: 0.3127,Loss D: 0.0243\n",
"Epoch 258, Loss G: 0.2116,Loss D: 0.0720\n",
"Epoch 259, Loss G: 0.3320,Loss D: 0.0687\n",
"Epoch 260, Loss G: 0.2940,Loss D: 0.0422\n",
"Epoch 261, Loss G: 0.6401,Loss D: 0.0058\n",
"Epoch 262, Loss G: 0.4052,Loss D: 0.0804\n",
"Epoch 263, Loss G: 0.5086,Loss D: -0.0041\n",
"Epoch 264, Loss G: 0.4924,Loss D: 0.0162\n",
"Epoch 265, Loss G: 0.5532,Loss D: -0.0401\n",
"Epoch 266, Loss G: 0.5285,Loss D: -0.0381\n",
"Epoch 267, Loss G: 0.4629,Loss D: -0.0970\n",
"Epoch 268, Loss G: 0.6104,Loss D: -0.0922\n",
"Epoch 269, Loss G: 0.3659,Loss D: 0.0015\n",
"Epoch 270, Loss G: 0.4166,Loss D: 0.0393\n",
"Epoch 271, Loss G: 0.4172,Loss D: 0.0997\n",
"Epoch 272, Loss G: 0.2331,Loss D: 0.0999\n",
"Epoch 273, Loss G: 0.3405,Loss D: 0.0730\n",
"Epoch 274, Loss G: 0.3338,Loss D: 0.0279\n",
"Epoch 275, Loss G: 0.3571,Loss D: 0.0667\n",
"Epoch 276, Loss G: 0.3951,Loss D: 0.0140\n",
"Epoch 277, Loss G: 0.2518,Loss D: 0.0726\n",
"Epoch 278, Loss G: 0.3112,Loss D: -0.0216\n",
"Epoch 279, Loss G: 0.4161,Loss D: -0.0146\n",
"Epoch 280, Loss G: 0.3186,Loss D: -0.0345\n",
"Epoch 281, Loss G: 0.2794,Loss D: -0.0716\n",
"Epoch 282, Loss G: 0.3211,Loss D: -0.0570\n",
"Epoch 283, Loss G: 0.2710,Loss D: -0.0596\n",
"Epoch 284, Loss G: 0.3674,Loss D: -0.0093\n",
"Epoch 285, Loss G: 0.2425,Loss D: -0.0092\n",
"Epoch 286, Loss G: 0.3579,Loss D: 0.0631\n",
"Epoch 287, Loss G: 0.4247,Loss D: 0.0355\n",
"Epoch 288, Loss G: 0.2433,Loss D: -0.0335\n",
"Epoch 289, Loss G: 0.3719,Loss D: -0.0625\n",
"Epoch 290, Loss G: 0.4448,Loss D: 0.0788\n",
"Epoch 291, Loss G: 0.2918,Loss D: 0.0329\n",
"Epoch 292, Loss G: 0.3490,Loss D: 0.1048\n",
"Epoch 293, Loss G: 0.3431,Loss D: 0.0557\n",
"Epoch 294, Loss G: 0.2143,Loss D: -0.0053\n",
"Epoch 295, Loss G: 0.2369,Loss D: 0.1335\n",
"Epoch 296, Loss G: 0.1772,Loss D: 0.0921\n",
"Epoch 297, Loss G: 0.3189,Loss D: 0.0471\n",
"Epoch 298, Loss G: 0.1971,Loss D: 0.0551\n",
"Epoch 299, Loss G: 0.0577,Loss D: 0.0345\n",
"Epoch 300, Loss G: 0.2606,Loss D: 0.0089\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Z9TvE4Ae46vy"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ij1MvT4Q3rDJ"
},
"source": [
"## Evaluation\n",
"\n",
"Uisng the sample() function of the trained GAN to generate fakes samples."
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 202
},
"id": "dAOZhYjp96NN",
"outputId": "7f04425a-51aa-40d2-83d2-20dc5f74611c"
},
"source": [
"fake_data = ctgan.sample(1000)\n",
"\n",
"fake_data.head()\n"
],
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>MARITAL</th>\n",
" <th>RACE</th>\n",
" <th>ETHNICITY</th>\n",
" <th>GENDER</th>\n",
" <th>BIRTHPLACE</th>\n",
" <th>CITY</th>\n",
" <th>ZIP</th>\n",
" <th>STATE</th>\n",
" <th>HEALTHCARE_EXPENSES</th>\n",
" <th>HEALTHCARE_COVERAGE</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>NaN</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>F</td>\n",
" <td>Caguas Puerto Rico PR</td>\n",
" <td>Chatham</td>\n",
" <td>2148.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>1.401285e+06</td>\n",
" <td>9040.259524</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>NaN</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>F</td>\n",
" <td>Quincy Massachusetts US</td>\n",
" <td>Groveland</td>\n",
" <td>1106.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>6.068832e+05</td>\n",
" <td>10299.582375</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>S</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>F</td>\n",
" <td>Holyoke Massachusetts US</td>\n",
" <td>Methuen</td>\n",
" <td>2114.0</td>\n",
" <td>Massachusetts</td>\n",
" <td>1.200787e+06</td>\n",
" <td>6400.250611</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>M</td>\n",
" <td>white</td>\n",
" <td>nonhispanic</td>\n",
" <td>M</td>\n",
" <td>Worcester Massachusetts US</td>\n",
" <td>Lynn</td>\n",
" <td>NaN</td>\n",
" <td>Massachusetts</td>\n",
" <td>1.294887e+04</td>\n",
" <td>11408.045525</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>S</td>\n",
" <td>black</td>\n",
" <td>nonhispanic</td>\n",
" <td>F</td>\n",
" <td>Dartmouth Massachusetts US</td>\n",
" <td>Lowell</td>\n",
" <td>NaN</td>\n",
" <td>Massachusetts</td>\n",
" <td>6.576685e+05</td>\n",
" <td>11668.114383</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" MARITAL RACE ... HEALTHCARE_EXPENSES HEALTHCARE_COVERAGE\n",
"0 NaN white ... 1.401285e+06 9040.259524\n",
"1 NaN white ... 6.068832e+05 10299.582375\n",
"2 S white ... 1.200787e+06 6400.250611\n",
"3 M white ... 1.294887e+04 11408.045525\n",
"4 S black ... 6.576685e+05 11668.114383\n",
"\n",
"[5 rows x 10 columns]"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9e1ilRTGz8_9"
},
"source": [
"\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "LdG7lIppFETC"
},
"source": [
"### Comparing the Real and Fake datasets"
]
},
{
"cell_type": "code",
"metadata": {
"id": "94Cp1ZZ36-_I",
"outputId": "058bd62c-a509-4dd2-eac9-dd1a5675faa5",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"!pip install dython"
],
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: dython in /usr/local/lib/python3.7/dist-packages (0.6.5.post1)\n",
"Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from dython) (0.11.1)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from dython) (3.2.2)\n",
"Requirement already satisfied: scikit-plot>=0.3.7 in /usr/local/lib/python3.7/dist-packages (from dython) (0.3.7)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from dython) (1.19.5)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from dython) (1.4.1)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from dython) (0.24.2)\n",
"Requirement already satisfied: pandas>=0.23.4 in /usr/local/lib/python3.7/dist-packages (from dython) (1.1.4)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->dython) (2.4.7)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->dython) (1.3.1)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->dython) (2.8.1)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->dython) (0.10.0)\n",
"Requirement already satisfied: joblib>=0.10 in /usr/local/lib/python3.7/dist-packages (from scikit-plot>=0.3.7->dython) (1.0.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->dython) (2.1.0)\n",
"Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.23.4->dython) (2018.9)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->dython) (1.15.0)\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 473
},
"id": "dml4lOPKE_gS",
"outputId": "441cb864-e7af-409b-8530-d1a0a3e16cff"
},
"source": [
"from dython.nominal import compute_associations, numerical_encoding\n",
"\n",
"real_corr = compute_associations(real_data, nominal_columns=categorical_features, theil_u=True)\n",
"fake_corr = compute_associations(fake_data, nominal_columns=categorical_features, theil_u=True)\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(35, 8))\n",
"\n",
"sns.set(style=\"white\")\n",
"sns.heatmap(real_corr, ax=axes[0], vmax=.3, square=True, center=0,\n",
" linewidths=.5, cbar_kws={\"shrink\": .5}, fmt='.2f')\n",
"axes[0].set_title('Real', size=20)\n",
"\n",
"sns.heatmap(fake_corr, ax=axes[1], vmax=.3, square=True, center=0,\n",
" linewidths=.5, cbar_kws={\"shrink\": .5}, fmt='.2f')\n",
"axes[1].set_title('Fake', size=20)\n",
"\n",
"\n",
"for ax in axes:\n",
" ax.set_xticklabels(\n",
" ax.get_xticklabels(),\n",
" rotation=45,\n",
" horizontalalignment='right'\n",
" )\n",
"\n"
],
"execution_count": 9,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 2520x576 with 4 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 394
},
"id": "tf1K-qRSH4bp",
"outputId": "fa3deb46-3036-4666-bf02-0139bb38b4e9"
},
"source": [
"cont_data = [ c for c in real_data.columns if c not in categorical_features]\n",
"cols = 2\n",
"\n",
"rows =max (1 , len(cont_data) // cols )\n",
"\n",
"fig, ax= plt.subplots(rows, cols, figsize=(25, 6 * rows))\n",
"\n",
"axes = ax.flatten()\n",
"i=0\n",
"for col in cont_data: \n",
" sns.distplot(real_data[col], ax=axes[i], label='Real', color = 'blue') \n",
"\n",
" sns.distplot(fake_data[col], ax=axes[i], label='Fake', color = 'orange')\n",
" \n",
" axes[i].set_title(col, size=20)\n",
" axes[i].legend(loc=0,prop={'size': 20})\n",
"\n",
" i+=1"
],
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 1800x432 with 2 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment