Skip to content

Instantly share code, notes, and snippets.

@curegit
Last active February 9, 2021 05:39
Show Gist options
  • Save curegit/a1daca4eaf071eafa58895db50d06638 to your computer and use it in GitHub Desktop.
Save curegit/a1daca4eaf071eafa58895db50d06638 to your computer and use it in GitHub Desktop.
Chainer で二次元マップ Self-Attention 層の実装 (Self-Attention GAN)
from chainer import Parameter, Chain
from chainer.links import Convolution2D
from chainer.functions import einsum, softmax
from chainer.initializers import Zero
def dot(a, b):
return einsum("...ji,...ik->...jk", a, b)
class SelfAttention(Chain):
def __init__(self, channels, inner_channels=None):
super().__init__()
self.inner_channels = inner_channels or channels
with self.init_scope():
self.f = Convolution2D(channels, self.inner_channels, ksize=1, stride=1, pad=0, nobias=True)
self.g = Convolution2D(channels, self.inner_channels, ksize=1, stride=1, pad=0, nobias=True)
self.h = Convolution2D(channels, self.inner_channels, ksize=1, stride=1, pad=0, nobias=True)
self.v = Convolution2D(self.inner_channels, channels, ksize=1, stride=1, pad=0, nobias=True)
self.gamma = Parameter(initializer=Zero(), shape=1)
def __call__(self, x):
b, _, h, w = x.shape
c = self.inner_channels
fx = self.f(x).reshape(b, c, h * w)
gx = self.g(x).reshape(b, c, h * w)
hx = self.h(x).reshape(b, c, h * w)
a = softmax(dot(fx.transpose(0, 2, 1), gx), axis=2)
return x + self.gamma * self.v(dot(hx, a).reshape(b, c, h, w))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment