Skip to content

Instantly share code, notes, and snippets.

@alanbuxton
Created January 1, 2021 16:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alanbuxton/804082dbd5f3a5e85ab1e68e5a4317cd to your computer and use it in GitHub Desktop.
Save alanbuxton/804082dbd5f3a5e85ab1e68e5a4317cd to your computer and use it in GitHub Desktop.
Compare use of IOB vs non-IOB tags in NER
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See <https://ai.stackexchange.com/questions/22514/do-we-have-to-use-the-iob-format-on-labels-in-the-ner-dataset-if-so-why>\n",
"\n",
"This notebook compares using full IOB tags (e.g. B-PER, I-PER) vs stripping off the B/I (e.g. just using PER)\n",
"\n",
"Requires the `ner_dataset.csv` which you can download from <https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import trax\n",
"from trax import layers as tl\n",
"import pandas as pd\n",
"import numpy as np\n",
"from trax.supervised import training\n",
"from sklearn.model_selection import train_test_split\n",
"import random as rnd\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def simplify_tag(iob_tag):\n",
" if iob_tag == 'O': \n",
" return iob_tag\n",
" else:\n",
" return iob_tag[2:]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv(\"ner_dataset.csv\",encoding = 'ISO-8859-1')\n",
"data = data.fillna(method = 'ffill')\n",
"\n",
"sentences = data.groupby(\"Sentence #\").apply(lambda x:[(w,t) for w,t in zip(x[\"Word\"].values.tolist(),\n",
" x[\"Tag\"].values.tolist())])\n",
"\n",
"words = data.loc[:, \"Word\"]\n",
"words = set(data[\"Word\"].values)\n",
"words = list(words.union(('<PAD>','<UNK>')))\n",
"iob_tags = list(set(data[\"Tag\"].values))\n",
"simplified_tags = list(set([simplify_tag(x) for x in iob_tags ]))\n",
"\n",
"vocab = {w : i for i ,w in enumerate(words)}\n",
"iob_tags_map = {t : i for i ,t in enumerate(iob_tags)}\n",
"simplified_tags_map = {t: i for i, t in enumerate(simplified_tags)}\n",
"\n",
"X = [[vocab[w[0]] for w in s] for s in sentences]\n",
"y = [[ (iob_tags_map[w[1]],\n",
" simplified_tags_map[simplify_tag(w[1])]) for w in s] for s in sentences]\n",
"\n",
"\n",
"t_sentences, v_sentences, t_labels, v_labels = train_test_split(X,y,test_size=0.1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def split_labels(labels):\n",
" l1 = []\n",
" l2 = []\n",
" for row in labels:\n",
" l1.append([x[0] for x in row])\n",
" l2.append([x[1] for x in row])\n",
" return l1,l2"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"iob_t_labels,simplified_t_labels = split_labels(t_labels)\n",
"iob_v_labels,simplified_v_labels = split_labels(v_labels)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Full tag map using IOB: {'O': 0, 'B-per': 1, 'B-nat': 2, 'B-eve': 3, 'I-nat': 4, 'I-tim': 5, 'I-eve': 6, 'B-geo': 7, 'B-art': 8, 'I-geo': 9, 'B-org': 10, 'I-art': 11, 'I-gpe': 12, 'I-org': 13, 'B-tim': 14, 'I-per': 15, 'B-gpe': 16}\n",
"Simplified tag map without IOB: {'tim': 0, 'O': 1, 'art': 2, 'org': 3, 'gpe': 4, 'per': 5, 'eve': 6, 'geo': 7, 'nat': 8}\n",
"Row of IOB values: [0, 14, 0, 0, 10, 13, 0, 14, 5, 0, 10, 14, 5, 5, 0]\n",
"Row of Simplified values: [1, 0, 1, 1, 3, 3, 1, 0, 0, 1, 3, 0, 0, 0, 1]\n",
"Row of IOB tags: ['O', 'B-tim', 'O', 'O', 'B-org', 'I-org', 'O', 'B-tim', 'I-tim', 'O', 'B-org', 'B-tim', 'I-tim', 'I-tim', 'O']\n",
"Row of Simplified tags: ['O', 'tim', 'O', 'O', 'org', 'org', 'O', 'tim', 'tim', 'O', 'org', 'tim', 'tim', 'tim', 'O']\n"
]
}
],
"source": [
"print(\"Full tag map using IOB: %s\" % iob_tags_map)\n",
"print(\"Simplified tag map without IOB: %s\" % simplified_tags_map)\n",
"print(\"Row of IOB values: %s\" % iob_t_labels[0])\n",
"print(\"Row of Simplified values: %s\" % simplified_t_labels[0])\n",
"print(\"Row of IOB tags: %s\" % [iob_tags[a] for a in iob_t_labels[0]])\n",
"print(\"Row of Simplified tags: %s\" % [simplified_tags[a] for a in simplified_t_labels[0]])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"iob_tag_count = len(iob_tags)\n",
"simplified_tag_count = len(simplified_tags)\n",
"words_count = len(words)\n",
"batch_size = 64\n",
"d_model = 300\n",
"train_steps = 500"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def Model(tag_count, vocab_size=words_count, d_model=d_model):\n",
" model = tl.Serial(\n",
" tl.Embedding(vocab_size, d_model),\n",
" tl.LSTM(d_model),\n",
" tl.Dense(tag_count),\n",
" tl.LogSoftmax()\n",
" )\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def data_generator(batch_size, x, y, pad, shuffle=False, verbose=False):\n",
" '''\n",
" Input:\n",
" batch_size - integer describing the batch size\n",
" x - list containing sentences where words are represented as integers\n",
" y - list containing tags associated with the sentences\n",
" shuffle - Shuffle the data order\n",
" pad - an integer representing a pad character\n",
" verbose - Print information during runtime\n",
" Output:\n",
" a tuple containing 2 elements:\n",
" X - np.ndarray of dim (batch_size, max_len) of padded sentences\n",
" Y - np.ndarray of dim (batch_size, max_len) of tags associated with the sentences in X\n",
" '''\n",
" num_lines = len(x)\n",
" lines_index = [*range(num_lines)]\n",
" if shuffle:\n",
" rnd.shuffle(lines_index)\n",
" index = 0 # tracks current location in x, y\n",
" while True:\n",
" buffer_x = [0] * batch_size # Temporal array to store the raw x data for this batch\n",
" buffer_y = [0] * batch_size # Temporal array to store the raw y data for this batch\n",
" max_len = 0\n",
" for i in range(batch_size):\n",
" if index >= num_lines:\n",
" index = 0\n",
" if shuffle:\n",
" rnd.shuffle(lines_index)\n",
" buffer_x[i] = x[lines_index[index]]\n",
" buffer_y[i] = y[lines_index[index]]\n",
" lenx = len(buffer_x[i])\n",
" if lenx > max_len:\n",
" max_len = lenx\n",
" index += 1\n",
" X = np.full((batch_size,max_len),pad)\n",
" Y = np.full((batch_size,max_len),pad)\n",
" for i in range(batch_size):\n",
" x_i = buffer_x[i]\n",
" y_i = buffer_y[i]\n",
" for j in range(len(x_i)):\n",
" X[i, j] = x_i[j]\n",
" Y[i, j] = y_i[j]\n",
" if verbose: print(\"X shape %s, first entry: %s\" % (X.shape,[words[i] for i in X[0]]))\n",
" yield((X,Y))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def train_model(model, train_generator, eval_generator, train_steps=1, output_dir='.'):\n",
" '''\n",
" Input:\n",
" NER - the model you are building\n",
" train_generator - The data generator for training examples\n",
" eval_generator - The data generator for validation examples,\n",
" train_steps - number of training steps\n",
" output_dir - folder to save your model\n",
" Output:\n",
" training_loop - a trax supervised training Loop\n",
" '''\n",
" train_task = training.TrainTask(\n",
" train_generator,\n",
" loss_layer = tl.CrossEntropyLoss(),\n",
" optimizer = trax.optimizers.Adam(0.01),\n",
" )\n",
"\n",
" eval_task = training.EvalTask(\n",
" labeled_data = eval_generator,\n",
" metrics = [tl.CrossEntropyLoss(), tl.Accuracy()],\n",
" n_eval_batches = 10\n",
" )\n",
"\n",
" training_loop = training.Loop(\n",
" model,\n",
" train_task,\n",
" eval_tasks = eval_task,\n",
" output_dir = output_dir)\n",
"\n",
" training_loop.run(n_steps = train_steps)\n",
" return training_loop\n",
"\n",
"train_generator1 = trax.data.inputs.add_loss_weights(\n",
" data_generator(batch_size, t_sentences, simplified_t_labels, vocab['<PAD>'], True),\n",
" id_to_mask=vocab['<PAD>'])\n",
"\n",
"eval_generator1 = trax.data.inputs.add_loss_weights(\n",
" data_generator(batch_size, v_sentences, simplified_v_labels, vocab['<PAD>'], shuffle=False,verbose=False),\n",
" id_to_mask=vocab['<PAD>'])\n",
"\n",
"train_generator2 = trax.data.inputs.add_loss_weights(\n",
" data_generator(batch_size, t_sentences, iob_t_labels, vocab['<PAD>'], True),\n",
" id_to_mask=vocab['<PAD>'])\n",
"\n",
"eval_generator2 = trax.data.inputs.add_loss_weights(\n",
" data_generator(batch_size, v_sentences, iob_v_labels, vocab['<PAD>'], shuffle=False,verbose=False),\n",
" id_to_mask=vocab['<PAD>'])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"simplified_model_dir = 'simplified_model'\n",
"iob_model_dir = 'iob_model'\n",
"simplified_model_file = '%s/model.pkl.gz' % simplified_model_dir\n",
"iob_model_file = '%s/model.pkl.gz' % iob_model_dir\n",
"\n",
"for x in [simplified_model_file,iob_model_file]:\n",
" if os.path.isfile(x): os.remove(x)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"#####################################\n",
"##### TRAINING SIMPLIFIED MODEL #####\n",
"#####################################\n",
"\n",
"Step 1: Total number of trainable weights: 11277909\n",
"Step 1: Ran 1 train steps in 6.22 secs\n",
"Step 1: train CrossEntropyLoss | 1.78805971\n",
"Step 1: eval CrossEntropyLoss | 1.55769454\n",
"Step 1: eval Accuracy | 0.84874840\n",
"\n",
"Step 100: Ran 99 train steps in 108.00 secs\n",
"Step 100: train CrossEntropyLoss | 0.41855213\n",
"Step 100: eval CrossEntropyLoss | 0.19427041\n",
"Step 100: eval Accuracy | 0.94703357\n",
"\n",
"Step 200: Ran 100 train steps in 50.29 secs\n",
"Step 200: train CrossEntropyLoss | 0.15863052\n",
"Step 200: eval CrossEntropyLoss | 0.13855599\n",
"Step 200: eval Accuracy | 0.95947566\n",
"\n",
"Step 300: Ran 100 train steps in 33.70 secs\n",
"Step 300: train CrossEntropyLoss | 0.13378948\n",
"Step 300: eval CrossEntropyLoss | 0.13570744\n",
"Step 300: eval Accuracy | 0.95844606\n",
"\n",
"Step 400: Ran 100 train steps in 40.36 secs\n",
"Step 400: train CrossEntropyLoss | 0.12307262\n",
"Step 400: eval CrossEntropyLoss | 0.11592426\n",
"Step 400: eval Accuracy | 0.96128523\n",
"\n",
"Step 500: Ran 100 train steps in 36.53 secs\n",
"Step 500: train CrossEntropyLoss | 0.11574913\n",
"Step 500: eval CrossEntropyLoss | 0.10583764\n",
"Step 500: eval Accuracy | 0.96447751\n",
"##############################\n",
"##### TRAINING IOB MODEL #####\n",
"##############################\n",
"\n",
"Step 1: Total number of trainable weights: 11280317\n",
"Step 1: Ran 1 train steps in 7.30 secs\n",
"Step 1: train CrossEntropyLoss | 3.58617210\n",
"Step 1: eval CrossEntropyLoss | 1.32238330\n",
"Step 1: eval Accuracy | 0.84874840\n",
"\n",
"Step 100: Ran 99 train steps in 118.14 secs\n",
"Step 100: train CrossEntropyLoss | 0.44481519\n",
"Step 100: eval CrossEntropyLoss | 0.22768956\n",
"Step 100: eval Accuracy | 0.94204758\n",
"\n",
"Step 200: Ran 100 train steps in 48.78 secs\n",
"Step 200: train CrossEntropyLoss | 0.18421331\n",
"Step 200: eval CrossEntropyLoss | 0.15019632\n",
"Step 200: eval Accuracy | 0.95503212\n",
"\n",
"Step 300: Ran 100 train steps in 43.62 secs\n",
"Step 300: train CrossEntropyLoss | 0.15317088\n",
"Step 300: eval CrossEntropyLoss | 0.15830390\n",
"Step 300: eval Accuracy | 0.95257186\n",
"\n",
"Step 400: Ran 100 train steps in 41.65 secs\n",
"Step 400: train CrossEntropyLoss | 0.14070258\n",
"Step 400: eval CrossEntropyLoss | 0.13295509\n",
"Step 400: eval Accuracy | 0.95788263\n",
"\n",
"Step 500: Ran 100 train steps in 43.64 secs\n",
"Step 500: train CrossEntropyLoss | 0.13118610\n",
"Step 500: eval CrossEntropyLoss | 0.12584044\n",
"Step 500: eval Accuracy | 0.95981656\n"
]
}
],
"source": [
"print(\"#####################################\")\n",
"print(\"##### TRAINING SIMPLIFIED MODEL #####\")\n",
"print(\"#####################################\")\n",
"simplified_training_loop = train_model(Model(simplified_tag_count), train_generator1, eval_generator1, train_steps, simplified_model_dir)\n",
"\n",
"print(\"##############################\")\n",
"print(\"##### TRAINING IOB MODEL #####\")\n",
"print(\"##############################\")\n",
"iob_training_loop = train_model(Model(iob_tag_count), train_generator2, eval_generator2, train_steps, iob_model_dir)\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def predict(sentence, model, vocab, tag_map):\n",
" s = [vocab[token] if token in vocab else vocab['<UNK>'] for token in sentence.split(' ')]\n",
" batch_data = np.ones((1, len(s)))\n",
" batch_data[0][:] = s\n",
" sentence = np.array(batch_data).astype(int)\n",
" output = model(sentence)\n",
" outputs = np.argmax(output, axis=2)\n",
" labels = list(tag_map.keys())\n",
" pred = []\n",
" for i in range(len(outputs[0])):\n",
" idx = outputs[0][i]\n",
" pred_label = labels[idx]\n",
" pred.append(pred_label)\n",
" return pred"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"shape11 = trax.shapes.ShapeDtype((1, 1), dtype=np.int32)\n",
"\n",
"simplified_model = Model(simplified_tag_count)\n",
"simplified_model.init_from_file(simplified_model_file, weights_only=True, input_signature=shape11)\n",
"\n",
"iob_model = Model(iob_tag_count)\n",
"iob_model.init_from_file(iob_model_file, weights_only=True, input_signature=shape11)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"test_sentence = \"Peter Navarro, the White House director of trade and manufacturing policy of U.S, said in an interview on Sunday morning that the White House was working to prepare for the possibility of a second wave of the coronavirus in the fall, though he said it wouldn’t necessarily come\"\n",
"test_words = test_sentence.split(' ')\n",
"s = [vocab[token] if token in vocab else vocab['<UNK>'] for token in test_words]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Peter: per B-per\n",
" Navarro,: per I-per\n",
" White: org B-org\n",
" House: org I-org\n",
" Sunday: tim B-tim\n",
" morning: tim I-tim\n",
" White: org B-org\n",
" House: org I-org\n"
]
}
],
"source": [
"simplified_predictions = predict(test_sentence, simplified_model, vocab, simplified_tags_map)\n",
"iob_predictions = predict(test_sentence, iob_model, vocab, iob_tags_map)\n",
"for x,y,z in zip(test_words, simplified_predictions,iob_predictions ):\n",
" if y != 'O' or z!= 'O':\n",
" print(\"%15s: %s %s\" % (x,y,z))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment