Skip to content

Instantly share code, notes, and snippets.

@Gananath
Last active April 20, 2022 07:36
Show Gist options
  • Save Gananath/8d167384da7d3bc078650c73fab1a8dd to your computer and use it in GitHub Desktop.
Save Gananath/8d167384da7d3bc078650c73fab1a8dd to your computer and use it in GitHub Desktop.
Jain's Fairness Index to suppliment multiclass classification loss functions.

On a random surfing of web I came across this paper Leveraging Uncertainties in Softmax Decision-Making Models for Low-Power IoT Devices by Chiwoo Cho et.al. In this paper the author proposed using Jain’s Fairness Index(JFI) to compute uncertanity in a deep learning model in IoT devices.

Jain's Fairness Index

Instead of computing uncertanity, I thought of adding JFI as a suppliemtary criterion with our loss function to improve our model training. I have trained a model to classify iris dataset with and without Jain’s Fairness Index. There is a slight improvement in the model prediction when trained with Jain's Fairness Index.

The code for iris dataset training has been taken from here and I have only added jains_fairness_index() funciton to it. I tested this idea with this dataset alone. I have added the seed for reproduciblity and retrained the model with and without JFI from start.

Without JFI

Before JFI

With JFI

After JFI

#!/usr/bin/env python
# coding: utf-8
# Pytorch Iris dataset training code borrowed from
# https://janakiev.com/blog/pytorch-iris/
#
# Paper:Leveraging Uncertainties in Softmax Decision-Making Models for Low-Power IoT Devices
# Authors: Chiwoo Cho, Wooyeol Choi and Taewoon Kim
# Sensors 2020, 20(16), 4603; https://doi.org/10.3390/s20164603
# Link: https://www.mdpi.com/1424-8220/20/16/4603/htm
#
# Implemented Gananath R
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.autograd import Variable
seed = 2022
torch.manual_seed(seed)
np.random.seed(seed)
def jains_fairness_index(output, precision=4):
sum_squared = torch.sum(output, axis=1) ** 2
squared_sum = torch.sum(torch.square(output), axis=1)
fairness_idx = (sum_squared / (output.shape[1] * squared_sum)) ** precision
return 1 - fairness_idx.mean()
iris = load_iris()
X = iris["data"]
y = iris["target"]
names = iris["target_names"]
feature_names = iris["feature_names"]
# Scale data to have mean 0 and variance 1
# which is importance for convergence of the neural network
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Split the data set into training and testing
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.2, random_state=2
)
class Model(nn.Module):
def __init__(self, input_dim):
super(Model, self).__init__()
self.layer1 = nn.Linear(input_dim, 50)
self.layer2 = nn.Linear(50, 50)
self.layer3 = nn.Linear(50, 3)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = F.softmax(self.layer3(x), dim=1)
return x
model = Model(X_train.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
model
import tqdm
EPOCHS = 100
X_train = Variable(torch.from_numpy(X_train)).float()
y_train = Variable(torch.from_numpy(y_train)).long()
X_test = Variable(torch.from_numpy(X_test)).float()
y_test = Variable(torch.from_numpy(y_test)).long()
loss_list = np.zeros((EPOCHS,))
accuracy_list = np.zeros((EPOCHS,))
for epoch in tqdm.trange(EPOCHS):
y_pred = model(X_train)
loss = loss_fn(y_pred, y_train) + jains_fairness_index(y_pred)
loss_list[epoch] = loss.item()
# Zero gradients
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
y_pred = model(X_test)
correct = (torch.argmax(y_pred, dim=1) == y_test).type(torch.FloatTensor)
accuracy_list[epoch] = correct.mean()
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import OneHotEncoder
plt.figure(figsize=(10, 10))
plt.plot([0, 1], [0, 1], "k--")
# One hot encoding
enc = OneHotEncoder()
Y_onehot = enc.fit_transform(y_test[:, np.newaxis]).toarray()
with torch.no_grad():
y_pred = model(X_test).numpy()
fpr, tpr, threshold = roc_curve(Y_onehot.ravel(), y_pred.ravel())
plt.plot(fpr, tpr, label="AUC = {:.3f}".format(auc(fpr, tpr)))
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.title("ROC curve")
plt.legend()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment