Created
October 28, 2019 17:38
-
-
Save elsheikh21/e20dcbf526200c592e8e856b801a5fb7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import tensorflow_hub as hub | |
from tensorflow.keras import backend as K | |
class BertEmbeddingLayer(tf.keras.layers.Layer): | |
''' | |
Integrate BERT Embeddings from tensorflow hub into a | |
custom Keras layer. | |
references: | |
1. https://github.com/strongio/keras-bert | |
2. https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1 | |
''' | |
def __init__( | |
self, | |
n_fine_tune_layers=10, | |
pooling="first", | |
bert_path="https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1", | |
**kwargs, | |
): | |
self.n_fine_tune_layers = n_fine_tune_layers | |
self.trainable = True | |
self.output_size = 768 | |
self.pooling = pooling | |
self.bert_path = bert_path | |
if self.pooling not in ["first", "mean"]: | |
raise NameError( | |
f"Undefined pooling type (must be either first or mean, but is {self.pooling}" | |
) | |
super(BertEmbeddingLayer, self).__init__(**kwargs) | |
def build(self, input_shape): | |
self.bert = hub.Module( | |
self.bert_path, trainable=self.trainable, name=f"{self.name}_module" | |
) | |
# Remove unused layers | |
trainable_vars = self.bert.variables | |
if self.pooling == "first": | |
trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name] | |
trainable_layers = ["pooler/dense"] | |
elif self.pooling == "mean": | |
trainable_vars = [ | |
var | |
for var in trainable_vars | |
if not "/cls/" in var.name and not "/pooler/" in var.name | |
] | |
trainable_layers = [] | |
else: | |
raise NameError( | |
f"Undefined pooling type (must be either first or mean, but is {self.pooling}" | |
) | |
# Select how many layers to fine tune | |
for i in range(self.n_fine_tune_layers): | |
trainable_layers.append(f"encoder/layer_{str(11 - i)}") | |
# Update trainable vars to contain only the specified layers | |
trainable_vars = [ | |
var | |
for var in trainable_vars | |
if any([l in var.name for l in trainable_layers]) | |
] | |
# Add to trainable weights | |
for var in trainable_vars: | |
self._trainable_weights.append(var) | |
for var in self.bert.variables: | |
if var not in self._trainable_weights: | |
self._non_trainable_weights.append(var) | |
super(BertEmbeddingLayer, self).build(input_shape) | |
def call(self, inputs): | |
inputs = [K.cast(x, dtype="int32") for x in inputs] | |
input_ids, input_mask, segment_ids = inputs | |
bert_inputs = dict( | |
input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids | |
) | |
if self.pooling == "first": | |
pooled = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[ | |
"pooled_output" | |
] | |
elif self.pooling == "mean": | |
result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)[ | |
"sequence_output" | |
] | |
mul_mask = lambda x, m: x * tf.expand_dims(m, axis=-1) | |
masked_reduce_mean = lambda x, m: tf.reduce_sum(mul_mask(x, m), axis=1) / ( | |
tf.reduce_sum(m, axis=1, keepdims=True) + 1e-10) | |
input_mask = tf.cast(input_mask, tf.float32) | |
pooled = masked_reduce_mean(result, input_mask) | |
else: | |
raise NameError(f"Undefined pooling type (must be either first or mean, but is {self.pooling}") | |
return pooled | |
def compute_output_shape(self, input_shape): | |
return (input_shape[0], self.output_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment