Skip to content

Instantly share code, notes, and snippets.

@kastnerkyle
Last active September 21, 2019 07:28
Show Gist options
  • Save kastnerkyle/6f0809256eedb85fe615ec7c33cf5d4f to your computer and use it in GitHub Desktop.
Save kastnerkyle/6f0809256eedb85fe615ec7c33cf5d4f to your computer and use it in GitHub Desktop.
mutual info for discrete variables (in this example, 2 images)
import numpy as np
from scipy.misc import imread
from itertools import product
# https://www.mathworks.com/matlabcentral/fileexchange/36538-very-fast-mutual-information-betweentwo-images?focused=3869473&tab=function
def mutual_info(X, Y, levels=256):
# example inputs are images, but it shouldn't matter
X = X.astype("float32")
Y = Y.astype("float32")
tsz = len(X.ravel())
X_norm = X - X.min() + 1
Y_norm = Y - Y.min() + 1
matAB = np.zeros((tsz, 2))
matAB[:, 0] = X_norm.ravel()
matAB[:, 1] = Y_norm.ravel()
# joint histogram, match matlab
h, x_edges, y_edges = np.histogram2d(matAB[:, 0] + 1, matAB[:, 1] + 1, bins=levels)
h = np.concatenate((0. * h[:, 0][:, None], h), axis=1)
h = np.concatenate((0. * h[0, :][None, :], h), axis=0)
hn = h / np.sum(h.ravel())
y_marg = np.sum(hn, axis=0)
x_marg = np.sum(hn, axis=1)
Hy = -np.sum(y_marg * np.log2(y_marg + (y_marg == 0))) # Entropy of Y
Hx = -np.sum(x_marg * np.log2(x_marg + (x_marg == 0))) # Entropy of X
arg_xy2 = hn * (np.log2(hn + (hn == 0)))
h_xy = np.sum(-arg_xy2.ravel()) # joint entropy
M = Hx + Hy - h_xy # mutual information
return M
aa = imread("first_image.png", flatten=False, mode=None)
bb = imread("second_image.png", flatten=False, mode=None)
mi = mutual_info(aa, bb)
from IPython import embed; embed(); raise ValueError()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment