Skip to content

Instantly share code, notes, and snippets.

@yunjey
Last active November 30, 2018 09:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yunjey/f630383825b6af37b1c4bcd9f291da7f to your computer and use it in GitHub Desktop.
Save yunjey/f630383825b6af37b1c4bcd9f291da7f to your computer and use it in GitHub Desktop.
from torch.nn.utils import spectral_norm
import torch.nn.functional as F
import torch.nn as nn
import torch
class NonLocalBlock(nn.Module):
"""Non-local block."""
def __init__(self, conv_dim):
super(NonLocalBlock, self).__init__()
self.conv1 = spectral_norm(nn.Conv2d(conv_dim, conv_dim//8, 1, 1, 0))
self.conv2 = spectral_norm(nn.Conv2d(conv_dim, conv_dim//8, 1, 1, 0))
self.conv3 = spectral_norm(nn.Conv2d(conv_dim, conv_dim//2, 1, 1, 0))
self.conv4 = spectral_norm(nn.Conv2d(conv_dim//2, conv_dim, 1, 1, 0))
self.downsample = nn.MaxPool2d(2, 2)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
N, C, H, W = x.size() # x: if (?, 1024, 8, 8)
query = self.conv1(x) # (?, 128, 8, 8)
query = query.reshape(N, C//8, -1) # (?, 128, 64)
key = self.conv2(x) # (?, 128, 8, 8)
key = self.downsample(key) # (?, 128, 4, 4)
key = key.reshape(N, C//8, -1) # (?, 128, 16)
attn = torch.bmm(query.transpose(1, 2), key) # (?, 64, 16)
attn = F.softmax(attn, dim=2) # (?, 64, 16)
value = self.conv3(x) # (?, 512, 8, 8)
value = self.downsample(value) # (?, 512, 4, 4)
value = value.reshape(N, C//2, -1) # (?, 512, 16)
out = torch.bmm(value, attn.transpose(1, 2)) # (?, 512, 64)
out = out.reshape(N, C//2, H, W) # (?, 512, 8, 8)
return x + self.gamma * self.conv4(out) # (?, 1024, 8, 8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment