Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active August 8, 2022 18:54
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/185ca53b35b012c7fe781e4c567378a6 to your computer and use it in GitHub Desktop.
Save xmodar/185ca53b35b012c7fe781e4c567378a6 to your computer and use it in GitHub Desktop.
Frechet's distance entirely in PyTorch with data batches streaming support.
"""Frechet's distance between two multi-variate Gaussians"""
import torch
import torch.nn as nn
class FrechetDistance:
"""Frechet's distance between two multi-variate Gaussians
https://www.sciencedirect.com/science/article/pii/0047259X8290077X
"""
def __init__(self, double=True, num_iterations=20, eps=1e-12):
self.eps = eps
self.double = double
self.num_iterations = num_iterations
def __call__(self, normal1, normal2):
# make sure that both of them have unbiased set the same
mu1, sigma1 = normal1.mean, normal1.covariance_matrix
mu2, sigma2 = normal2.mean, normal2.covariance_matrix
return self.compute(mu1, sigma1, mu2, sigma2)
def compute(self, mu1, sigma1, mu2, sigma2):
"""Compute Frechet's distance between two multi-variate Gaussians
https://gist.github.com/ModarTensai/185ca53b35b012c7fe781e4c567378a6
"""
norm_2 = (mu1 - mu2).norm(2, dim=-1).pow(2)
trace1 = sigma1.diagonal(0, -1, -2).sum(-1)
trace2 = sigma2.diagonal(0, -1, -2).sum(-1)
sigma3 = self.psd_matrix_sqrt(sigma1 @ sigma2)
trace3 = sigma3.diagonal(0, -1, -2).sum(-1)
return norm_2 + trace1 + trace2 - 2 * trace3
def psd_matrix_sqrt(self, matrix):
"""Compute the square root of a PSD matrix using Newton's method
https://gist.github.com/ModarTensai/7c4aeb3d75bf1e0ab99b24cf2b3b37a3
"""
dtype = matrix.dtype
if self.double:
matrix = matrix.double()
norm = matrix.norm(dim=[-2, -1], keepdim=True).clamp_min_(self.eps)
matrix = matrix / norm
def mul_diag_add(inputs, scale=-0.5, diag=1.5):
# multiply by a scalar then add a scalar to the diagonal
inputs.mul_(scale).diagonal(0, -1, -2).add_(diag)
return inputs
other = mul_diag_add(matrix.clone()) # avoid inplace
matrix = matrix @ other
for i in range(1, self.num_iterations):
temp = mul_diag_add(other @ matrix)
matrix = matrix @ temp
if i + 1 < self.num_iterations: # skip last step
other = temp @ other
return (matrix * norm.sqrt()).to(dtype)
class MultivariateNormal(nn.Module):
"""Multivariate normal (also called Gaussian) distribution
https://gist.github.com/ModarTensai/185ca53b35b012c7fe781e4c567378a6
"""
def __init__(self, feature_size, unbiased=True):
super().__init__()
self.unbiased = bool(unbiased)
if isinstance(feature_size, MultivariateNormal):
self.unbiased = feature_size.unbiased
mean = feature_size.mean.clone()
mass = feature_size.mass.clone()
self.count = feature_size.count
elif isinstance(feature_size, torch.distributions.MultivariateNormal):
assert feature_size.batch_shape == (), 'currently support 1D only'
mean = feature_size.mean.clone()
mass = feature_size.covariance_matrix * (1 / (mean.numel() - 1))
self.count = mean.numel() # heuristically set it to feature_size
else:
self.count = 0
mean = torch.zeros(feature_size)
mass = torch.zeros(feature_size, feature_size)
self.register_buffer('mean', mean)
self.register_buffer('mass', mass)
if not hasattr(self, 'mean'): # dummy to suppress pylint no-member
self.mean = self.mass = None
@property
def factor(self):
"""Get the normalization factor"""
return 1 / (self.count - int(bool(self.unbiased)))
@property
def covariance_matrix(self):
"""Get the covariance matrix"""
return self.mass * self.factor
@property
def variance(self):
"""Get the variance."""
return self.mass.diag() * self.factor
def forward(self, batch):
"""Perform the forward pass (only update in training mode)"""
mean, covariance, count = self.get_stats(batch, self.unbiased)
if self.training:
self.stats_update(mean, covariance, count, self.unbiased)
return mean, covariance
@staticmethod
def get_stats(batch, unbiased=True):
"""Compute the statistics of a batch
https://gist.github.com/ModarTensai/5ab449acba9df1a26c12060240773110
"""
assert 1 <= batch.ndim <= 2
if batch.ndim == 1:
batch.unsqueeze(0)
count = batch.shape[0]
mean = batch.mean(0)
if count == 1:
covariance = None
else:
batch = batch - mean
factor = 1 / (count - int(bool(unbiased)))
covariance = factor * batch.t().conj() @ batch
return mean, covariance, count
def stats_update(self, mean, covariance, count, unbiased=None):
"""Update the metric given batch statistics
https://gist.github.com/ModarTensai/dc95444faf3624ed979b4d0b2088fdf1
"""
diff1 = mean - self.mean
self.mean += diff1 * (count / (self.count + count))
diff2 = mean - self.mean
mass = diff1[:, None].conj() @ diff2[None, :]
if count > 1:
mass += covariance
mass *= count
if unbiased is None:
unbiased = self.unbiased
if unbiased:
mass -= covariance
self.mass += mass
self.count += count
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment