Last active
November 26, 2015 13:24
-
-
Save rasmusbergpalm/ee3ea7ea0b7962e3f033 to your computer and use it in GitHub Desktop.
Pseudo-code for Encoder part (section 4) of http://arxiv.org/pdf/1511.06391v1.pdf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def set2vector(x, init, f, g, h, RNN, T): | |
""" | |
x: input data w. N samples | |
init: function that initializes hidden state of RNN | |
f: f(x): function that maps each sample of x to a vector. E.g. the identity function, an MLP, etc. | |
g: g(m,q): function that maps vectors m and q to a single number. E.g. the dot product, an MLP, etc. | |
h: h(q,r): function that maps vectors q and vector r to a vector, e.g. an MLP or similar. | |
T: int, how many steps to run. | |
""" | |
qs[0] = init() # qs[t] is a vector (RNN "post" hidden state?) | |
N = len(x) | |
for i in 0:N | |
m[i] = f(x[i]) # x[i] and m[i] are vectors (input, embedded input) | |
for t in 1:T | |
q[t] = RNN(qs[t-1]) # q[t] is a vector (RNN "pre" hidden state?) | |
d = 0 # d is a scalar (softmax denominator) | |
for i in 0:N | |
e[i][t] = exp(g(m[i], q[t])) # e[i] is a scalar ("attention importance") | |
d += e[i][t] | |
for i in 0:N | |
a[i][t] = e[i][t]/d # a[i] is a scalar ("softmaxed attention importance") | |
r[t] = 0 # r[t] is a vector (softmax sum of embedded input) | |
for i in 0:N | |
r[t] += a[i][t]*m[i] | |
qs[t] = h(q[t], r[t]) | |
return qs[T] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment