Skip to content

Instantly share code, notes, and snippets.

@t-ae
Created August 8, 2017 13:48
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save t-ae/732f78671643de97bbe2c46519972491 to your computer and use it in GitHub Desktop.
Save t-ae/732f78671643de97bbe2c46519972491 to your computer and use it in GitHub Desktop.
Minibatch discrimination module in PyTorch
import torch
import torch.nn as nn
import torch.nn.init as init
class MinibatchDiscrimination(nn.Module):
def __init__(self, in_features, out_features, kernel_dims, mean=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.kernel_dims = kernel_dims
self.mean = mean
self.T = nn.Parameter(torch.Tensor(in_features, out_features, kernel_dims))
init.normal(self.T, 0, 1)
def forward(self, x):
# x is NxA
# T is AxBxC
matrices = x.mm(self.T.view(self.in_features, -1))
matrices = matrices.view(-1, self.out_features, self.kernel_dims)
M = matrices.unsqueeze(0) # 1xNxBxC
M_T = M.permute(1, 0, 2, 3) # Nx1xBxC
norm = torch.abs(M - M_T).sum(3) # NxNxB
expnorm = torch.exp(-norm)
o_b = (expnorm.sum(0) - 1) # NxB, subtract self distance
if self.mean:
o_b /= x.size(0) - 1
x = torch.cat([x, o_b], 1)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment