Skip to content

Instantly share code, notes, and snippets.

@DmitryUlyanov
Created October 24, 2017 15:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DmitryUlyanov/ef6fefb8a055a7739eefc0ab4d02b87d to your computer and use it in GitHub Desktop.
Save DmitryUlyanov/ef6fefb8a055a7739eefc0ab4d02b87d to your computer and use it in GitHub Desktop.
import torch
import torch.nn
from torch.autograd import Variable
def pairwise_euclidean(samples):
B = samples.size(0)
samples_norm = samples.mul(samples).sum(1)
samples_norm = samples_norm.expand(B, B)
dist_mat = samples.mm(samples.t()).mul(-2) + \
samples_norm.add(samples_norm.t())
return dist_mat
def sample_entropy(samples):
# Assume B x C input
dist_mat = pairwise_euclidean(samples)
# Get max and add it to diag
m = dist_mat.max().detach()
dist_mat_d = dist_mat + \
Variable(torch.eye(dist_mat.size(0)).type_as(samples.data) * (m.data[0] + 1))
entropy = (dist_mat_d.min(1)[0] + 1e-4).log().sum()
entropy *= (samples.size(1) + 0.) / samples.size(0)
return entropy
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment