Skip to content

Instantly share code, notes, and snippets.

@adiprasad
Created September 8, 2019 21:46
Show Gist options
  • Save adiprasad/6305306e9ea55732e38c038880742b4c to your computer and use it in GitHub Desktop.
Save adiprasad/6305306e9ea55732e38c038880742b4c to your computer and use it in GitHub Desktop.
CRF model training and prediction on Moby Dick tokens
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import pycrfsuite\n",
"from collections import defaultdict\n",
"import os\n",
"from sklearn.preprocessing import LabelBinarizer\n",
"from sklearn.metrics import accuracy_score\n",
"from itertools import chain"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper functions for converting data in suitable input format for CRFSuite "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"def return_file_path(dir_path, file_path):\n",
"\treturn os.path.join(dir_path, file_path)\n",
"\n",
"\n",
"def convert_x_i(x_i):\n",
"\tfeatures_dict = defaultdict()\n",
"\tnum_features = len(x_i)\n",
"\n",
"\tdd = defaultdict()\n",
"\n",
"\tdd['bias'] = 1.0\n",
"\n",
"\tfor idx in range(num_features):\n",
"\t\tpixel_i = \"pixel_\" + str(idx)\n",
"\t\tdd[pixel_i] = x_i[idx]\n",
"\n",
"\treturn dd\n",
"\n",
"\n",
"def convert_x(file_path):\n",
"\tx_arr = []\n",
"\n",
"\twith open(file_path, \"r\") as x_file:\n",
"\t\tfor x_i_str in x_file:\n",
"\t\t\tx_i_str = x_i_str.strip()\n",
"\t\t\tx_i_str_arr = x_i_str.split()\n",
"\t\t\tx_i = [float(x_ij) for x_ij in x_i_str_arr]\n",
"\n",
"\t\t\tx_i_features = convert_x_i(x_i)\n",
"\n",
"\t\t\tx_arr.append(x_i_features)\n",
"\n",
"\treturn x_arr\n",
"\n",
"\n",
"def prepare_data(data_dir, mode = \"train\"):\n",
"\tfile_dir = os.path.join(data_dir, \"{}_words\".format(mode))\n",
"\twords_file = return_file_path(data_dir, \"{}_words.txt\".format(mode))\n",
"\n",
"\tX = []\n",
"\tY = []\n",
"\n",
"\twith open(words_file) as f:\n",
"\t\tfor line in f:\n",
"\t\t\tline = line.strip()\n",
"\t\t\ti, word = line.split()\n",
"\n",
"\t\t\tx_i_file_path = return_file_path(file_dir, \"img_{}.txt\".format(i))\n",
"\t\t\tx_i_arr = convert_x(x_i_file_path)\n",
"\n",
"\t\t\ty_i_arr = list(word)\n",
"\n",
"\t\t\tX.append(x_i_arr)\n",
"\t\t\tY.append(y_i_arr)\n",
"\n",
"\treturn X, Y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Functions to train and test the model"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def train_model(X, Y, max_iter_count, model_store = \"handwriting-reco.crfsuite\"):\n",
"\ttrainer = pycrfsuite.Trainer(verbose=False)\n",
"\n",
"\tfor xseq, yseq in zip(X, Y):\n",
"\t\ttrainer.append(xseq, yseq)\n",
"\n",
"\ttrainer.set_params({\n",
"\t 'c1': 1.0, # coefficient for L1 penalty\n",
"\t 'c2': 1e-3, # coefficient for L2 penalty\n",
"\t 'max_iterations': max_iter_count, # stop earlier\n",
"\n",
"\t # include transitions that are possible, but not observed\n",
"\t 'feature.possible_transitions': True\n",
"\t})\n",
"\n",
"\ttrainer.train(model_store)\n",
"\n",
"\tprint(trainer.logparser.last_iteration)\n",
"\n",
"\n",
"def get_preds(X, model_store = \"handwriting-reco.crfsuite\"):\n",
"\ttagger = pycrfsuite.Tagger()\n",
"\ttagger.open(model_store)\n",
"\tY_pred = [tagger.tag(x) for x in X]\n",
"\n",
"\treturn Y_pred\n",
"\n",
"\n",
"def test_model(X_test, Y_test):\n",
"\tY_test_pred = get_preds(X_test)\n",
"\t\n",
"\tlb = LabelBinarizer()\n",
"\t\n",
"\ty_test_combined = lb.fit_transform(list(chain.from_iterable(Y_test)))\n",
"\ty_pred_combined = lb.transform(list(chain.from_iterable(Y_test_pred)))\n",
"\n",
"\tprint \"Test accuracy : {}\".format(accuracy_score(y_test_combined, y_pred_combined))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model for 500 iterations"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'loss': 51853.452636, 'error_norm': 57.176499, 'linesearch_trials': 2, 'active_features': 3636, 'num': 500, 'time': 1.73, 'scores': {}, 'linesearch_step': 0.5, 'feature_norm': 85.522574}\n",
"Training successful with 500 iterations.. Enable verbose in the CRF model above and re-run to track progress\n"
]
}
],
"source": [
"data_dir = './data'\n",
"X_train, Y_train = prepare_data(data_dir)\n",
"train_model(X_train, Y_train, 500)\n",
"\n",
"print \"Training successful with 500 iterations.. Enable verbose in the CRF model above and re-run to track progress\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test the model"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test accuracy : 0.853043730931\n"
]
}
],
"source": [
"X_test, Y_test = prepare_data(data_dir, mode = \"test\")\n",
"test_model(X_test, Y_test)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.16"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment