We're happy to see community interest in the QRNN architecture. Here's the core of our implementation, written in Chainer. STRNNFunction is the recurrent pooling function, while QRNNLayer implements a QRNN layer composed of convolutional and pooling subcomponents, with optional attention and state-saving features for the tasks in the paper.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# requirements: | |
# pip install -e git://github.com/jekbradbury/chainer.git@raw-kernel | |
from chainer import cuda, Function, Variable, Chain | |
import chainer.links as L | |
import chainer.functions as F | |
import numpy as np | |
THREADS_PER_BLOCK = 32 | |
class STRNNFunction(Function): | |
def forward_gpu(self, inputs): | |
f, z, hinit = inputs | |
b, t, c = f.shape | |
assert c % THREADS_PER_BLOCK == 0 | |
self.h = cuda.cupy.zeros((b, t + 1, c), dtype=np.float32) | |
self.h[:, 0, :] = hinit | |
cuda.raw(''' | |
#define THREADS_PER_BLOCK 32 | |
extern "C" __global__ void strnn_fwd( | |
const CArray<float, 3> f, const CArray<float, 3> z, | |
CArray<float, 3> h) { | |
int index[3]; | |
const int t_size = f.shape()[1]; | |
index[0] = blockIdx.x; | |
index[1] = 0; | |
index[2] = blockIdx.y * THREADS_PER_BLOCK + threadIdx.x; | |
float prev_h = h[index]; | |
for (int i = 0; i < t_size; i++){ | |
index[1] = i; | |
const float ft = f[index]; | |
const float zt = z[index]; | |
index[1] = i + 1; | |
float &ht = h[index]; | |
prev_h = prev_h * ft + zt; | |
ht = prev_h; | |
} | |
}''', 'strnn_fwd')( | |
(b, c // THREADS_PER_BLOCK), (THREADS_PER_BLOCK,), | |
(f, z, self.h)) | |
return self.h[:, 1:, :], | |
def backward_gpu(self, inputs, grads): | |
f, z = inputs[:2] | |
gh, = grads | |
b, t, c = f.shape | |
gz = cuda.cupy.zeros_like(gh) | |
cuda.raw(''' | |
#define THREADS_PER_BLOCK 32 | |
extern "C" __global__ void strnn_back( | |
const CArray<float, 3> f, const CArray<float, 3> gh, | |
CArray<float, 3> gz) { | |
int index[3]; | |
const int t_size = f.shape()[1]; | |
index[0] = blockIdx.x; | |
index[2] = blockIdx.y * THREADS_PER_BLOCK + threadIdx.x; | |
index[1] = t_size - 1; | |
float &gz_last = gz[index]; | |
gz_last = gh[index]; | |
float prev_gz = gz_last; | |
for (int i = t_size - 1; i > 0; i--){ | |
index[1] = i; | |
const float ft = f[index]; | |
index[1] = i - 1; | |
const float ght = gh[index]; | |
float &gzt = gz[index]; | |
prev_gz = prev_gz * ft + ght; | |
gzt = prev_gz; | |
} | |
}''', 'strnn_back')( | |
(b, c // THREADS_PER_BLOCK), (THREADS_PER_BLOCK,), | |
(f, gh, gz)) | |
gf = self.h[:, :-1, :] * gz | |
ghinit = f[:, 0, :] * gz[:, 0, :] | |
return gf, gz, ghinit | |
def strnn(f, z, h0): | |
return STRNNFunction()(f, z, h0) | |
def attention_sum(encoding, query): | |
alpha = F.softmax(F.batch_matmul(encoding, query, transb=True)) | |
alpha, encoding = F.broadcast(alpha[:, :, :, None], | |
encoding[:, :, None, :]) | |
return F.sum(alpha * encoding, axis=1) | |
class Linear(L.Linear): | |
def __call__(self, x): | |
shape = x.shape | |
if len(shape) == 3: | |
x = F.reshape(x, (-1, shape[2])) | |
y = super().__call__(self, x) | |
if len(shape) == 3: | |
y = F.reshape(y, shape) | |
return y | |
class QRNNLayer(Chain): | |
def __init__(self, in_size, out_size, kernel_size=2, attention=False, | |
decoder=False): | |
if kernel_size == 1: | |
super().__init__(W=Linear(in_size, 3 * out_size)) | |
elif kernel_size == 2: | |
super().__init__(W=Linear(in_size, 3 * out_size, nobias=True), | |
V=Linear(in_size, 3 * out_size)) | |
else: | |
super().__init__( | |
conv=L.ConvolutionND(1, in_size, 3 * out_size, kernel_size, | |
stride=1, pad=kernel_size - 1)) | |
if attention: | |
self.add_link('U', Linear(out_size, 3 * in_size)) | |
self.add_link('o', Linear(2 * out_size, out_size)) | |
self.in_size, self.size, self.attention = in_size, out_size, attention | |
self.kernel_size = kernel_size | |
def pre(self, x): | |
dims = len(x.shape) - 1 | |
if self.kernel_size == 1: | |
ret = self.W(x) | |
elif self.kernel_size == 2: | |
if dims == 2: | |
xprev = Variable( | |
self.xp.zeros((self.batch_size, 1, self.in_size), | |
dtype=np.float32), volatile='AUTO') | |
xtminus1 = F.concat((xprev, x[:, :-1, :]), axis=1) | |
else: | |
xtminus1 = self.x | |
ret = self.W(x) + self.V(xtminus1) | |
else: | |
ret = F.swapaxes(self.conv( | |
F.swapaxes(x, 1, 2))[:, :, :x.shape[2]], 1, 2) | |
if not self.attention: | |
return ret | |
if dims == 1: | |
enc = self.encoding[:, -1, :] | |
else: | |
enc = self.encoding[:, -1:, :] | |
return sum(F.broadcast(self.U(enc), ret)) | |
def init(self, encoder_c=None, encoder_h=None): | |
self.encoding = encoder_c | |
self.c, self.x = None, None | |
if self.encoding is not None: | |
self.batch_size = self.encoding.shape[0] | |
if not self.attention: | |
self.c = self.encoding[:, -1, :] | |
if self.c is None or self.c.shape[0] < self.batch_size: | |
self.c = Variable(self.xp.zeros((self.batch_size, self.size), | |
dtype=np.float32), volatile='AUTO') | |
if self.x is None or self.x.shape[0] < self.batch_size: | |
self.x = Variable(self.xp.zeros((self.batch_size, self.in_size), | |
dtype=np.float32), volatile='AUTO') | |
def __call__(self, x): | |
if not hasattr(self, 'encoding') or self.encoding is None: | |
self.batch_size = x.shape[0] | |
self.init() | |
dims = len(x.shape) - 1 | |
f, z, o = F.split_axis(self.pre(x), 3, axis=dims) | |
f = F.sigmoid(f) | |
z = (1 - f) * F.tanh(z) | |
o = F.sigmoid(o) | |
if dims == 2: | |
self.c = strnn(f, z, self.c[:self.batch_size]) | |
else: | |
self.c = f * self.c + z | |
if self.attention: | |
context = attention_sum(self.encoding, self.c) | |
self.h = o * self.o(F.concat((self.c, context), axis=dims)) | |
else: | |
self.h = self.c * o | |
self.x = x | |
return self.h | |
def get_state(self): | |
return F.concat((self.x, self.c, self.h), axis=1) | |
def set_state(self, state): | |
self.x, self.c, self.h = F.split_axis( | |
state, (self.in_size, self.in_size + self.size), axis=1) | |
state = property(get_state, set_state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@jekbradbury
hi jekbradbury, thanks for your implementation of QRNN.
I have a question, why is
prev_h = prev_h * ft + zt
in line 36 rather thanprev_h = prev_h * ft + (1 - ft) * zt
?