Skip to content

Instantly share code, notes, and snippets.

@bkj
Created November 10, 2017 03:37
Show Gist options
  • Save bkj/f24bc7a646756b42459d23ecf15c65c5 to your computer and use it in GitHub Desktop.
Save bkj/f24bc7a646756b42459d23ecf15c65c5 to your computer and use it in GitHub Desktop.
from __future__ import division
import numpy as np
def mut_info(X, y):
n = X.shape[0]
n_11 = np.asarray(X[y].sum(axis=0)).squeeze()
n_01 = np.asarray(X[~y].sum(axis=0)).squeeze()
n_10 = np.asarray(y.sum() - n_11).squeeze()
n_00 = np.asarray(n - n_11 - n_01 - n_10).squeeze()
n = n_11 + n_10 + n_01 + n_00
n_1x = n_11 + n_10
n_x1 = n_11 + n_01
n_0x = n_01 + n_00
n_x0 = n_10 + n_00
p1 = n_11 * np.log(n * n_11 / (n_1x * n_x1))
p2 = n_01 * np.log(n * n_01 / (n_0x * n_x1))
p3 = n_10 * np.log(n * n_10 / (n_1x * n_x0))
p4 = n_00 * np.log(n * n_00 / (n_0x * n_x0))
p1[np.isnan(p1)] = 0
p2[np.isnan(p2)] = 0
p3[np.isnan(p3)] = 0
p4[np.isnan(p4)] = 0
return (p1 + p2 + p3 + p4) / n
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment