Skip to content

Instantly share code, notes, and snippets.

@prerakmody
Created August 12, 2021 11:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prerakmody/c93754d36854400cdeb0c93e23784b07 to your computer and use it in GitHub Desktop.
Save prerakmody/c93754d36854400cdeb0c93e23784b07 to your computer and use it in GitHub Desktop.
Predictive (Entropy) + Model (Mutual Information) Uncertainty for DNNs
"""
Understanding Entropy and Mutual information (epistemic uncertainty) with Stochastic Forward Passes (=M)
Binary Classification
- y = {(p1, (1-p1)), (pn, (1-pn)), ... (pn, (1-pn))}
- Ent = -(pbar.log(pbar) + (1-pbar).log(1-pbar))
- MIF = Ent + avg_M(p1.log(p1) + ... + pn.log(pn) + (1-p1).log(p1) + ... + pn.log(1-pn))
- Case1
-- y = {(1,0), (1,0), ..., (1,0)}
-- pbar = (1,0)
-- Ent = -(1.log(1) + 0.log(0)) = 0 --> low predictive uncertainty
-- MIF = Ent + avg_M(1.log(1) + ... 1.log(1) + 0.log(0) + ... 0.log(0)) = 0 --> low model uncertainty
- Case 2
-- y = {(0.5,0.5), (0.5,0.5), ..., (0.5,0.5)}
-- pbar = (0.5,0.5)
-- Ent = -(0.5.log(0.5) + 0.5.log(0.5)) = 0.693147181 --> high predictive uncertainty
-- MIF = Ent + avg_M(0.5.log(0.5) + ... 0.5.log(0.5) + 0.5.log(0.5) + ... 0.5.log(0.5))
= 0.693147181 + -0.69314718 = 0 --> low model uncertainty
- Case 3
-- y = {(1,0), (0,1), ..., (1,0)}
-- pbar = (0.5, 0.5)
-- Ent = -(0.5.log(0.5) + 0.5.log(0.5)) = 0.693147181 --> high predictive uncertainty
-- MIF = Ent + avg_M(1.log(1) + ... 0.log(0) + 1.log(1) + ... 0.log(0)) = 0.693147181 --> high model uncertainty
"""
######### Case 2: [for different probs with no prob perturbations]
import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0, 1.0, 0.1)[1:]
ce = -x*np.log(x) + (-(1-x)*np.log(1-x))
MC = 10.0
X = np.repeat(np.expand_dims(x,0),MC,axis=0)
MIF = y + (1/MC)*np.sum(X*np.log(X) + (1-X)*np.log(1-X), axis=0)
MIF[abs(MIF) < 1e-10] = 0
f,axarr = plt.subplots(1,2)
plt.suptitle('Case 1 (deterministic probs for each Monte Carlo (M) forward pass) \n Binary Classification Setting (e.g Dog vs Cat)')
axarr[0].plot(x,ce)
axarr[0].set_title('CE')
axarr[0].set_xlabel('p')
axarr[0].set_ylabel('CE = pbar.log(pbar) + (1-pbar).log(1-pbar)')
axarr[1].plot(x,MIF)
axarr[1].set_title('MIF')
axarr[1].set_xlabel('p')
axarr[1].set_ylabel('MIF = CE + avg(p1.log(p1) + ... pm.log(pm) + (1-p1).log(1-p1) + ... (1-pm).log(1-pm))')
plt.show()
######### Case 1: [for different probs with prob perturbations = N(0,0.01)]
x = np.arange(0, 1.0, 0.1)[1:]
MC = 10.0
X = []
for _ in np.arange(MC):X.append(x + np.random.normal(0,0.01,len(x)))
X = np.array(X)
Xbar = np.mean(X,axis=0)
ce = -(Xbar*np.log(Xbar) + (1-Xbar)*np.log(1-Xbar))
MIF = ce + (1/MC)*np.sum(X*np.log(X) + (1-X)*np.log(1-X), axis=0)
f,axarr = plt.subplots(1,2)
plt.suptitle('Case 1 (perturbed-N(0,0.01) probs for each Monte Carlo (M) forward pass) \n Binary Classification Setting (e.g Dog vs Cat)')
axarr[0].plot(x,ce)
axarr[0].set_title('CE')
axarr[0].set_xlabel('p')
axarr[0].set_ylabel('CE = pbar.log(pbar) + (1-pbar).log(1-pbar)')
axarr[1].plot(x,MIF)
axarr[1].set_title('MIF')
axarr[1].set_xlabel('p')
axarr[1].set_ylabel('MIF = CE + avg(p1.log(p1) + ... pm.log(pm) + (1-p1).log(1-p1) + ... (1-pm).log(1-pm))')
axarr[1].set_ylim([-0.01,0.01])
plt.show()
######### Case 3: When model constantly flips its prediction i.e y = {(0.1,0.9), (0.9,0.1), ..., (0.1, 0.9)}
x = np.arange(0, 1.1, 0.1)
MC = 10.0
X = np.tile(np.vstack((x,1-x)), (int(MC/2),1))
Xbar = np.mean(X,axis=0)
ce = -(Xbar*np.log(Xbar) + (1-Xbar)*np.log(1-Xbar))
MIF = ce + (1/MC)*np.sum(np.nan_to_num(X*np.log(X) + (1-X)*np.log(1-X)), axis=0)
f,axarr = plt.subplots(1,2)
plt.suptitle('Case 3 (flipped probs for each Monte Carlo (M) forward pass) \n y = {(0.1,0.9), (0.9,0.1), ..., (0.1, 0.9)} \n Binary Classification Setting (e.g Dog vs Cat)')
axarr[0].plot(x,ce)
axarr[0].set_title('CE')
axarr[0].set_xlabel('p')
axarr[0].set_ylabel('CE = pbar.log(pbar) + (1-pbar).log(1-pbar)')
axarr[0].set_ylim([0,0.8])
axarr[1].plot(x,MIF)
axarr[1].set_title('MIF')
axarr[1].set_xlabel('p')
axarr[1].set_ylabel('MIF = CE + avg(p1.log(p1) + ... pm.log(pm) + (1-p1).log(1-p1) + ... (1-pm).log(1-pm))')
axarr[1].set_ylim([0,0.8])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment