Last active
November 3, 2020 02:36
-
-
Save napsternxg/4b219d1f3b210aad923b74df219ba6c4 to your computer and use it in GitHub Desktop.
Evaluate NER predictions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from collections import Counter" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def split_label(label):\n", | |
" if label == \"O\":\n", | |
" return label, None\n", | |
" return tuple(label.split(\"-\", 1))\n", | |
"\n", | |
"assert split_label(\"O\") == (\"O\", None)\n", | |
"assert split_label(\"B-PER\") == (\"B\", \"PER\")\n", | |
"assert split_label(\"B-PER-S\") == (\"B\", \"PER-S\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def seq_metrics(source, target):\n", | |
" \"\"\"Evaluate the source based on the target. \n", | |
" This means if an entity is present in target only then will it be evaluated in source.\"\"\"\n", | |
" correct = Counter()\n", | |
" total = Counter()\n", | |
" # Append O to ensure we don't need to check for the last entity outside the loop\n", | |
" source += [\"O\"]\n", | |
" target += [\"O\"]\n", | |
" i = -1\n", | |
" while i < len(target)-1:\n", | |
" i += 1\n", | |
" boundary_t, category_t = split_label(target[i])\n", | |
" boundary_s, category_s = split_label(source[i])\n", | |
" correct_entity = 0\n", | |
" # Begin evaluation if you encounter an entity in target\n", | |
" if boundary_t == \"B\":\n", | |
" # print(target[i], source[i])\n", | |
" if target[i] == source[i]:\n", | |
" correct_entity = 1\n", | |
" # print(\"\\t\", target[i+1], source[i+1])\n", | |
" boundary_next_t, category_next_t = split_label(target[i+1])\n", | |
" boundary_next_s, category_next_s = split_label(source[i+1])\n", | |
" # Iterate till we are in the same category\n", | |
" while ((boundary_next_t == \"I\" and category_t == category_next_t)\n", | |
" or (boundary_next_s == \"I\" and category_s == category_next_s)):\n", | |
" if target[i+1] != source[i+1]:\n", | |
" correct_entity = 0\n", | |
" i += 1\n", | |
" # print(\"\\t\", target[i+1], source[i+1])\n", | |
" boundary_next_t, category_next_t = split_label(target[i+1])\n", | |
" boundary_next_s, category_next_s = split_label(source[i+1])\n", | |
" correct[category_t] += correct_entity\n", | |
" total[category_t] += 1\n", | |
" return correct, total " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"Y_true = [[\"B-PER\", \"O\", \"B-LOC\"],\t[\"O\", \"B-LOC\", \"I-LOC\"], [\"B-PER\", \"O\", \"B-LOC\", \"O\", \"B-LOC\", \"I-LOC\"]] \n", | |
"Y_pred = [[\"B-PER\", \"I-PER\", \"O\"],\t[\"O\", \"B-LOC\", \"I-LOC\"], [\"B-PER\", \"O\", \"O\", \"O\", \"B-LOC\", \"I-LOC\"]] " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Counter({'PER': 0, 'LOC': 0}) Counter({'PER': 1}) Counter({'PER': 1, 'LOC': 1})\n", | |
"Counter({'LOC': 2}) Counter({'LOC': 1}) Counter({'LOC': 1})\n", | |
"Counter({'PER': 2, 'LOC': 2}) Counter({'PER': 1, 'LOC': 1}) Counter({'LOC': 2, 'PER': 1})\n" | |
] | |
} | |
], | |
"source": [ | |
"for y_true, y_pred in zip(Y_true, Y_pred):\n", | |
" correct_predictions, all_predictions = seq_metrics(y_true, y_pred)\n", | |
" correct_truths, all_truths = seq_metrics(y_pred, y_true)\n", | |
" correct_predictions.update(correct_truths)\n", | |
" print(correct_predictions, all_predictions, all_truths)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def eval_metrics(Y_true, Y_pred):\n", | |
" correct_predictions = Counter()\n", | |
" all_predictions = Counter()\n", | |
" all_truths = Counter()\n", | |
" for y_true, y_pred in zip(Y_true, Y_pred):\n", | |
" cp, ap = seq_metrics(y_true, y_pred)\n", | |
" ct, at = seq_metrics(y_pred, y_true)\n", | |
" correct_predictions.update(cp | ct)\n", | |
" all_predictions.update(ap)\n", | |
" all_truths.update(at)\n", | |
" return correct_predictions, all_predictions, all_truths" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(Counter({'LOC': 2, 'PER': 1}),\n", | |
" Counter({'PER': 2, 'LOC': 2}),\n", | |
" Counter({'PER': 2, 'LOC': 4}))" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"correct_predictions, all_predictions, all_truths = eval_metrics(Y_true, Y_pred)\n", | |
"correct_predictions, all_predictions, all_truths" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" \t precision\t recall\t f1_score\t #true_preds\t #preds\t #truths\n", | |
"LOC \t 1.000\t 0.500\t 0.667\t 2\t 2\t 4\n", | |
"PER \t 0.500\t 0.500\t 0.500\t 1\t 2\t 2\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"{:5s}\\t{:>10s}\\t{:>10s}\\t{:>10s}\\t{:>12s}\\t{:>10s}\\t{:>10s}\".format(\n", | |
" \"\", \"precision\", \"recall\", \"f1_score\", \"#true_preds\", \"#preds\", \"#truths\"\n", | |
"))\n", | |
"for k in correct_predictions | all_predictions | all_truths:\n", | |
" precision = correct_predictions[k]/all_predictions.get(k, 1)\n", | |
" recall = correct_predictions[k]/all_truths.get(k, 1)\n", | |
" f1_score = 2*precision*recall/((precision+recall) if (precision+recall) != 0 else 1)\n", | |
" print(\"{:5}\\t{:10.3f}\\t{:10.3f}\\t{:10.3f}\\t{:12d}\\t{:10d}\\t{:10d}\".format(\n", | |
" k, precision, recall, f1_score,\n", | |
" correct_predictions[k],\n", | |
" all_predictions[k],\n", | |
" all_truths[k]\n", | |
" ))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:root] *", | |
"language": "python", | |
"name": "conda-root-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
def split_label(y): | |
label = y.rsplit("-", 1) | |
if len(label) == 2: | |
return tuple(label) | |
return (label[0], None) | |
def extract_entities(Y): | |
i = 0 | |
while i < len(Y): | |
y = Y[i] | |
boundary, tag = split_label(y) | |
start = i | |
if boundary == "B": | |
i += 1 | |
while i < len(Y) and split_label(Y[i])[0] == "I": | |
i += 1 | |
yield (start, i, tag) | |
else: | |
i += 1 | |
def seq_metrics(Y_true, Y_pred): | |
true_counts = Counter() | |
pred_counts = Counter() | |
true_pred_counts = Counter() | |
for y_true, y_pred in zip(Y_true, Y_pred): | |
true_entities = set(extract_entities(y_true)) | |
pred_entities = set(extract_entities(y_pred)) | |
true_pred_entities = true_entities & pred_entities | |
y_true_counts = Counter((x[2] for x in true_entities)) | |
y_pred_counts = Counter((x[2] for x in pred_entities)) | |
y_true_pred_counts = Counter((x[2] for x in true_pred_entities)) | |
true_counts += y_true_counts | |
pred_counts += y_pred_counts | |
true_pred_counts += y_true_pred_counts | |
return true_counts, pred_counts, true_pred_counts | |
list(extract_entities(Y_pred[0])) | |
seq_metrics(Y_true, Y_pred) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment