Skip to content

Instantly share code, notes, and snippets.

@samarthbhargav
Created January 25, 2024 14:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samarthbhargav/950e72a0c076d01ebc470266e8ef8ff1 to your computer and use it in GitHub Desktop.
Save samarthbhargav/950e72a0c076d01ebc470266e8ef8ff1 to your computer and use it in GitHub Desktop.
Pairwise KL Divergence between two diagonal gaussian distributions
def pairwise_kl_divergence(mean1, var1, mean2, var2):
# Ensure that variances are non-negative
var1 = torch.clamp(var1, min=1e-10)
var2 = torch.clamp(var2, min=1e-10)
k = mean1.size(1)
# shape = BZ_1xD
logvar1 = torch.log(var1)
# log determinant
logvar1det = logvar1.sum(1)
# shape = BZ_2xD
logvar2= torch.log(var2)
logvar2det = logvar2.sum(1)
# matrix of log(det(var2)) - log(det(var1)) - k
# shape = BZ_1, BZ_2 where (i,j) = (i+j)
log_var_diff = -logvar1det.reshape(-1, 1) + logvar2det - k
# inverse of var2
var2inv = 1/var2
# trace(var2^-1. var1) if both var1/var2 are diagonal
tr_prod = var1.matmul(var2inv.T)
# mudiff_sq - shape of BZ_1xBZ_2xD
mudiff_sq = (mean1.reshape(-1, 1, k) - mean2) ** 2
diff_div = (mudiff_sq * var2inv).sum(dim=-1)
kl = -0.5 * (log_var_diff + tr_prod + diff_div)
return kl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment