Skip to content

Instantly share code, notes, and snippets.

@kieranrcampbell
Created March 24, 2022 23:31
Show Gist options
  • Save kieranrcampbell/e78619de0172e12e8e4d130c66ee3c98 to your computer and use it in GitHub Desktop.
Save kieranrcampbell/e78619de0172e12e8e4d130c66ee3c98 to your computer and use it in GitHub Desktop.
pytorch implementation of mmd
## Can't remember where this comes from unfortunately
def gaussian_kernel(a, b):
dim1_1, dim1_2 = a.shape[0], b.shape[0]
depth = a.shape[1]
a = a.view(dim1_1, 1, depth)
b = b.view(1, dim1_2, depth)
a_core = a.expand(dim1_1, dim1_2, depth)
b_core = b.expand(dim1_1, dim1_2, depth)
numerator = (a_core - b_core).pow(2).mean(2)/depth
return torch.exp(-numerator)
def MMD(a, b):
return gaussian_kernel(a, a).mean() + gaussian_kernel(b, b).mean() - 2*gaussian_kernel(a, b).mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment