Skip to content

Instantly share code, notes, and snippets.

@coreyjs
Created January 23, 2021 13:18
Show Gist options
  • Save coreyjs/056cf5707b3216469738e7bd9ee25f58 to your computer and use it in GitHub Desktop.
Save coreyjs/056cf5707b3216469738e7bd9ee25f58 to your computer and use it in GitHub Desktop.
Calculate Gram Matrix
def gram_matrix(tensor):
""" Calculate the Gram Matrix of a given tensor
Gram Matrix: https://en.wikipedia.org/wiki/Gramian_matrix
"""
## get the batch_size, depth, height, and width of the Tensor
## reshape it, so we're multiplying the features for each channel
## calculate the gram matrix
batch_size, d, h, w = tensor.size()
tensor = tensor.view(d, h * w)
gram = torch.mm(tensor, tensor.t())
return gram
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment