Skip to content

Instantly share code, notes, and snippets.

@rohit-gupta
Last active December 17, 2022 03:10
Show Gist options
  • Save rohit-gupta/f24a5442762281c69f49bfc849e055d2 to your computer and use it in GitHub Desktop.
Save rohit-gupta/f24a5442762281c69f49bfc849e055d2 to your computer and use it in GitHub Desktop.
def var_covar_loss(Z, alpha=1.0, beta=0.01):
eps = 1e-5
K = z.shape[1]
# covariance matrix
C = torch.cov(Z.t())
# Push sqrt of diagonal terms to 1 (std dev = 1.0)
var_loss = K - torch.diag(C).clamp(eps, 1).sqrt().sum()
# Push off diagonal terms to 0 (features should not be correlated with each other)
cov_loss = 2 * torch.triu(C, diagonal=1).square().sum()
loss = alpha * var_loss + beta * cov_loss
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment