Skip to content

Instantly share code, notes, and snippets.

@npow
Last active August 29, 2015 14:16
Show Gist options
  • Save npow/7c6c1d5f5641d37a7139 to your computer and use it in GitHub Desktop.
Save npow/7c6c1d5f5641d37a7139 to your computer and use it in GitHub Desktop.
from functools import partial
from theano.ifelse import ifelse
import numpy as np
import theano
import theano.tensor as T
def maxx(e):
return ifelse(T.gt(e, 0.0), e, 0.0)
def s(x, y, U):
return T.dot(T.dot((x).reshape((1,-1)), U.T), T.dot(U, (y)))
def getTrain(D=100, W=100):
U_O = theano.shared(np.random.randn(D, W))
U_R = theano.shared(np.random.randn(D, W))
m_o1 = T.vector('m_o1')
m_o2 = T.vector('m_o2')
x = T.vector('x')
y = T.vector('y')
r = T.vector('r')
gamma = T.scalar('gamma')
f1 = T.vector('f1')
f2 = T.vector('f2')
L = T.matrix('L') # list of messages
V = T.matrix('V') # vocab
sO = partial(s, U=U_O)
sR = partial(s, U=U_R)
cost1, u1 = theano.scan(lambda f_bar: maxx(gamma - sO(x, f1) + sO(x, f_bar)), L)
cost2, u2 = theano.scan(lambda f_bar: gamma - sO(x, f2) - sO(m_o1, f2) + sO(x, f_bar) + sO(m_o1, f_bar), L)
cost3, u3 = theano.scan(lambda r_bar: gamma - sR(x, r) - sR(m_o1, r) - sR(m_o2, r) + sR(x, r_bar) + sR(m_o1, r_bar) + sR(m_o2, r_bar), V)
cost = cost1.sum() + cost2.sum() + cost3.sum()
g_uo, g_ur = T.grad(cost, [U_O, U_R])
train = theano.function(
inputs=[m_o1, m_o2, x, gamma, f1, f2, L, V, r],
outputs=[cost],
updates=[(U_O, U_O-0.1*g_uo), (U_R, U_R-0.1*g_ur)])
return train
getTrain()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment