Skip to content

Instantly share code, notes, and snippets.

@ChuaCheowHuan
Created April 20, 2020 06:38
Show Gist options
  • Save ChuaCheowHuan/18977a3e77c0655d945e8af60633e4df to your computer and use it in GitHub Desktop.
Save ChuaCheowHuan/18977a3e77c0655d945e8af60633e4df to your computer and use it in GitHub Desktop.
KL divergence for multivariate normal distributions
def kl_mvn(m0, S0, m1, S1):
"""
https://stackoverflow.com/questions/44549369/kullback-leibler-divergence-from-gaussian-pm-pv-to-gaussian-qm-qv
The following function computes the KL-Divergence between any two
multivariate normal distributions
(no need for the covariance matrices to be diagonal)
Kullback-Liebler divergence from Gaussian pm,pv to Gaussian qm,qv.
Also computes KL divergence from a single Gaussian pm,pv to a set
of Gaussians qm,qv.
Diagonal covariances are assumed. Divergence is expressed in nats.
- accepts stacks of means, but only one S0 and S1
From wikipedia
KL( (m0, S0) || (m1, S1))
= .5 * ( tr(S1^{-1} S0) + log |S1|/|S0| +
(m1 - m0)^T S1^{-1} (m1 - m0) - N )
# 'diagonal' is [1, 2, 3, 4]
tf.diag(diagonal) ==> [[1, 0, 0, 0]
[0, 2, 0, 0]
[0, 0, 3, 0]
[0, 0, 0, 4]]
# See wikipedia on KL divergence special case.
#KL = 0.5 * tf.reduce_sum(1 + t_log_var - K.square(t_mean) - K.exp(t_log_var), axis=1)
if METHOD['name'] == 'kl_pen':
self.tflam = tf.placeholder(tf.float32, None, 'lambda')
kl = tf.distributions.kl_divergence(oldpi, pi)
self.kl_mean = tf.reduce_mean(kl)
self.aloss = -(tf.reduce_mean(surr - self.tflam * kl))
"""
# store inv diag covariance of S1 and diff between means
N = m0.shape[0]
iS1 = np.linalg.inv(S1)
diff = m1 - m0
# kl is made of three terms
tr_term = np.trace(iS1 @ S0)
det_term = np.log(np.linalg.det(S1)/np.linalg.det(S0)) #np.sum(np.log(S1)) - np.sum(np.log(S0))
quad_term = diff.T @ np.linalg.inv(S1) @ diff #np.sum( (diff*diff) * iS1, axis=1)
#print(tr_term,det_term,quad_term)
return .5 * (tr_term + det_term + quad_term - N)
@mvsoom
Copy link

mvsoom commented Dec 8, 2022

Here is a version that gives identical results, but should be significantly faster and more numerically stable. I also changed the parametrization. Enjoy! :)

def kl_mvn(to, fr):
    """Calculate `KL(to||fr)`, where `to` and `fr` are pairs of means and covariance matrices"""
    m_to, S_to = to
    m_fr, S_fr = fr
    
    d = m_fr - m_to
    
    c, lower = scipy.linalg.cho_factor(S_fr)
    def solve(B):
        return scipy.linalg.cho_solve((c, lower), B)
    
    def logdet(S):
        return np.linalg.slogdet(S)[1]

    term1 = np.trace(solve(S_to))
    term2 = logdet(S_fr) - logdet(S_to)
    term3 = d.T @ solve(d)
    return (term1 + term2 + term3 - len(d))/2.

@ChuaCheowHuan
Copy link
Author

@mvsoom Hey, thanks for sharing :)

@keizerzilla
Copy link

Thank you, @ChuaCheowHuan and @mvsoom!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment