Skip to content

Instantly share code, notes, and snippets.

@jsksxs360
Created October 10, 2021 02:30
Show Gist options
  • Save jsksxs360/fc925815e42d1e72e0398e3f971841d4 to your computer and use it in GitHub Desktop.
Save jsksxs360/fc925815e42d1e72e0398e3f971841d4 to your computer and use it in GitHub Desktop.
This class implements Factored Tensor Network (FTN) (https://aclanthology.org/P19-1058.pdf). FTN is used to model complex interactions between two vectors
#coding:utf-8
from keras import backend as K
import tensorflow as tf
from keras.engine.topology import Layer
class FactoredTensor(Layer):
"""
This class implements Factored Tensor Network (FTN) (https://aclanthology.org/P19-1058.pdf).
FTN is used to model complex interactions between two vectors
# Input
two vector: x1, x2
each with shape: `(samples, features)`.
# Output shape
vector with shape: `(samples, tensor slices)`.
Note: The layer has been tested with Keras 2.3.1 (Tensorflow 1.14.0 as backend)
Example:
x1_inputs = Input(shape=(None,), dtype='int32')
x1_inputs = Input(shape=(None,), dtype='int32')
x1_embeddings = Embedding(max_features, 128)(x1_inputs)
x2_embeddings = Embedding(max_features, 128)(x2_inputs)
x1_vector = GlobalAveragePooling1D()(x1_embeddings)
x2_vector = GlobalAveragePooling1D()(x2_embeddings)
interaction_features = FactoredTensor(256)([x1_vector, x2_vector])
result_vec = Dropout(0.1)(interaction_features)
outputs = Dense(1, activation='sigmoid')(result_vec)
"""
def __init__(self, m, r=8, **kwargs):
'''Factored Tensor Network
Args:
m: tensor slices
r: low rank
'''
self.m = m
self.r = r
super(FactoredTensor, self).__init__(**kwargs)
def get_config(self):
config = super().get_config()
config['m'] = self.m
config['r'] = self.r
return config
def build(self, input_shape):
self.d = input_shape[0][-1]
self.U = self.add_weight(shape=(2*self.d, self.m),
name='FTN_U',
initializer='glorot_uniform',
trainable=True)
self.J = self.add_weight(shape=(self.d, self.r, self.m),
name='FTN_J',
initializer='glorot_uniform',
trainable=True)
self.K = self.add_weight(shape=(self.r, self.d, self.m),
name='FTN_K',
initializer='glorot_uniform',
trainable=True)
self.s = self.add_weight(shape=(self.m,),
name='FTN_s',
initializer='glorot_uniform',
trainable=True)
def compute_mask(self, input, input_mask=None):
return None
def call(self, x, mask=None):
x1, x2 = x
assert x1.shape[-1] == x2.shape[-1], 'The two input vectors should have the same shape, \
got {} and {}.'.format(x1.shape, x2.shape)
# x1*J
j = K.permute_dimensions(self.J, (2, 1, 0))
_x1 = K.permute_dimensions(x1, (1, 0))
t = K.dot(j, _x1)
t = K.permute_dimensions(t, (0, 2, 1))
# x1*J*K
k = K.permute_dimensions(self.K, (2, 0, 1))
t = tf.einsum('mbr,mrd->mbd', t, k)
t = K.permute_dimensions(t, (1, 0, 2))
# x1*J*K*x2
t = tf.einsum('bmd,bd->bm', t, x2)
# U*[x1;x2]
u = K.dot(K.concatenate([x1, x2]), self.U)
return K.relu(t + u + self.s)
def compute_output_shape(self, input_shape):
return input_shape[0][0], self.m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment