Created
August 23, 2022 03:57
-
-
Save aniquetahir/e67ddbd33a0e57a3af4c4640f383d0f1 to your computer and use it in GitHub Desktop.
Simple MLP Fairness
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# Download adult dataset\n", | |
"#from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult, load_preproc_data_compas, load_preproc_data_german\n", | |
"import jax\n", | |
"import jax.numpy as jnp\n", | |
"import numpy as onp\n", | |
"import haiku as hk\n", | |
"import optax\n", | |
"from typing import Dict, List\n", | |
"import chex\n", | |
"import pickle\n", | |
"from tqdm import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"outputs": [], | |
"source": [ | |
"def pickle_save(obj, location: str):\n", | |
" with open(location, 'wb') as save_file:\n", | |
" pickle.dump(obj, save_file)\n", | |
"\n", | |
"def pickle_load(location: str):\n", | |
" obj = None\n", | |
" with open(location, 'rb') as load_file:\n", | |
" obj = pickle.load(load_file)\n", | |
" return obj" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "DeviceArray([ 4, 8, 12], dtype=int32)" | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"jnp.array([1,2,3]) * 4" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"outputs": [], | |
"source": [ | |
"adult_data = load_preproc_data_adult()\n", | |
"compas_data = load_preproc_data_compas()\n", | |
"german_data = load_preproc_data_german()" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "aif360.datasets.compas_dataset.CompasDataset" | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"type(compas_data)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"outputs": [], | |
"source": [ | |
"df_adult = adult_data.convert_to_dataframe()[0]" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"ename": "AttributeError", | |
"evalue": "'DataFrame' object has no attribute 'te'", | |
"output_type": "error", | |
"traceback": [ | |
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", | |
"\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)", | |
"Input \u001B[0;32mIn [7]\u001B[0m, in \u001B[0;36m<cell line: 1>\u001B[0;34m()\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[43mdf_adult\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mte\u001B[49m\n", | |
"File \u001B[0;32m~/anaconda3/lib/python3.9/site-packages/pandas/core/generic.py:5575\u001B[0m, in \u001B[0;36mNDFrame.__getattr__\u001B[0;34m(self, name)\u001B[0m\n\u001B[1;32m 5568\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m (\n\u001B[1;32m 5569\u001B[0m name \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_internal_names_set\n\u001B[1;32m 5570\u001B[0m \u001B[38;5;129;01mand\u001B[39;00m name \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_metadata\n\u001B[1;32m 5571\u001B[0m \u001B[38;5;129;01mand\u001B[39;00m name \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_accessors\n\u001B[1;32m 5572\u001B[0m \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_info_axis\u001B[38;5;241m.\u001B[39m_can_hold_identifiers_and_holds_name(name)\n\u001B[1;32m 5573\u001B[0m ):\n\u001B[1;32m 5574\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m[name]\n\u001B[0;32m-> 5575\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mobject\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__getattribute__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mname\u001B[49m\u001B[43m)\u001B[49m\n", | |
"\u001B[0;31mAttributeError\u001B[0m: 'DataFrame' object has no attribute 'te'" | |
] | |
} | |
], | |
"source": [ | |
"df_adult.te" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": " race sex Age (decade)=10 Age (decade)=20 Age (decade)=30 \\\n0 0.0 1.0 0.0 1.0 0.0 \n1 1.0 1.0 0.0 0.0 1.0 \n2 1.0 1.0 0.0 1.0 0.0 \n3 0.0 1.0 0.0 0.0 0.0 \n4 1.0 0.0 1.0 0.0 0.0 \n... ... ... ... ... ... \n48837 1.0 0.0 0.0 1.0 0.0 \n48838 1.0 1.0 0.0 0.0 0.0 \n48839 1.0 0.0 0.0 0.0 0.0 \n48840 1.0 1.0 0.0 1.0 0.0 \n48841 1.0 0.0 0.0 0.0 0.0 \n\n Age (decade)=40 Age (decade)=50 Age (decade)=60 Age (decade)=>=70 \\\n0 0.0 0.0 0.0 0.0 \n1 0.0 0.0 0.0 0.0 \n2 0.0 0.0 0.0 0.0 \n3 1.0 0.0 0.0 0.0 \n4 0.0 0.0 0.0 0.0 \n... ... ... ... ... \n48837 0.0 0.0 0.0 0.0 \n48838 1.0 0.0 0.0 0.0 \n48839 0.0 1.0 0.0 0.0 \n48840 0.0 0.0 0.0 0.0 \n48841 0.0 1.0 0.0 0.0 \n\n Education Years=6 Education Years=7 Education Years=8 \\\n0 0.0 1.0 0.0 \n1 0.0 0.0 0.0 \n2 0.0 0.0 0.0 \n3 0.0 0.0 0.0 \n4 0.0 0.0 0.0 \n... ... ... ... \n48837 0.0 0.0 0.0 \n48838 0.0 0.0 0.0 \n48839 0.0 0.0 0.0 \n48840 0.0 0.0 0.0 \n48841 0.0 0.0 0.0 \n\n Education Years=9 Education Years=10 Education Years=11 \\\n0 0.0 0.0 0.0 \n1 1.0 0.0 0.0 \n2 0.0 0.0 0.0 \n3 0.0 1.0 0.0 \n4 0.0 1.0 0.0 \n... ... ... ... \n48837 0.0 0.0 0.0 \n48838 1.0 0.0 0.0 \n48839 1.0 0.0 0.0 \n48840 1.0 0.0 0.0 \n48841 1.0 0.0 0.0 \n\n Education Years=12 Education Years=<6 Education Years=>12 \\\n0 0.0 0.0 0.0 \n1 0.0 0.0 0.0 \n2 1.0 0.0 0.0 \n3 0.0 0.0 0.0 \n4 0.0 0.0 0.0 \n... ... ... ... \n48837 1.0 0.0 0.0 \n48838 0.0 0.0 0.0 \n48839 0.0 0.0 0.0 \n48840 0.0 0.0 0.0 \n48841 0.0 0.0 0.0 \n\n Income Binary \n0 0.0 \n1 0.0 \n2 1.0 \n3 1.0 \n4 0.0 \n... ... \n48837 0.0 \n48838 1.0 \n48839 0.0 \n48840 0.0 \n48841 1.0 \n\n[48842 rows x 19 columns]", | |
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>race</th>\n <th>sex</th>\n <th>Age (decade)=10</th>\n <th>Age (decade)=20</th>\n <th>Age (decade)=30</th>\n <th>Age (decade)=40</th>\n <th>Age (decade)=50</th>\n <th>Age (decade)=60</th>\n <th>Age (decade)=>=70</th>\n <th>Education Years=6</th>\n <th>Education Years=7</th>\n <th>Education Years=8</th>\n <th>Education Years=9</th>\n <th>Education Years=10</th>\n <th>Education Years=11</th>\n <th>Education Years=12</th>\n <th>Education Years=<6</th>\n <th>Education Years=>12</th>\n <th>Income Binary</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n </tr>\n <tr>\n <th>3</th>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>48837</th>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>48838</th>\n <td>1.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n </tr>\n <tr>\n <th>48839</th>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>48840</th>\n <td>1.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>48841</th>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n </tr>\n </tbody>\n</table>\n<p>48842 rows × 19 columns</p>\n</div>" | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"df_adult" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"outputs": [], | |
"source": [ | |
"#pickle_save(df_adult, 'adult_df.pd')" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"outputs": [], | |
"source": [ | |
"df_adult = pickle_load('adult_df.pd')" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"outputs": [], | |
"source": [ | |
"data_numpy = df_adult.to_numpy()\n", | |
"onp.random.shuffle(data_numpy)\n", | |
"data = data_numpy[:, :-1]\n", | |
"labels = data_numpy[:, -1:]" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"outputs": [], | |
"source": [ | |
"train_test_ratio = 0.8\n", | |
"train_end_index = int(len(data_numpy) * train_test_ratio)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"outputs": [], | |
"source": [ | |
"train_data = data[:train_end_index]\n", | |
"test_data = data[train_end_index:]\n", | |
"train_labels = labels[:train_end_index]\n", | |
"test_labels = labels[train_end_index:]" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Calculate the bias metrics" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%% md\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Statistical Parity Difference" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%% md\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"outputs": [], | |
"source": [ | |
"# We first look at the race\n", | |
"# Get the index of the samples with favorable outcome\n", | |
"favorable_samples = train_data[onp.where(train_labels==1)[0]]\n", | |
"unfavorable_samples = train_data[onp.where(train_labels==0)[0]]" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"outputs": [], | |
"source": [ | |
"# number of favorable outcomes by race 0\n", | |
"race0_fav = (favorable_samples[:,0]==0).astype(float).sum()\n", | |
"race0_total = len(train_data[train_data[:, 0]==0])\n", | |
"race0_fav_rate = race0_fav/race0_total\n", | |
"\n", | |
"race1_fav = (favorable_samples[:,0]==1).astype(float).sum()\n", | |
"race1_total = len(train_data[train_data[:, 0]==1])\n", | |
"race1_fav_rate = race1_fav/race1_total\n" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(8494.0, 33397)" | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"race1_fav, race1_total" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(851.0, 5676)" | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"race0_fav, race0_total" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "0.1044046938001694" | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"race1_fav_rate-race0_fav_rate" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Create a Simple MLP" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%% md\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"outputs": [], | |
"source": [ | |
"# Define a simple MLP\n", | |
"class MLP(hk.Module):\n", | |
" def __init__(self, num_layers: int, hidden_dim: int, dropout: float):\n", | |
" super(MLP, self).__init__('mlp')\n", | |
" self.num_layers = num_layers\n", | |
" self.hidden_dim = hidden_dim\n", | |
" self.dropout = dropout\n", | |
"\n", | |
" def __call__(self, x: chex.Array, is_training:bool=False):\n", | |
" x_ = x\n", | |
" for i in range(self.num_layers - 1):\n", | |
" x_ = hk.Linear(self.hidden_dim)(x_)\n", | |
" # x_ = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=.999)(x_, is_training)\n", | |
" x_ = jax.nn.leaky_relu(x_)\n", | |
" x_ = hk.cond(is_training, lambda a: hk.dropout(hk.next_rng_key(), self.dropout, a), lambda a: a, x_)\n", | |
"\n", | |
" x_ = hk.Linear(2)(x_)\n", | |
" return x_" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"outputs": [], | |
"source": [ | |
"NUM_LAYERS = 3\n", | |
"HIDDEN_DIM = 128\n", | |
"DROPOUT = 0.25\n", | |
"\n", | |
"key = jax.random.PRNGKey(6)\n", | |
"mlp_init, mlp_apply = hk.transform(lambda x, t: MLP(NUM_LAYERS, HIDDEN_DIM, DROPOUT)(x, t))" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"outputs": [], | |
"source": [ | |
"key, split = jax.random.split(key)\n", | |
"mlp_params = mlp_init(split, jnp.array(train_data[:10]), True)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Train our model on the data" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%% md\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"outputs": [], | |
"source": [ | |
"optim = optax.adamw(0.005)\n", | |
"opt_state = optim.init(mlp_params)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"outputs": [], | |
"source": [ | |
"NUM_EPOCHS = 1000\n", | |
"# BATCH_SIZE = 100\n", | |
"# For now we can use the entire data as the batch\n", | |
"train_data = jnp.array(train_data)\n", | |
"train_labels = jnp.array(train_labels).astype(int).flatten()\n", | |
"\n", | |
"test_data = jnp.array(test_data)\n", | |
"test_labels = jnp.array(test_labels).astype(int).flatten()\n", | |
"\n", | |
"def loss_fn(params, key, x, y, is_training=True):\n", | |
" logits = mlp_apply(params, key, x, is_training)\n", | |
" loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)\n", | |
" return loss.mean()\n", | |
"\n", | |
"loss_value_grad = jax.jit(jax.value_and_grad(loss_fn))" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 1%| | 12/1000 [00:00<00:53, 18.42it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.6878519058227539, Train Acc: 0.7608323097229004, Test Acc: 0.7602620720863342\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 11%|█ | 112/1000 [00:01<00:13, 64.16it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.4207240045070648, Train Acc: 0.8057737946510315, Test Acc: 0.7966015338897705\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 22%|██▏ | 220/1000 [00:03<00:10, 77.82it/s] " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.4189543128013611, Train Acc: 0.8058761954307556, Test Acc: 0.7969086170196533\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 32%|███▏ | 318/1000 [00:04<00:08, 79.91it/s] " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.41802212595939636, Train Acc: 0.8060553669929504, Test Acc: 0.7966015338897705\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 42%|████▏ | 421/1000 [00:05<00:06, 85.12it/s] " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.4177604019641876, Train Acc: 0.806029736995697, Test Acc: 0.7979322671890259\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 52%|█████▏ | 523/1000 [00:06<00:05, 85.70it/s] " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.41721275448799133, Train Acc: 0.8061065077781677, Test Acc: 0.7964991331100464\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 63%|██████▎ | 626/1000 [00:07<00:04, 85.74it/s] " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.41683855652809143, Train Acc: 0.8060809373855591, Test Acc: 0.7978298664093018\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 72%|███████▏ | 717/1000 [00:08<00:03, 85.12it/s] " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.4164116084575653, Train Acc: 0.8060553669929504, Test Acc: 0.797522783279419\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 82%|████████▏ | 819/1000 [00:09<00:02, 83.34it/s] " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.4168620705604553, Train Acc: 0.8061065077781677, Test Acc: 0.7964991331100464\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 92%|█████████▏| 918/1000 [00:10<00:01, 81.49it/s] " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Loss: 0.41707655787467957, Train Acc: 0.8061065077781677, Test Acc: 0.7964991331100464\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 1000/1000 [00:11<00:00, 89.28it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in tqdm(range(NUM_EPOCHS)):\n", | |
" key, split = jax.random.split(key)\n", | |
" loss, grad = loss_value_grad(mlp_params, split, train_data, train_labels)\n", | |
" updates, opt_state = optim.update(grad, opt_state, mlp_params)\n", | |
" mlp_params = optax.apply_updates(mlp_params, updates)\n", | |
" if i%100 == 0:\n", | |
" key, split = jax.random.split(key)\n", | |
" train_predictions = jnp.argmax(mlp_apply(mlp_params, split, train_data, False), axis=1)\n", | |
" train_acc = (train_predictions == train_labels).astype(float).mean()\n", | |
"\n", | |
" test_predictions = jnp.argmax(mlp_apply(mlp_params, split, test_data, False), axis=1)\n", | |
" test_acc = (test_predictions == test_labels).astype(float).mean()\n", | |
"\n", | |
" print(f'Loss: {loss}, Train Acc: {train_acc}, Test Acc: {test_acc}')" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 55, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "DeviceArray([[0.98889744, 0.9593886 ],\n [0.95056295, 0.99793124],\n [0.96067035, 0.98759353],\n [0.4896833 , 0.4588622 ],\n [0.4600625 , 0.48842293],\n [0.960789 , 0.9874729 ],\n [1.0003633 , 0.9482032 ],\n [0.92749125, 1.0220175 ],\n [0.4901985 , 0.45837286],\n [0.94303775, 1.0057118 ]], dtype=float32)" | |
}, | |
"execution_count": 55, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss_fn(mlp_params, split, train_data[:10], train_labels[:10])" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"outputs": [], | |
"source": [ | |
"test_logits = mlp_apply(mlp_params, key, train_data[:10], True)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "DeviceArray([[0.51991534, 0.48008466],\n [0.50613105, 0.49386895],\n [0.51190686, 0.4880932 ],\n [0.56233084, 0.43766916],\n [0.547018 , 0.452982 ],\n [0.5108703 , 0.48912978],\n [0.55052614, 0.44947386],\n [0.52233845, 0.47766158],\n [0.46725613, 0.53274393],\n [0.44238308, 0.557617 ]], dtype=float32)" | |
}, | |
"execution_count": 21, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_logits" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"outputs": [], | |
"source": [ | |
"test_labels = train_labels[:10]" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"outputs": [], | |
"source": [ | |
"a = optax.softmax_cross_entropy_with_integer_labels(test_logits, test_labels.astype(int).flatten())" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "DeviceArray(0.7031968, dtype=float32)" | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"a.mean()" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int32)" | |
}, | |
"execution_count": 38, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"jnp.argmax(test_logits, axis=1)" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"outputs": [], | |
"source": [], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment