Instantly share code, notes, and snippets.

@jekbradbury /qrnn.py Secret
Last active Sep 28, 2018

Embed
What would you like to do?
We're happy to see community interest in the QRNN architecture. Here's the core of our implementation, written in Chainer. STRNNFunction is the recurrent pooling function, while QRNNLayer implements a QRNN layer composed of convolutional and pooling subcomponents, with optional attention and state-saving features for the tasks in the paper.
# requirements:
# pip install -e git://github.com/jekbradbury/chainer.git@raw-kernel
from chainer import cuda, Function, Variable, Chain
import chainer.links as L
import chainer.functions as F
import numpy as np
THREADS_PER_BLOCK = 32
class STRNNFunction(Function):
def forward_gpu(self, inputs):
f, z, hinit = inputs
b, t, c = f.shape
assert c % THREADS_PER_BLOCK == 0
self.h = cuda.cupy.zeros((b, t + 1, c), dtype=np.float32)
self.h[:, 0, :] = hinit
cuda.raw('''
#define THREADS_PER_BLOCK 32
extern "C" __global__ void strnn_fwd(
const CArray<float, 3> f, const CArray<float, 3> z,
CArray<float, 3> h) {
int index[3];
const int t_size = f.shape()[1];
index[0] = blockIdx.x;
index[1] = 0;
index[2] = blockIdx.y * THREADS_PER_BLOCK + threadIdx.x;
float prev_h = h[index];
for (int i = 0; i < t_size; i++){
index[1] = i;
const float ft = f[index];
const float zt = z[index];
index[1] = i + 1;
float &ht = h[index];
prev_h = prev_h * ft + zt;
ht = prev_h;
}
}''', 'strnn_fwd')(
(b, c // THREADS_PER_BLOCK), (THREADS_PER_BLOCK,),
(f, z, self.h))
return self.h[:, 1:, :],
def backward_gpu(self, inputs, grads):
f, z = inputs[:2]
gh, = grads
b, t, c = f.shape
gz = cuda.cupy.zeros_like(gh)
cuda.raw('''
#define THREADS_PER_BLOCK 32
extern "C" __global__ void strnn_back(
const CArray<float, 3> f, const CArray<float, 3> gh,
CArray<float, 3> gz) {
int index[3];
const int t_size = f.shape()[1];
index[0] = blockIdx.x;
index[2] = blockIdx.y * THREADS_PER_BLOCK + threadIdx.x;
index[1] = t_size - 1;
float &gz_last = gz[index];
gz_last = gh[index];
float prev_gz = gz_last;
for (int i = t_size - 1; i > 0; i--){
index[1] = i;
const float ft = f[index];
index[1] = i - 1;
const float ght = gh[index];
float &gzt = gz[index];
prev_gz = prev_gz * ft + ght;
gzt = prev_gz;
}
}''', 'strnn_back')(
(b, c // THREADS_PER_BLOCK), (THREADS_PER_BLOCK,),
(f, gh, gz))
gf = self.h[:, :-1, :] * gz
ghinit = f[:, 0, :] * gz[:, 0, :]
return gf, gz, ghinit
def strnn(f, z, h0):
return STRNNFunction()(f, z, h0)
def attention_sum(encoding, query):
alpha = F.softmax(F.batch_matmul(encoding, query, transb=True))
alpha, encoding = F.broadcast(alpha[:, :, :, None],
encoding[:, :, None, :])
return F.sum(alpha * encoding, axis=1)
class Linear(L.Linear):
def __call__(self, x):
shape = x.shape
if len(shape) == 3:
x = F.reshape(x, (-1, shape[2]))
y = super().__call__(self, x)
if len(shape) == 3:
y = F.reshape(y, shape)
return y
class QRNNLayer(Chain):
def __init__(self, in_size, out_size, kernel_size=2, attention=False,
decoder=False):
if kernel_size == 1:
super().__init__(W=Linear(in_size, 3 * out_size))
elif kernel_size == 2:
super().__init__(W=Linear(in_size, 3 * out_size, nobias=True),
V=Linear(in_size, 3 * out_size))
else:
super().__init__(
conv=L.ConvolutionND(1, in_size, 3 * out_size, kernel_size,
stride=1, pad=kernel_size - 1))
if attention:
self.add_link('U', Linear(out_size, 3 * in_size))
self.add_link('o', Linear(2 * out_size, out_size))
self.in_size, self.size, self.attention = in_size, out_size, attention
self.kernel_size = kernel_size
def pre(self, x):
dims = len(x.shape) - 1
if self.kernel_size == 1:
ret = self.W(x)
elif self.kernel_size == 2:
if dims == 2:
xprev = Variable(
self.xp.zeros((self.batch_size, 1, self.in_size),
dtype=np.float32), volatile='AUTO')
xtminus1 = F.concat((xprev, x[:, :-1, :]), axis=1)
else:
xtminus1 = self.x
ret = self.W(x) + self.V(xtminus1)
else:
ret = F.swapaxes(self.conv(
F.swapaxes(x, 1, 2))[:, :, :x.shape[2]], 1, 2)
if not self.attention:
return ret
if dims == 1:
enc = self.encoding[:, -1, :]
else:
enc = self.encoding[:, -1:, :]
return sum(F.broadcast(self.U(enc), ret))
def init(self, encoder_c=None, encoder_h=None):
self.encoding = encoder_c
self.c, self.x = None, None
if self.encoding is not None:
self.batch_size = self.encoding.shape[0]
if not self.attention:
self.c = self.encoding[:, -1, :]
if self.c is None or self.c.shape[0] < self.batch_size:
self.c = Variable(self.xp.zeros((self.batch_size, self.size),
dtype=np.float32), volatile='AUTO')
if self.x is None or self.x.shape[0] < self.batch_size:
self.x = Variable(self.xp.zeros((self.batch_size, self.in_size),
dtype=np.float32), volatile='AUTO')
def __call__(self, x):
if not hasattr(self, 'encoding') or self.encoding is None:
self.batch_size = x.shape[0]
self.init()
dims = len(x.shape) - 1
f, z, o = F.split_axis(self.pre(x), 3, axis=dims)
f = F.sigmoid(f)
z = (1 - f) * F.tanh(z)
o = F.sigmoid(o)
if dims == 2:
self.c = strnn(f, z, self.c[:self.batch_size])
else:
self.c = f * self.c + z
if self.attention:
context = attention_sum(self.encoding, self.c)
self.h = o * self.o(F.concat((self.c, context), axis=dims))
else:
self.h = self.c * o
self.x = x
return self.h
def get_state(self):
return F.concat((self.x, self.c, self.h), axis=1)
def set_state(self, state):
self.x, self.c, self.h = F.split_axis(
state, (self.in_size, self.in_size + self.size), axis=1)
state = property(get_state, set_state)
@Agoniii

This comment has been minimized.

Show comment
Hide comment
@Agoniii

Agoniii Sep 28, 2018

@jekbradbury
hi jekbradbury, thanks for your implementation of QRNN.
I have a question, why is prev_h = prev_h * ft + zt in line 36 rather than prev_h = prev_h * ft + (1 - ft) * zt ?

Agoniii commented Sep 28, 2018

@jekbradbury
hi jekbradbury, thanks for your implementation of QRNN.
I have a question, why is prev_h = prev_h * ft + zt in line 36 rather than prev_h = prev_h * ft + (1 - ft) * zt ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment