Created November 4, 2019 00:40
"cells": [
"source": [
" # %tensorflow_version only exists in Colab.\n",
" %tensorflow_version 2.x\n",
"except Exception:\n",
" pass\n",
"import tensorflow as tf\n",
"import numpy as np\n",
"import os\n",
"import time\n",
"source": [
"labels = set()\n",
"def file2Examples(file_name):\n",
" '''\n",
" Read data files and return input/output pairs\n",
" '''\n",
" \n",
" examples=[]\n",
" with open(file_name,\"r\") as f:\n",
" next(f)\n",
" next(f)\n",
" example = [[],[]]\n",
" for line in f:\n",
" input_output_split= line.split()\n",
" if len(input_output_split)==4:\n",
" example[0].append(input_output_split[0])\n",
" example[1].append(input_output_split[-1])\n",
" labels.add(input_output_split[-1])\n",
" elif len(input_output_split)==0:\n",
" examples.append(example)\n",
" example=[[],[]]\n",
" else:\n",
" example=[[],[]]\n",
" f.close()\n",
" \n",
" return examples\n",
" \n",
"# Extract examples from train, validation, and test files which can be found at \n",
"train_examples = file2Examples(\"train.txt\")\n",
"test_examples = file2Examples(\"test.txt\")\n",
"valid_examples = file2Examples(\"valid.txt\")"
"source": [
" # create character vocab\n",
" all_text = \" \".join([\" \".join(x[0]) for x in train_examples+valid_examples+test_examples])\n",
" vocab = sorted(set(all_text))\n",
" \n",
" # create character/id and label/id mapping\n",
" char2idx = {u:i+1 for i, u in enumerate(vocab)}\n",
" idx2char = np.array(vocab)\n",
" label2idx = {u:i+1 for i, u in enumerate(labels)}\n",
" idx2label = np.array(labels)\n",
" \n",
" print(idx2label)\n",
" print(char2idx)"
"source": [
" def split_char_labels(eg):\n",
" '''\n",
" For a given input/output example, break tokens into characters while keeping \n",
" the same label.\n",
" '''\n",
" tokens = eg[0]\n",
" labels=eg[1]\n",
" input_chars = []\n",
" output_char_labels = []\n",
" for token,label in zip(tokens,labels):\n",
" input_chars.extend([char for char in token])\n",
" input_chars.extend(' ')\n",
" output_char_labels.extend([label]*len(token))\n",
" output_char_labels.extend('O')\n",
" return [[char2idx[x] for x in input_chars[:-1]],np.array([label2idx[x] for x in output_char_labels[:-1]])]\n",
" \n",
" train_formatted = [split_char_labels(eg) for eg in train_examples]\n",
" test_formatted = [split_char_labels(eg) for eg in test_examples]\n",
" valid_formatted = [split_char_labels(eg) for eg in valid_examples]\n",
" \n",
" print(len(train_formatted))\n",
" print(len(test_formatted))\n",
" print(len(valid_formatted))"
"source": [
" # training generator\n",
" def gen_train_series():\n",
" for eg in train_formatted:\n",
" yield eg[0],eg[1]\n",
" \n",
" # validation generator\n",
" def gen_valid_series():\n",
" \n",
" for eg in valid_formatted:\n",
" yield eg[0],eg[1]\n",
" \n",
" # test generator\n",
" def gen_test_series():\n",
" for eg in test_formatted:\n",
" yield eg[0],eg[1]\n",
" \n",
" # create Dataset objects for train, test and validation sets \n",
" series =,output_types=(tf.int32, tf.int32),output_shapes = ((None, None)))\n",
" series_valid =,output_types=(tf.int32, tf.int32),output_shapes = ((None, None)))\n",
" series_test =,output_types=(tf.int32, tf.int32),output_shapes = ((None, None)))\n",
" BATCH_SIZE = 128\n",
" BUFFER_SIZE=1000\n",
" \n",
" # create padded batch series objects for train, test and validation sets\n",
" ds_series_batch = series.shuffle(BUFFER_SIZE).padded_batch(BATCH_SIZE, padded_shapes=([None], [None]), drop_remainder=True)\n",
" ds_series_batch_valid = series_valid.padded_batch(BATCH_SIZE, padded_shapes=([None], [None]), drop_remainder=True)\n",
" ds_series_batch_test = series_test.padded_batch(BATCH_SIZE, padded_shapes=([None], [None]), drop_remainder=True)\n",
" \n",
" # print example batches\n",
" for input_example_batch, target_example_batch in ds_series_batch_valid.take(1):\n",
" print(input_example_batch)\n",
" print(target_example_batch)"
"source": [
" vocab_size = len(vocab)+1\n",
" # The embedding dimension\n",
" embedding_dim = 256\n",
" # Number of RNN units\n",
" rnn_units = 1024\n",
" label_size = len(labels) \n",
" \n",
" # build LSTM model\n",
" def build_model(vocab_size,label_size, embedding_dim, rnn_units, batch_size):\n",
" model = tf.keras.Sequential([\n",
" tf.keras.layers.Embedding(vocab_size, embedding_dim,\n",
" batch_input_shape=[batch_size, None],mask_zero=True),\n",
" tf.keras.layers.LSTM(rnn_units,\n",
" return_sequences=True,\n",
" stateful=True,\n",
" recurrent_initializer='glorot_uniform'),\n",
" tf.keras.layers.Dense(label_size)\n",
" ])\n",
" return model\n",
" model = build_model(\n",
" vocab_size = len(vocab)+1,\n",
" label_size=len(labels)+1,\n",
" embedding_dim=embedding_dim,\n",
" rnn_units=rnn_units,\n",
" batch_size=BATCH_SIZE)\n",
" model.summary()"
"source": [
" import os\n",
" # define loss function\n",
" def loss(labels, logits):\n",
" return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)\n",
" model.compile(optimizer='adam', loss=loss,metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])\n",
" # Directory where the checkpoints will be saved\n",
" checkpoint_dir = './training_checkpoints'\n",
" # Name of the checkpoint files\n",
" checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt_{epoch}\")\n",
" checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(\n",
" filepath=checkpoint_prefix,\n",
" save_weights_only=True)"
"source": [
" EPOCHS=20\n",
" \n",
" history =, epochs=EPOCHS, validation_data=ds_series_batch_valid,callbacks=[checkpoint_callback])"
"source": [
"from sklearn.metrics import classification_report, confusion_matrix\n",
"preds = np.array([])\n",
"y_trues= np.array([])\n",
"# iterate through test set, make predictions based on trained model\n",
"for input_example_batch, target_example_batch in ds_series_batch_test:\n",
" pred=model.predict(input_example_batch)\n",
" pred_max=tf.argmax(tf.nn.softmax(pred),2).numpy().flatten()\n",
" y_true=target_example_batch.numpy().flatten()\n",
" preds=np.concatenate([preds,pred_max])\n",
" y_trues=np.concatenate([y_trues,y_true])\n",
"# remove padding from evaluation\n",
"remove_padding = [(p,y) for p,y in zip(preds,y_trues) if y!=0]\n",
"r_p = [x[0] for x in remove_padding]\n",
"r_t = [x[1] for x in remove_padding]\n",
"# print confusion matrix and classification report\n",
