Skip to content

Instantly share code, notes, and snippets.

@albertbuchard
Last active October 22, 2023 11:49
Show Gist options
  • Save albertbuchard/98b5739b40cee32ad3a33deec0709527 to your computer and use it in GitHub Desktop.
Save albertbuchard/98b5739b40cee32ad3a33deec0709527 to your computer and use it in GitHub Desktop.
MINE: Mutual Information Neural Estimation | Minimal Working Example
import math
import torch.optim as optim
import torch
from torch import nn
class MineWrapper(nn.Module):
def __init__(self, stat_model, moving_average_rate=0.1, unbiased=False):
super(MineWrapper, self).__init__()
self.stat_model = stat_model
self.unbiased = unbiased
LogWithMovingAverageGrad.alpha = moving_average_rate
def get_t_exp_t(self, x, y):
# resample y for marginal estimation
y_resampled = y[torch.randperm(y.shape[0])]
t = self.stat_model(x, mine_y=y).mean()
exp_t = torch.exp(self.stat_model(x, mine_y=y_resampled)).mean()
return t, exp_t
def get_loss(self, x, y):
t, exp_t = self.get_t_exp_t(x, y)
if self.unbiased:
lower_bound = (t - LogWithMovingAverage(exp_t))
else:
lower_bound = (t - torch.log(exp_t))
return -1.0 * lower_bound
def get_mutual_information(self, x, y):
t, exp_t = self.get_t_exp_t(x, y)
mi = (t - torch.log(exp_t)).item() / math.log(2)
return mi
class LogWithMovingAverageGrad(torch.autograd.Function):
# Static variable to store the moving average of the input
moving_avg_input = None
alpha = 0.01
@staticmethod
def forward(ctx, input):
# Compute the log and save the input for backward pass
output = input.log()
ctx.save_for_backward(input)
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# Update the moving average of the input
if LogWithMovingAverageGrad.moving_avg_input is None:
LogWithMovingAverageGrad.moving_avg_input = input
else:
LogWithMovingAverageGrad.moving_avg_input = (
LogWithMovingAverageGrad.alpha * input +
(1 - LogWithMovingAverageGrad.alpha) * LogWithMovingAverageGrad.moving_avg_input
)
# Normalize the grad_output by dividing it with the moving average of the input
grad_input = grad_output / LogWithMovingAverageGrad.moving_avg_input
return grad_input
LogWithMovingAverage = LogWithMovingAverageGrad.apply
class StatModel(nn.Module):
def __init__(self, dim):
super(StatModel, self).__init__()
self.layers = nn.Sequential(
nn.Linear(dim, 100),
nn.ReLU(),
nn.Linear(100, 1)
)
def forward(self, x, mine_y):
# Concatenate x and y
x_y = torch.cat([x, mine_y], dim=1)
out = self.layers(x_y)
return out
def train(x, y, num_epochs=100):
dim = x.shape[1] + y.shape[1]
stat_model = StatModel(dim)
# Create an instance of MineWrapper
mine = MineWrapper(stat_model=stat_model)
# Set up the optimizer
optimizer = optim.AdamW(mine.parameters(), lr=0.001)
# Training loop
mi = None
for epoch in range(num_epochs):
optimizer.zero_grad()
loss = mine.get_loss(x, y)
loss.backward()
optimizer.step()
mi = mine.get_mutual_information(x, y)
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}, Mutual Information: {mi}")
return mi
if __name__ == "__main__":
n = 10000
dim = 10
# Independent variables
x = torch.randn(n, 10)
y = torch.randint(0, 2, size=(n, 10)).float()
independent_mi = train(x, y)
# Dependent variables
x = torch.randn(n, 10)
y = x + torch.normal(0, 2, size=(n, 10)) > 0
y = y.float()
dependent_mi = train(x, y)
print(f"Independent MI: {independent_mi}, Dependent MI: {dependent_mi}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment