Skip to content

Instantly share code, notes, and snippets.

@rasmusbergpalm
Last active November 26, 2015 13:24
Show Gist options
  • Save rasmusbergpalm/ee3ea7ea0b7962e3f033 to your computer and use it in GitHub Desktop.
Save rasmusbergpalm/ee3ea7ea0b7962e3f033 to your computer and use it in GitHub Desktop.
Pseudo-code for Encoder part (section 4) of http://arxiv.org/pdf/1511.06391v1.pdf
def set2vector(x, init, f, g, h, RNN, T):
"""
x: input data w. N samples
init: function that initializes hidden state of RNN
f: f(x): function that maps each sample of x to a vector. E.g. the identity function, an MLP, etc.
g: g(m,q): function that maps vectors m and q to a single number. E.g. the dot product, an MLP, etc.
h: h(q,r): function that maps vectors q and vector r to a vector, e.g. an MLP or similar.
T: int, how many steps to run.
"""
qs[0] = init() # qs[t] is a vector (RNN "post" hidden state?)
N = len(x)
for i in 0:N
m[i] = f(x[i]) # x[i] and m[i] are vectors (input, embedded input)
for t in 1:T
q[t] = RNN(qs[t-1]) # q[t] is a vector (RNN "pre" hidden state?)
d = 0 # d is a scalar (softmax denominator)
for i in 0:N
e[i][t] = exp(g(m[i], q[t])) # e[i] is a scalar ("attention importance")
d += e[i][t]
for i in 0:N
a[i][t] = e[i][t]/d # a[i] is a scalar ("softmaxed attention importance")
r[t] = 0 # r[t] is a vector (softmax sum of embedded input)
for i in 0:N
r[t] += a[i][t]*m[i]
qs[t] = h(q[t], r[t])
return qs[T]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment