Skip to content

Instantly share code, notes, and snippets.

@minhlab
Last active August 29, 2015 14:10
Show Gist options
  • Save minhlab/9f4109dcdb66b2ce6358 to your computer and use it in GitHub Desktop.
Save minhlab/9f4109dcdb66b2ce6358 to your computer and use it in GitHub Desktop.
neural tensor network
'''
Created on Nov 17, 2014
@author: Minh Ngoc Le
'''
from pylearn2.models.mlp import Layer
from pylearn2.space import IndexSpace, VectorSpace
from pylearn2.utils import sharedX, wraps
import numpy
from theano import tensor as T
class NeuralTensorLayer(Layer):
def __init__(self, nrelations, k, max_input_labels, dim, irange=0.1, layer_name="input"):
super(NeuralTensorLayer, self).__init__()
self.layer_name = layer_name
self.nrelations = nrelations
self.dim = dim
self.max_input_labels = max_input_labels
W_value = numpy.random.uniform(-irange, irange, size=(nrelations, k, dim, dim))
self.W = sharedX(W_value, 'W')
b_value = numpy.zeros((nrelations, k))
self.b = sharedX(b_value, 'b')
e_value = numpy.random.uniform(-irange, irange, size=(max_input_labels, dim))
self.e = sharedX(e_value, 'e')
V_value = numpy.random.uniform(-irange, irange, size=(nrelations, k, 2*dim))
self.V = sharedX(V_value, 'V')
self._params = [self.e, self.W, self.b, self.V]
self.input_space = IndexSpace(max_input_labels, dim=3)
self.output_space = VectorSpace(dim=k)
@wraps(Layer.fprop)
def fprop(self, inputs):
e1 = self.e[inputs[:,0].flatten()]
W = self.W[inputs[:,1].flatten()]
e2 = self.e[inputs[:,2].flatten()]
V = self.V[inputs[:,1].flatten()]
b = self.b[inputs[:,1].flatten()]
e = T.concatenate([e1, e2], axis=1)
tensor = T.batched_dot(T.batched_tensordot(e1, W, axes=[[1],[2]]), e2)
neural = T.batched_dot(V, e) + b
return tensor + neural
@wraps(Layer.get_layer_monitoring_channels)
def get_layer_monitoring_channels(self, state_below=None,
state=None, targets=None):
return super(NeuralTensorLayer, self).get_layer_monitoring_channels()
@wraps(Layer.set_input_space)
def set_input_space(self, space):
"""
TODO: check
"""
pass
@wraps(Layer.get_weight_decay)
def get_weight_decay(self, coeff):
if isinstance(coeff, str):
coeff = float(coeff)
assert isinstance(coeff, float) or hasattr(coeff, 'dtype')
return coeff * (T.sqr(self.W).sum() +
T.sqr(self.V).sum() +
T.sqr(self.e).sum())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment