This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Build model | |
in_id = tf.keras.layers.Input(shape=(max_seq_length,), name="input_ids") | |
in_mask = tf.keras.layers.Input(shape=(max_seq_length,), name="input_masks") | |
in_segment = tf.keras.layers.Input(shape=(max_seq_length,), name="segment_ids") | |
bert_inputs = [in_id, in_mask, in_segment] | |
# Instantiate the custom Bert Layer defined above | |
bert_output = BertLayer(n_fine_tune_layers=10)(bert_inputs) | |
# Build the rest of the classifier |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class BertLayer(tf.layers.Layer): | |
def __init__(self, n_fine_tune_layers=10, **kwargs): | |
self.n_fine_tune_layers = n_fine_tune_layers | |
self.trainable = True | |
self.output_size = 768 | |
super(BertLayer, self).__init__(**kwargs) | |
def build(self, input_shape): | |
self.bert = hub.Module( | |
bert_path, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Instantiate tokenizer | |
tokenizer = create_tokenizer_from_hub_module() | |
# Convert data to InputExample format | |
train_examples = convert_text_to_examples(train_text, train_label) | |
test_examples = convert_text_to_examples(test_text, test_label) | |
# Convert to features | |
(train_input_ids, train_input_masks, train_segment_ids, train_labels | |
) = convert_examples_to_features(tokenizer, train_examples, max_seq_length=max_seq_length) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load all files from a directory in a DataFrame. | |
def load_directory_data(directory): | |
data = {} | |
data["sentence"] = [] | |
data["sentiment"] = [] | |
for file_path in os.listdir(directory): | |
with tf.gfile.GFile(os.path.join(directory, file_path), "r") as f: | |
data["sentence"].append(f.read()) | |
data["sentiment"].append(re.match("\d+_(\d+)\.txt", file_path).group(1)) | |
return pd.DataFrame.from_dict(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Sun Sep 24 06:22:04 2017 | |
@author: jacobzweig | |
""" | |
import random | |
class Prisoner: | |
def __init__(self, id, numPrisoners): |