Created
October 10, 2021 02:30
-
-
Save jsksxs360/fc925815e42d1e72e0398e3f971841d4 to your computer and use it in GitHub Desktop.
This class implements Factored Tensor Network (FTN) (https://aclanthology.org/P19-1058.pdf). FTN is used to model complex interactions between two vectors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#coding:utf-8 | |
from keras import backend as K | |
import tensorflow as tf | |
from keras.engine.topology import Layer | |
class FactoredTensor(Layer): | |
""" | |
This class implements Factored Tensor Network (FTN) (https://aclanthology.org/P19-1058.pdf). | |
FTN is used to model complex interactions between two vectors | |
# Input | |
two vector: x1, x2 | |
each with shape: `(samples, features)`. | |
# Output shape | |
vector with shape: `(samples, tensor slices)`. | |
Note: The layer has been tested with Keras 2.3.1 (Tensorflow 1.14.0 as backend) | |
Example: | |
x1_inputs = Input(shape=(None,), dtype='int32') | |
x1_inputs = Input(shape=(None,), dtype='int32') | |
x1_embeddings = Embedding(max_features, 128)(x1_inputs) | |
x2_embeddings = Embedding(max_features, 128)(x2_inputs) | |
x1_vector = GlobalAveragePooling1D()(x1_embeddings) | |
x2_vector = GlobalAveragePooling1D()(x2_embeddings) | |
interaction_features = FactoredTensor(256)([x1_vector, x2_vector]) | |
result_vec = Dropout(0.1)(interaction_features) | |
outputs = Dense(1, activation='sigmoid')(result_vec) | |
""" | |
def __init__(self, m, r=8, **kwargs): | |
'''Factored Tensor Network | |
Args: | |
m: tensor slices | |
r: low rank | |
''' | |
self.m = m | |
self.r = r | |
super(FactoredTensor, self).__init__(**kwargs) | |
def get_config(self): | |
config = super().get_config() | |
config['m'] = self.m | |
config['r'] = self.r | |
return config | |
def build(self, input_shape): | |
self.d = input_shape[0][-1] | |
self.U = self.add_weight(shape=(2*self.d, self.m), | |
name='FTN_U', | |
initializer='glorot_uniform', | |
trainable=True) | |
self.J = self.add_weight(shape=(self.d, self.r, self.m), | |
name='FTN_J', | |
initializer='glorot_uniform', | |
trainable=True) | |
self.K = self.add_weight(shape=(self.r, self.d, self.m), | |
name='FTN_K', | |
initializer='glorot_uniform', | |
trainable=True) | |
self.s = self.add_weight(shape=(self.m,), | |
name='FTN_s', | |
initializer='glorot_uniform', | |
trainable=True) | |
def compute_mask(self, input, input_mask=None): | |
return None | |
def call(self, x, mask=None): | |
x1, x2 = x | |
assert x1.shape[-1] == x2.shape[-1], 'The two input vectors should have the same shape, \ | |
got {} and {}.'.format(x1.shape, x2.shape) | |
# x1*J | |
j = K.permute_dimensions(self.J, (2, 1, 0)) | |
_x1 = K.permute_dimensions(x1, (1, 0)) | |
t = K.dot(j, _x1) | |
t = K.permute_dimensions(t, (0, 2, 1)) | |
# x1*J*K | |
k = K.permute_dimensions(self.K, (2, 0, 1)) | |
t = tf.einsum('mbr,mrd->mbd', t, k) | |
t = K.permute_dimensions(t, (1, 0, 2)) | |
# x1*J*K*x2 | |
t = tf.einsum('bmd,bd->bm', t, x2) | |
# U*[x1;x2] | |
u = K.dot(K.concatenate([x1, x2]), self.U) | |
return K.relu(t + u + self.s) | |
def compute_output_shape(self, input_shape): | |
return input_shape[0][0], self.m |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment