Skip to content

Instantly share code, notes, and snippets.

@napsternxg
Last active November 3, 2020 02:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save napsternxg/4b219d1f3b210aad923b74df219ba6c4 to your computer and use it in GitHub Desktop.
Save napsternxg/4b219d1f3b210aad923b74df219ba6c4 to your computer and use it in GitHub Desktop.
Evaluate NER predictions
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def split_label(label):\n",
" if label == \"O\":\n",
" return label, None\n",
" return tuple(label.split(\"-\", 1))\n",
"\n",
"assert split_label(\"O\") == (\"O\", None)\n",
"assert split_label(\"B-PER\") == (\"B\", \"PER\")\n",
"assert split_label(\"B-PER-S\") == (\"B\", \"PER-S\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def seq_metrics(source, target):\n",
" \"\"\"Evaluate the source based on the target. \n",
" This means if an entity is present in target only then will it be evaluated in source.\"\"\"\n",
" correct = Counter()\n",
" total = Counter()\n",
" # Append O to ensure we don't need to check for the last entity outside the loop\n",
" source += [\"O\"]\n",
" target += [\"O\"]\n",
" i = -1\n",
" while i < len(target)-1:\n",
" i += 1\n",
" boundary_t, category_t = split_label(target[i])\n",
" boundary_s, category_s = split_label(source[i])\n",
" correct_entity = 0\n",
" # Begin evaluation if you encounter an entity in target\n",
" if boundary_t == \"B\":\n",
" # print(target[i], source[i])\n",
" if target[i] == source[i]:\n",
" correct_entity = 1\n",
" # print(\"\\t\", target[i+1], source[i+1])\n",
" boundary_next_t, category_next_t = split_label(target[i+1])\n",
" boundary_next_s, category_next_s = split_label(source[i+1])\n",
" # Iterate till we are in the same category\n",
" while ((boundary_next_t == \"I\" and category_t == category_next_t)\n",
" or (boundary_next_s == \"I\" and category_s == category_next_s)):\n",
" if target[i+1] != source[i+1]:\n",
" correct_entity = 0\n",
" i += 1\n",
" # print(\"\\t\", target[i+1], source[i+1])\n",
" boundary_next_t, category_next_t = split_label(target[i+1])\n",
" boundary_next_s, category_next_s = split_label(source[i+1])\n",
" correct[category_t] += correct_entity\n",
" total[category_t] += 1\n",
" return correct, total "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"Y_true = [[\"B-PER\", \"O\", \"B-LOC\"],\t[\"O\", \"B-LOC\", \"I-LOC\"], [\"B-PER\", \"O\", \"B-LOC\", \"O\", \"B-LOC\", \"I-LOC\"]] \n",
"Y_pred = [[\"B-PER\", \"I-PER\", \"O\"],\t[\"O\", \"B-LOC\", \"I-LOC\"], [\"B-PER\", \"O\", \"O\", \"O\", \"B-LOC\", \"I-LOC\"]] "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Counter({'PER': 0, 'LOC': 0}) Counter({'PER': 1}) Counter({'PER': 1, 'LOC': 1})\n",
"Counter({'LOC': 2}) Counter({'LOC': 1}) Counter({'LOC': 1})\n",
"Counter({'PER': 2, 'LOC': 2}) Counter({'PER': 1, 'LOC': 1}) Counter({'LOC': 2, 'PER': 1})\n"
]
}
],
"source": [
"for y_true, y_pred in zip(Y_true, Y_pred):\n",
" correct_predictions, all_predictions = seq_metrics(y_true, y_pred)\n",
" correct_truths, all_truths = seq_metrics(y_pred, y_true)\n",
" correct_predictions.update(correct_truths)\n",
" print(correct_predictions, all_predictions, all_truths)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def eval_metrics(Y_true, Y_pred):\n",
" correct_predictions = Counter()\n",
" all_predictions = Counter()\n",
" all_truths = Counter()\n",
" for y_true, y_pred in zip(Y_true, Y_pred):\n",
" cp, ap = seq_metrics(y_true, y_pred)\n",
" ct, at = seq_metrics(y_pred, y_true)\n",
" correct_predictions.update(cp | ct)\n",
" all_predictions.update(ap)\n",
" all_truths.update(at)\n",
" return correct_predictions, all_predictions, all_truths"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Counter({'LOC': 2, 'PER': 1}),\n",
" Counter({'PER': 2, 'LOC': 2}),\n",
" Counter({'PER': 2, 'LOC': 4}))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"correct_predictions, all_predictions, all_truths = eval_metrics(Y_true, Y_pred)\n",
"correct_predictions, all_predictions, all_truths"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" \t precision\t recall\t f1_score\t #true_preds\t #preds\t #truths\n",
"LOC \t 1.000\t 0.500\t 0.667\t 2\t 2\t 4\n",
"PER \t 0.500\t 0.500\t 0.500\t 1\t 2\t 2\n"
]
}
],
"source": [
"print(\"{:5s}\\t{:>10s}\\t{:>10s}\\t{:>10s}\\t{:>12s}\\t{:>10s}\\t{:>10s}\".format(\n",
" \"\", \"precision\", \"recall\", \"f1_score\", \"#true_preds\", \"#preds\", \"#truths\"\n",
"))\n",
"for k in correct_predictions | all_predictions | all_truths:\n",
" precision = correct_predictions[k]/all_predictions.get(k, 1)\n",
" recall = correct_predictions[k]/all_truths.get(k, 1)\n",
" f1_score = 2*precision*recall/((precision+recall) if (precision+recall) != 0 else 1)\n",
" print(\"{:5}\\t{:10.3f}\\t{:10.3f}\\t{:10.3f}\\t{:12d}\\t{:10d}\\t{:10d}\".format(\n",
" k, precision, recall, f1_score,\n",
" correct_predictions[k],\n",
" all_predictions[k],\n",
" all_truths[k]\n",
" ))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:root] *",
"language": "python",
"name": "conda-root-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from collections import Counter
def split_label(y):
label = y.rsplit("-", 1)
if len(label) == 2:
return tuple(label)
return (label[0], None)
def extract_entities(Y):
i = 0
while i < len(Y):
y = Y[i]
boundary, tag = split_label(y)
start = i
if boundary == "B":
i += 1
while i < len(Y) and split_label(Y[i])[0] == "I":
i += 1
yield (start, i, tag)
else:
i += 1
def seq_metrics(Y_true, Y_pred):
true_counts = Counter()
pred_counts = Counter()
true_pred_counts = Counter()
for y_true, y_pred in zip(Y_true, Y_pred):
true_entities = set(extract_entities(y_true))
pred_entities = set(extract_entities(y_pred))
true_pred_entities = true_entities & pred_entities
y_true_counts = Counter((x[2] for x in true_entities))
y_pred_counts = Counter((x[2] for x in pred_entities))
y_true_pred_counts = Counter((x[2] for x in true_pred_entities))
true_counts += y_true_counts
pred_counts += y_pred_counts
true_pred_counts += y_true_pred_counts
return true_counts, pred_counts, true_pred_counts
list(extract_entities(Y_pred[0]))
seq_metrics(Y_true, Y_pred)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment