Skip to content

Instantly share code, notes, and snippets.

@aniquetahir
Created August 23, 2022 03:57
Show Gist options
  • Save aniquetahir/e67ddbd33a0e57a3af4c4640f383d0f1 to your computer and use it in GitHub Desktop.
Save aniquetahir/e67ddbd33a0e57a3af4c4640f383d0f1 to your computer and use it in GitHub Desktop.
Simple MLP Fairness
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# Download adult dataset\n",
"#from aif360.algorithms.preprocessing.optim_preproc_helpers.data_preproc_functions import load_preproc_data_adult, load_preproc_data_compas, load_preproc_data_german\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as onp\n",
"import haiku as hk\n",
"import optax\n",
"from typing import Dict, List\n",
"import chex\n",
"import pickle\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"def pickle_save(obj, location: str):\n",
" with open(location, 'wb') as save_file:\n",
" pickle.dump(obj, save_file)\n",
"\n",
"def pickle_load(location: str):\n",
" obj = None\n",
" with open(location, 'rb') as load_file:\n",
" obj = pickle.load(load_file)\n",
" return obj"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "DeviceArray([ 4, 8, 12], dtype=int32)"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.array([1,2,3]) * 4"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"adult_data = load_preproc_data_adult()\n",
"compas_data = load_preproc_data_compas()\n",
"german_data = load_preproc_data_german()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "aif360.datasets.compas_dataset.CompasDataset"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"type(compas_data)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [
"df_adult = adult_data.convert_to_dataframe()[0]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"ename": "AttributeError",
"evalue": "'DataFrame' object has no attribute 'te'",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)",
"Input \u001B[0;32mIn [7]\u001B[0m, in \u001B[0;36m<cell line: 1>\u001B[0;34m()\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[43mdf_adult\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mte\u001B[49m\n",
"File \u001B[0;32m~/anaconda3/lib/python3.9/site-packages/pandas/core/generic.py:5575\u001B[0m, in \u001B[0;36mNDFrame.__getattr__\u001B[0;34m(self, name)\u001B[0m\n\u001B[1;32m 5568\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m (\n\u001B[1;32m 5569\u001B[0m name \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_internal_names_set\n\u001B[1;32m 5570\u001B[0m \u001B[38;5;129;01mand\u001B[39;00m name \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_metadata\n\u001B[1;32m 5571\u001B[0m \u001B[38;5;129;01mand\u001B[39;00m name \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_accessors\n\u001B[1;32m 5572\u001B[0m \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_info_axis\u001B[38;5;241m.\u001B[39m_can_hold_identifiers_and_holds_name(name)\n\u001B[1;32m 5573\u001B[0m ):\n\u001B[1;32m 5574\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m[name]\n\u001B[0;32m-> 5575\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mobject\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[38;5;21;43m__getattribute__\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mname\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[0;31mAttributeError\u001B[0m: 'DataFrame' object has no attribute 'te'"
]
}
],
"source": [
"df_adult.te"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": " race sex Age (decade)=10 Age (decade)=20 Age (decade)=30 \\\n0 0.0 1.0 0.0 1.0 0.0 \n1 1.0 1.0 0.0 0.0 1.0 \n2 1.0 1.0 0.0 1.0 0.0 \n3 0.0 1.0 0.0 0.0 0.0 \n4 1.0 0.0 1.0 0.0 0.0 \n... ... ... ... ... ... \n48837 1.0 0.0 0.0 1.0 0.0 \n48838 1.0 1.0 0.0 0.0 0.0 \n48839 1.0 0.0 0.0 0.0 0.0 \n48840 1.0 1.0 0.0 1.0 0.0 \n48841 1.0 0.0 0.0 0.0 0.0 \n\n Age (decade)=40 Age (decade)=50 Age (decade)=60 Age (decade)=>=70 \\\n0 0.0 0.0 0.0 0.0 \n1 0.0 0.0 0.0 0.0 \n2 0.0 0.0 0.0 0.0 \n3 1.0 0.0 0.0 0.0 \n4 0.0 0.0 0.0 0.0 \n... ... ... ... ... \n48837 0.0 0.0 0.0 0.0 \n48838 1.0 0.0 0.0 0.0 \n48839 0.0 1.0 0.0 0.0 \n48840 0.0 0.0 0.0 0.0 \n48841 0.0 1.0 0.0 0.0 \n\n Education Years=6 Education Years=7 Education Years=8 \\\n0 0.0 1.0 0.0 \n1 0.0 0.0 0.0 \n2 0.0 0.0 0.0 \n3 0.0 0.0 0.0 \n4 0.0 0.0 0.0 \n... ... ... ... \n48837 0.0 0.0 0.0 \n48838 0.0 0.0 0.0 \n48839 0.0 0.0 0.0 \n48840 0.0 0.0 0.0 \n48841 0.0 0.0 0.0 \n\n Education Years=9 Education Years=10 Education Years=11 \\\n0 0.0 0.0 0.0 \n1 1.0 0.0 0.0 \n2 0.0 0.0 0.0 \n3 0.0 1.0 0.0 \n4 0.0 1.0 0.0 \n... ... ... ... \n48837 0.0 0.0 0.0 \n48838 1.0 0.0 0.0 \n48839 1.0 0.0 0.0 \n48840 1.0 0.0 0.0 \n48841 1.0 0.0 0.0 \n\n Education Years=12 Education Years=<6 Education Years=>12 \\\n0 0.0 0.0 0.0 \n1 0.0 0.0 0.0 \n2 1.0 0.0 0.0 \n3 0.0 0.0 0.0 \n4 0.0 0.0 0.0 \n... ... ... ... \n48837 1.0 0.0 0.0 \n48838 0.0 0.0 0.0 \n48839 0.0 0.0 0.0 \n48840 0.0 0.0 0.0 \n48841 0.0 0.0 0.0 \n\n Income Binary \n0 0.0 \n1 0.0 \n2 1.0 \n3 1.0 \n4 0.0 \n... ... \n48837 0.0 \n48838 1.0 \n48839 0.0 \n48840 0.0 \n48841 1.0 \n\n[48842 rows x 19 columns]",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>race</th>\n <th>sex</th>\n <th>Age (decade)=10</th>\n <th>Age (decade)=20</th>\n <th>Age (decade)=30</th>\n <th>Age (decade)=40</th>\n <th>Age (decade)=50</th>\n <th>Age (decade)=60</th>\n <th>Age (decade)=&gt;=70</th>\n <th>Education Years=6</th>\n <th>Education Years=7</th>\n <th>Education Years=8</th>\n <th>Education Years=9</th>\n <th>Education Years=10</th>\n <th>Education Years=11</th>\n <th>Education Years=12</th>\n <th>Education Years=&lt;6</th>\n <th>Education Years=&gt;12</th>\n <th>Income Binary</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>1</th>\n <td>1.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>2</th>\n <td>1.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n </tr>\n <tr>\n <th>3</th>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n </tr>\n <tr>\n <th>4</th>\n <td>1.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>48837</th>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>48838</th>\n <td>1.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n </tr>\n <tr>\n <th>48839</th>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>48840</th>\n <td>1.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n </tr>\n <tr>\n <th>48841</th>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>0.0</td>\n <td>1.0</td>\n </tr>\n </tbody>\n</table>\n<p>48842 rows × 19 columns</p>\n</div>"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df_adult"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"#pickle_save(df_adult, 'adult_df.pd')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"df_adult = pickle_load('adult_df.pd')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"data_numpy = df_adult.to_numpy()\n",
"onp.random.shuffle(data_numpy)\n",
"data = data_numpy[:, :-1]\n",
"labels = data_numpy[:, -1:]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"train_test_ratio = 0.8\n",
"train_end_index = int(len(data_numpy) * train_test_ratio)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [
"train_data = data[:train_end_index]\n",
"test_data = data[train_end_index:]\n",
"train_labels = labels[:train_end_index]\n",
"test_labels = labels[train_end_index:]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Calculate the bias metrics"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"### Statistical Parity Difference"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"# We first look at the race\n",
"# Get the index of the samples with favorable outcome\n",
"favorable_samples = train_data[onp.where(train_labels==1)[0]]\n",
"unfavorable_samples = train_data[onp.where(train_labels==0)[0]]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"# number of favorable outcomes by race 0\n",
"race0_fav = (favorable_samples[:,0]==0).astype(float).sum()\n",
"race0_total = len(train_data[train_data[:, 0]==0])\n",
"race0_fav_rate = race0_fav/race0_total\n",
"\n",
"race1_fav = (favorable_samples[:,0]==1).astype(float).sum()\n",
"race1_total = len(train_data[train_data[:, 0]==1])\n",
"race1_fav_rate = race1_fav/race1_total\n"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "(8494.0, 33397)"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"race1_fav, race1_total"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "(851.0, 5676)"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"race0_fav, race0_total"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [
{
"data": {
"text/plain": "0.1044046938001694"
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"race1_fav_rate-race0_fav_rate"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Create a Simple MLP"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"# Define a simple MLP\n",
"class MLP(hk.Module):\n",
" def __init__(self, num_layers: int, hidden_dim: int, dropout: float):\n",
" super(MLP, self).__init__('mlp')\n",
" self.num_layers = num_layers\n",
" self.hidden_dim = hidden_dim\n",
" self.dropout = dropout\n",
"\n",
" def __call__(self, x: chex.Array, is_training:bool=False):\n",
" x_ = x\n",
" for i in range(self.num_layers - 1):\n",
" x_ = hk.Linear(self.hidden_dim)(x_)\n",
" # x_ = hk.BatchNorm(create_scale=True, create_offset=True, decay_rate=.999)(x_, is_training)\n",
" x_ = jax.nn.leaky_relu(x_)\n",
" x_ = hk.cond(is_training, lambda a: hk.dropout(hk.next_rng_key(), self.dropout, a), lambda a: a, x_)\n",
"\n",
" x_ = hk.Linear(2)(x_)\n",
" return x_"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [],
"source": [
"NUM_LAYERS = 3\n",
"HIDDEN_DIM = 128\n",
"DROPOUT = 0.25\n",
"\n",
"key = jax.random.PRNGKey(6)\n",
"mlp_init, mlp_apply = hk.transform(lambda x, t: MLP(NUM_LAYERS, HIDDEN_DIM, DROPOUT)(x, t))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [],
"source": [
"key, split = jax.random.split(key)\n",
"mlp_params = mlp_init(split, jnp.array(train_data[:10]), True)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## Train our model on the data"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 23,
"outputs": [],
"source": [
"optim = optax.adamw(0.005)\n",
"opt_state = optim.init(mlp_params)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 24,
"outputs": [],
"source": [
"NUM_EPOCHS = 1000\n",
"# BATCH_SIZE = 100\n",
"# For now we can use the entire data as the batch\n",
"train_data = jnp.array(train_data)\n",
"train_labels = jnp.array(train_labels).astype(int).flatten()\n",
"\n",
"test_data = jnp.array(test_data)\n",
"test_labels = jnp.array(test_labels).astype(int).flatten()\n",
"\n",
"def loss_fn(params, key, x, y, is_training=True):\n",
" logits = mlp_apply(params, key, x, is_training)\n",
" loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)\n",
" return loss.mean()\n",
"\n",
"loss_value_grad = jax.jit(jax.value_and_grad(loss_fn))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 25,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%| | 12/1000 [00:00<00:53, 18.42it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.6878519058227539, Train Acc: 0.7608323097229004, Test Acc: 0.7602620720863342\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 11%|█ | 112/1000 [00:01<00:13, 64.16it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.4207240045070648, Train Acc: 0.8057737946510315, Test Acc: 0.7966015338897705\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 22%|██▏ | 220/1000 [00:03<00:10, 77.82it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.4189543128013611, Train Acc: 0.8058761954307556, Test Acc: 0.7969086170196533\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 32%|███▏ | 318/1000 [00:04<00:08, 79.91it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.41802212595939636, Train Acc: 0.8060553669929504, Test Acc: 0.7966015338897705\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 42%|████▏ | 421/1000 [00:05<00:06, 85.12it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.4177604019641876, Train Acc: 0.806029736995697, Test Acc: 0.7979322671890259\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 52%|█████▏ | 523/1000 [00:06<00:05, 85.70it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.41721275448799133, Train Acc: 0.8061065077781677, Test Acc: 0.7964991331100464\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 63%|██████▎ | 626/1000 [00:07<00:04, 85.74it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.41683855652809143, Train Acc: 0.8060809373855591, Test Acc: 0.7978298664093018\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 72%|███████▏ | 717/1000 [00:08<00:03, 85.12it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.4164116084575653, Train Acc: 0.8060553669929504, Test Acc: 0.797522783279419\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 82%|████████▏ | 819/1000 [00:09<00:02, 83.34it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.4168620705604553, Train Acc: 0.8061065077781677, Test Acc: 0.7964991331100464\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 92%|█████████▏| 918/1000 [00:10<00:01, 81.49it/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 0.41707655787467957, Train Acc: 0.8061065077781677, Test Acc: 0.7964991331100464\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 1000/1000 [00:11<00:00, 89.28it/s]\n"
]
}
],
"source": [
"for i in tqdm(range(NUM_EPOCHS)):\n",
" key, split = jax.random.split(key)\n",
" loss, grad = loss_value_grad(mlp_params, split, train_data, train_labels)\n",
" updates, opt_state = optim.update(grad, opt_state, mlp_params)\n",
" mlp_params = optax.apply_updates(mlp_params, updates)\n",
" if i%100 == 0:\n",
" key, split = jax.random.split(key)\n",
" train_predictions = jnp.argmax(mlp_apply(mlp_params, split, train_data, False), axis=1)\n",
" train_acc = (train_predictions == train_labels).astype(float).mean()\n",
"\n",
" test_predictions = jnp.argmax(mlp_apply(mlp_params, split, test_data, False), axis=1)\n",
" test_acc = (test_predictions == test_labels).astype(float).mean()\n",
"\n",
" print(f'Loss: {loss}, Train Acc: {train_acc}, Test Acc: {test_acc}')"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 55,
"outputs": [
{
"data": {
"text/plain": "DeviceArray([[0.98889744, 0.9593886 ],\n [0.95056295, 0.99793124],\n [0.96067035, 0.98759353],\n [0.4896833 , 0.4588622 ],\n [0.4600625 , 0.48842293],\n [0.960789 , 0.9874729 ],\n [1.0003633 , 0.9482032 ],\n [0.92749125, 1.0220175 ],\n [0.4901985 , 0.45837286],\n [0.94303775, 1.0057118 ]], dtype=float32)"
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss_fn(mlp_params, split, train_data[:10], train_labels[:10])"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [],
"source": [
"test_logits = mlp_apply(mlp_params, key, train_data[:10], True)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [
{
"data": {
"text/plain": "DeviceArray([[0.51991534, 0.48008466],\n [0.50613105, 0.49386895],\n [0.51190686, 0.4880932 ],\n [0.56233084, 0.43766916],\n [0.547018 , 0.452982 ],\n [0.5108703 , 0.48912978],\n [0.55052614, 0.44947386],\n [0.52233845, 0.47766158],\n [0.46725613, 0.53274393],\n [0.44238308, 0.557617 ]], dtype=float32)"
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_logits"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [],
"source": [
"test_labels = train_labels[:10]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 30,
"outputs": [],
"source": [
"a = optax.softmax_cross_entropy_with_integer_labels(test_logits, test_labels.astype(int).flatten())"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 32,
"outputs": [
{
"data": {
"text/plain": "DeviceArray(0.7031968, dtype=float32)"
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.mean()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 38,
"outputs": [
{
"data": {
"text/plain": "DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 1, 1], dtype=int32)"
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"jnp.argmax(test_logits, axis=1)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment