Skip to content

Instantly share code, notes, and snippets.

@ChuaCheowHuan
Created April 20, 2020 06:38
Show Gist options
  • Save ChuaCheowHuan/18977a3e77c0655d945e8af60633e4df to your computer and use it in GitHub Desktop.
Save ChuaCheowHuan/18977a3e77c0655d945e8af60633e4df to your computer and use it in GitHub Desktop.
KL divergence for multivariate normal distributions
def kl_mvn(m0, S0, m1, S1):
"""
https://stackoverflow.com/questions/44549369/kullback-leibler-divergence-from-gaussian-pm-pv-to-gaussian-qm-qv
The following function computes the KL-Divergence between any two
multivariate normal distributions
(no need for the covariance matrices to be diagonal)
Kullback-Liebler divergence from Gaussian pm,pv to Gaussian qm,qv.
Also computes KL divergence from a single Gaussian pm,pv to a set
of Gaussians qm,qv.
Diagonal covariances are assumed. Divergence is expressed in nats.
- accepts stacks of means, but only one S0 and S1
From wikipedia
KL( (m0, S0) || (m1, S1))
= .5 * ( tr(S1^{-1} S0) + log |S1|/|S0| +
(m1 - m0)^T S1^{-1} (m1 - m0) - N )
# 'diagonal' is [1, 2, 3, 4]
tf.diag(diagonal) ==> [[1, 0, 0, 0]
[0, 2, 0, 0]
[0, 0, 3, 0]
[0, 0, 0, 4]]
# See wikipedia on KL divergence special case.
#KL = 0.5 * tf.reduce_sum(1 + t_log_var - K.square(t_mean) - K.exp(t_log_var), axis=1)
if METHOD['name'] == 'kl_pen':
self.tflam = tf.placeholder(tf.float32, None, 'lambda')
kl = tf.distributions.kl_divergence(oldpi, pi)
self.kl_mean = tf.reduce_mean(kl)
self.aloss = -(tf.reduce_mean(surr - self.tflam * kl))
"""
# store inv diag covariance of S1 and diff between means
N = m0.shape[0]
iS1 = np.linalg.inv(S1)
diff = m1 - m0
# kl is made of three terms
tr_term = np.trace(iS1 @ S0)
det_term = np.log(np.linalg.det(S1)/np.linalg.det(S0)) #np.sum(np.log(S1)) - np.sum(np.log(S0))
quad_term = diff.T @ np.linalg.inv(S1) @ diff #np.sum( (diff*diff) * iS1, axis=1)
#print(tr_term,det_term,quad_term)
return .5 * (tr_term + det_term + quad_term - N)
@ChuaCheowHuan
Copy link
Author

@mvsoom Hey, thanks for sharing :)

@keizerzilla
Copy link

Thank you, @ChuaCheowHuan and @mvsoom!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment