Skip to content

Instantly share code, notes, and snippets.

@KushajveerSingh
Created April 19, 2019 07:49
Show Gist options
  • Save KushajveerSingh/62ff1d386e7561a201842ee4741f31fc to your computer and use it in GitHub Desktop.
Save KushajveerSingh/62ff1d386e7561a201842ee4741f31fc to your computer and use it in GitHub Desktop.
SPADE model from the paper 1903.07291, my implementation.
class SPADE(Module):
def __init__(self, args, k):
super().__init__()
num_filters = args.spade_filter
kernel_size = args.spade_kernel
self.conv = spectral_norm(Conv2d(1, num_filters, kernel_size=(kernel_size, kernel_size), padding=1))
self.conv_gamma = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))
self.conv_beta = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))
def forward(self, x, seg):
N, C, H, W = x.size()
sum_channel = torch.sum(x.reshape(N, C, H*W), dim=-1)
mean = sum_channel / (N*H*W)
std = torch.sqrt((sum_channel**2 - mean**2) / (N*H*W))
mean = torch.unsqueeze(torch.unsqueeze(mean, -1), -1)
std = torch.unsqueeze(torch.unsqueeze(std, -1), -1)
x = (x - mean) / std
seg = F.interpolate(seg, size=(H,W), mode='nearest')
seg = relu(self.conv(seg))
seg_gamma = self.conv_gamma(seg)
seg_beta = self.conv_beta(seg)
x = torch.matmul(seg_gamma, x) + seg_beta
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment