Last active
February 9, 2021 05:39
-
-
Save curegit/a1daca4eaf071eafa58895db50d06638 to your computer and use it in GitHub Desktop.
Chainer で二次元マップ Self-Attention 層の実装 (Self-Attention GAN)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from chainer import Parameter, Chain | |
from chainer.links import Convolution2D | |
from chainer.functions import einsum, softmax | |
from chainer.initializers import Zero | |
def dot(a, b): | |
return einsum("...ji,...ik->...jk", a, b) | |
class SelfAttention(Chain): | |
def __init__(self, channels, inner_channels=None): | |
super().__init__() | |
self.inner_channels = inner_channels or channels | |
with self.init_scope(): | |
self.f = Convolution2D(channels, self.inner_channels, ksize=1, stride=1, pad=0, nobias=True) | |
self.g = Convolution2D(channels, self.inner_channels, ksize=1, stride=1, pad=0, nobias=True) | |
self.h = Convolution2D(channels, self.inner_channels, ksize=1, stride=1, pad=0, nobias=True) | |
self.v = Convolution2D(self.inner_channels, channels, ksize=1, stride=1, pad=0, nobias=True) | |
self.gamma = Parameter(initializer=Zero(), shape=1) | |
def __call__(self, x): | |
b, _, h, w = x.shape | |
c = self.inner_channels | |
fx = self.f(x).reshape(b, c, h * w) | |
gx = self.g(x).reshape(b, c, h * w) | |
hx = self.h(x).reshape(b, c, h * w) | |
a = softmax(dot(fx.transpose(0, 2, 1), gx), axis=2) | |
return x + self.gamma * self.v(dot(hx, a).reshape(b, c, h, w)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment