Skip to content

Instantly share code, notes, and snippets.

@jekbradbury
Last active October 12, 2020 03:29
Show Gist options
  • Star 27 You must be signed in to star a gist
  • Fork 5 You must be signed in to fork a gist
  • Save jekbradbury/a3a5ae890328db49d8093c1a5bdc8a1e to your computer and use it in GitHub Desktop.
Save jekbradbury/a3a5ae890328db49d8093c1a5bdc8a1e to your computer and use it in GitHub Desktop.
We're happy to see community interest in the QRNN architecture. Here's the core of our implementation, written in Chainer. STRNNFunction is the recurrent pooling function, while QRNNLayer implements a QRNN layer composed of convolutional and pooling subcomponents, with optional attention and state-saving features for the tasks in the paper.
# requirements:
# pip install -e git://github.com/jekbradbury/chainer.git@raw-kernel
from chainer import cuda, Function, Variable, Chain
import chainer.links as L
import chainer.functions as F
import numpy as np
THREADS_PER_BLOCK = 32
class STRNNFunction(Function):
def forward_gpu(self, inputs):
f, z, hinit = inputs
b, t, c = f.shape
assert c % THREADS_PER_BLOCK == 0
self.h = cuda.cupy.zeros((b, t + 1, c), dtype=np.float32)
self.h[:, 0, :] = hinit
cuda.raw('''
#define THREADS_PER_BLOCK 32
extern "C" __global__ void strnn_fwd(
const CArray<float, 3> f, const CArray<float, 3> z,
CArray<float, 3> h) {
int index[3];
const int t_size = f.shape()[1];
index[0] = blockIdx.x;
index[1] = 0;
index[2] = blockIdx.y * THREADS_PER_BLOCK + threadIdx.x;
float prev_h = h[index];
for (int i = 0; i < t_size; i++){
index[1] = i;
const float ft = f[index];
const float zt = z[index];
index[1] = i + 1;
float &ht = h[index];
prev_h = prev_h * ft + zt;
ht = prev_h;
}
}''', 'strnn_fwd')(
(b, c // THREADS_PER_BLOCK), (THREADS_PER_BLOCK,),
(f, z, self.h))
return self.h[:, 1:, :],
def backward_gpu(self, inputs, grads):
f, z = inputs[:2]
gh, = grads
b, t, c = f.shape
gz = cuda.cupy.zeros_like(gh)
cuda.raw('''
#define THREADS_PER_BLOCK 32
extern "C" __global__ void strnn_back(
const CArray<float, 3> f, const CArray<float, 3> gh,
CArray<float, 3> gz) {
int index[3];
const int t_size = f.shape()[1];
index[0] = blockIdx.x;
index[2] = blockIdx.y * THREADS_PER_BLOCK + threadIdx.x;
index[1] = t_size - 1;
float &gz_last = gz[index];
gz_last = gh[index];
float prev_gz = gz_last;
for (int i = t_size - 1; i > 0; i--){
index[1] = i;
const float ft = f[index];
index[1] = i - 1;
const float ght = gh[index];
float &gzt = gz[index];
prev_gz = prev_gz * ft + ght;
gzt = prev_gz;
}
}''', 'strnn_back')(
(b, c // THREADS_PER_BLOCK), (THREADS_PER_BLOCK,),
(f, gh, gz))
gf = self.h[:, :-1, :] * gz
ghinit = f[:, 0, :] * gz[:, 0, :]
return gf, gz, ghinit
def strnn(f, z, h0):
return STRNNFunction()(f, z, h0)
def attention_sum(encoding, query):
alpha = F.softmax(F.batch_matmul(encoding, query, transb=True))
alpha, encoding = F.broadcast(alpha[:, :, :, None],
encoding[:, :, None, :])
return F.sum(alpha * encoding, axis=1)
class Linear(L.Linear):
def __call__(self, x):
shape = x.shape
if len(shape) == 3:
x = F.reshape(x, (-1, shape[2]))
y = super().__call__(self, x)
if len(shape) == 3:
y = F.reshape(y, shape)
return y
class QRNNLayer(Chain):
def __init__(self, in_size, out_size, kernel_size=2, attention=False,
decoder=False):
if kernel_size == 1:
super().__init__(W=Linear(in_size, 3 * out_size))
elif kernel_size == 2:
super().__init__(W=Linear(in_size, 3 * out_size, nobias=True),
V=Linear(in_size, 3 * out_size))
else:
super().__init__(
conv=L.ConvolutionND(1, in_size, 3 * out_size, kernel_size,
stride=1, pad=kernel_size - 1))
if attention:
self.add_link('U', Linear(out_size, 3 * in_size))
self.add_link('o', Linear(2 * out_size, out_size))
self.in_size, self.size, self.attention = in_size, out_size, attention
self.kernel_size = kernel_size
def pre(self, x):
dims = len(x.shape) - 1
if self.kernel_size == 1:
ret = self.W(x)
elif self.kernel_size == 2:
if dims == 2:
xprev = Variable(
self.xp.zeros((self.batch_size, 1, self.in_size),
dtype=np.float32), volatile='AUTO')
xtminus1 = F.concat((xprev, x[:, :-1, :]), axis=1)
else:
xtminus1 = self.x
ret = self.W(x) + self.V(xtminus1)
else:
ret = F.swapaxes(self.conv(
F.swapaxes(x, 1, 2))[:, :, :x.shape[2]], 1, 2)
if not self.attention:
return ret
if dims == 1:
enc = self.encoding[:, -1, :]
else:
enc = self.encoding[:, -1:, :]
return sum(F.broadcast(self.U(enc), ret))
def init(self, encoder_c=None, encoder_h=None):
self.encoding = encoder_c
self.c, self.x = None, None
if self.encoding is not None:
self.batch_size = self.encoding.shape[0]
if not self.attention:
self.c = self.encoding[:, -1, :]
if self.c is None or self.c.shape[0] < self.batch_size:
self.c = Variable(self.xp.zeros((self.batch_size, self.size),
dtype=np.float32), volatile='AUTO')
if self.x is None or self.x.shape[0] < self.batch_size:
self.x = Variable(self.xp.zeros((self.batch_size, self.in_size),
dtype=np.float32), volatile='AUTO')
def __call__(self, x):
if not hasattr(self, 'encoding') or self.encoding is None:
self.batch_size = x.shape[0]
self.init()
dims = len(x.shape) - 1
f, z, o = F.split_axis(self.pre(x), 3, axis=dims)
f = F.sigmoid(f)
z = (1 - f) * F.tanh(z)
o = F.sigmoid(o)
if dims == 2:
self.c = strnn(f, z, self.c[:self.batch_size])
else:
self.c = f * self.c + z
if self.attention:
context = attention_sum(self.encoding, self.c)
self.h = o * self.o(F.concat((self.c, context), axis=dims))
else:
self.h = self.c * o
self.x = x
return self.h
def get_state(self):
return F.concat((self.x, self.c, self.h), axis=1)
def set_state(self, state):
self.x, self.c, self.h = F.split_axis(
state, (self.in_size, self.in_size + self.size), axis=1)
state = property(get_state, set_state)
@Agoniii
Copy link

Agoniii commented Sep 28, 2018

@jekbradbury
hi jekbradbury, thanks for your implementation of QRNN.
I have a question, why is prev_h = prev_h * ft + zt in line 36 rather than prev_h = prev_h * ft + (1 - ft) * zt ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment