Created
July 8, 2018 13:16
-
-
Save skystrife/2942c00ba09905320535b3e1e5784ee7 to your computer and use it in GitHub Desktop.
SIGIR18: Probabilistic Topic Models Hands-on Exercise 1 (Retrieval)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Exercise 1: Pseudo-feedback with Two-component Mixture Model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"First, let's import the Python bindings for MeTA:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import metapy" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"If you don't have `metapy` installed, you can install it with a\n", | |
"\n", | |
"```bash\n", | |
"pip install metapy\n", | |
"```\n", | |
"\n", | |
"on the command line on Linux, macOS, or Windows for either Python 2.7 or Python 3.x. (I will be using Python 3.6 in this tutorial.)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Double-check that you are running the latest version. Right now, that should be `0.2.10`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'0.2.10'" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"metapy.__version__" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now, let's set MeTA to log to standard error so we can see progress output for long-running commands. (Only do this once, or you'll get double the output.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"metapy.log_to_stderr()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now, let's download all of the files we need for the tutorial." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import urllib.request\n", | |
"import os\n", | |
"import tarfile\n", | |
"\n", | |
"if not os.path.exists('sigir18-tutorial.tar.gz'):\n", | |
" urllib.request.urlretrieve('https://meta-toolkit.org/data/2018-06-25/sigir18-tutorial.tar.gz',\n", | |
" 'sigir18-tutorial.tar.gz')\n", | |
" \n", | |
"if not os.path.exists('data'):\n", | |
" with tarfile.open('sigir18-tutorial.tar.gz', 'r:gz') as files:\n", | |
" files.extractall()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's index our data using the `InvertedIndex` format. In a search engine, we want to quickly determine what documents mention a specific query term, so the `InvertedIndex` stores a mapping from term to a list of documents that contain that term (along with how many times they do)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" \r", | |
" > Counting lines in file: [> ] 0% ETA 00:00:00 \r", | |
" > Counting lines in file: [=================================] 100% ETA 00:00:00 \n", | |
"1529953996: [info] Creating index: cranfield-idx/inv (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:119)\n", | |
" \r", | |
" > Tokenizing Docs: [> ] 0% ETA 00:00:00 \n", | |
"1529953996: [warning] Empty document (id = 470) generated! (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:228)\n", | |
" \n", | |
"1529953996: [warning] Empty document (id = 994) generated! (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:228)\n", | |
" \r", | |
" > Tokenizing Docs: [========================================] 100% ETA 00:00:00 \n", | |
" \r", | |
" > Merging: [> ] 0% ETA 00:00:00 \r", | |
" > Merging: [================================================] 100% ETA 00:00:00 \n", | |
"1529953996: [info] Created uncompressed postings file cranfield-idx/inv/postings.index (197.770000 KB) (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:148)\n", | |
" \r", | |
" > Compressing postings: [> ] 0% ETA 00:00:00 \r", | |
" > Compressing postings: [===================================] 100% ETA 00:00:00 \n", | |
"1529953996: [info] Created compressed postings file (168.060000 KB) (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:279)\n", | |
"1529953996: [info] Done creating index: cranfield-idx/inv (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:166)\n" | |
] | |
} | |
], | |
"source": [ | |
"inv_idx = metapy.index.make_inverted_index('cranfield.toml')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This may take a minute at first, since the index needs to be built. Subsequent calls to `make_inverted_index` with this config file will simply load the index, which will not take any time.\n", | |
"\n", | |
"Here's how we can interact with the index object:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1400" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inv_idx.num_docs()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"4137" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inv_idx.unique_terms()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"87.17857360839844" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inv_idx.avg_doc_length()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"122050" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"inv_idx.total_corpus_terms()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let's search our index. We'll start by creating a ranker:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ranker = metapy.index.DirichletPrior()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we need a query. Let's create an example query." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"query = metapy.index.Document()\n", | |
"query.content(\"flow equilibrium\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we can use this to search our index like so:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[(235, 1.2931444644927979),\n", | |
" (1251, 1.256299614906311),\n", | |
" (316, 1.1081531047821045),\n", | |
" (655, 1.0878994464874268),\n", | |
" (574, 1.076568841934204)]" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"top_docs = ranker.score(inv_idx, query, num_results=5)\n", | |
"top_docs" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We are returned a ranked list of *(doc_id, score)* pairs. The scores are from the ranker, which in this case was Okapi BM25. Since the `tutorial.toml` file we created for the cranfield dataset has `store-full-text = true`, we can verify the content of our top documents by inspecting the document metadata field \"content\"." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1. criteria for thermodynamic equilibrium in gas flow . when gases flow at high velocity, the rates of internal processes may not be fast enough to maintain thermodynamic equilibrium . by defining quasi-equilibrium in flow as the condition in which the...\n", | |
"\n", | |
"2. on the approach to chemical and vibrational equilibrium behind a strong normal shock wave . the concurrent approach to chemical and vibrational equilibrium of a pure diatomic gas passing through a strong normal shock wave is investigated . it is dem...\n", | |
"\n", | |
"3. non-equilibrium flow of an ideal dissociating gas . the theory of an'ideal dissociating'gas developed by lighthill/1957/for conditions of thermodynamic equilibrium is extended to non-equilibrium conditions by postulating a simple rate equation for th...\n", | |
"\n", | |
"4. departure from dissociation equilibrium in a hypersonic nozzle . the equations of motion for the flow of an ideal dissociating gas through a nearly conical nozzle have been solved numerically, assuming a simple equation for the rate of dissociation, ...\n", | |
"\n", | |
"5. atomic recombination in a hypersonic wind tunnel nozzle . the flow of an ideal dissociating gas through a nearly conical nozzle is considered . the equations of one-dimensional motion are solved numerically assuming a simple rate equation together wi...\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"for num, (d_id, _) in enumerate(top_docs):\n", | |
" content = inv_idx.metadata(d_id).get('content')\n", | |
" print(\"{}. {}...\\n\".format(num + 1, content[0:250]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Since we have the queries file and relevance judgements, we can do an IR evaluation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"ev = metapy.index.IREval('cranfield.toml')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We will loop over the queries file and add each result to the `IREval` object `ev`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Query 1 average precision: 0.19\n", | |
"Query 2 average precision: 0.5433333333333332\n", | |
"Query 3 average precision: 0.6541666666666666\n", | |
"Query 4 average precision: 0.5\n", | |
"Query 5 average precision: 0.35\n", | |
"Query 6 average precision: 0.0625\n", | |
"Query 7 average precision: 0.10666666666666666\n", | |
"Query 8 average precision: 0.0\n", | |
"Query 9 average precision: 0.6984126984126983\n", | |
"Query 10 average precision: 0.0625\n", | |
"Query 11 average precision: 0.028571428571428574\n", | |
"Query 12 average precision: 0.18\n", | |
"Query 13 average precision: 0.0\n", | |
"Query 14 average precision: 0.5\n", | |
"Query 15 average precision: 0.7\n", | |
"Query 16 average precision: 0.08333333333333333\n", | |
"Query 17 average precision: 0.07142857142857142\n", | |
"Query 18 average precision: 0.3333333333333333\n", | |
"Query 19 average precision: 0.0\n", | |
"Query 20 average precision: 0.2685185185185185\n", | |
"Query 21 average precision: 0.0\n", | |
"Query 22 average precision: 0.0\n", | |
"Query 23 average precision: 0.04722222222222222\n", | |
"Query 24 average precision: 0.3333333333333333\n", | |
"Query 25 average precision: 0.6507936507936507\n", | |
"Query 26 average precision: 0.13888888888888887\n", | |
"Query 27 average precision: 0.0\n", | |
"Query 28 average precision: 0.0\n", | |
"Query 29 average precision: 0.14166666666666666\n", | |
"Query 30 average precision: 0.11904761904761904\n", | |
"Query 31 average precision: 0.0\n", | |
"Query 32 average precision: 0.023809523809523808\n", | |
"Query 33 average precision: 0.5317460317460317\n", | |
"Query 34 average precision: 0.20555555555555557\n", | |
"Query 35 average precision: 0.0\n", | |
"Query 36 average precision: 0.25\n", | |
"Query 37 average precision: 0.05555555555555555\n", | |
"Query 38 average precision: 0.01111111111111111\n", | |
"Query 39 average precision: 0.12857142857142856\n", | |
"Query 40 average precision: 0.07222222222222222\n", | |
"Query 41 average precision: 0.7666666666666666\n", | |
"Query 42 average precision: 0.1638095238095238\n", | |
"Query 43 average precision: 0.48333333333333334\n", | |
"Query 44 average precision: 0.0\n", | |
"Query 45 average precision: 0.155\n", | |
"Query 46 average precision: 0.5380952380952382\n", | |
"Query 47 average precision: 0.15083333333333332\n", | |
"Query 48 average precision: 0.1433333333333333\n", | |
"Query 49 average precision: 0.0\n", | |
"Query 50 average precision: 0.0\n", | |
"Query 51 average precision: 0.16666666666666666\n", | |
"Query 52 average precision: 0.027777777777777776\n", | |
"Query 53 average precision: 0.15857142857142856\n", | |
"Query 54 average precision: 0.05555555555555555\n", | |
"Query 55 average precision: 0.11499999999999999\n", | |
"Query 56 average precision: 0.05\n", | |
"Query 57 average precision: 0.01\n", | |
"Query 58 average precision: 0.15873015873015872\n", | |
"Query 59 average precision: 0.03571428571428571\n", | |
"Query 60 average precision: 0.4666666666666667\n", | |
"Query 61 average precision: 0.45999999999999996\n", | |
"Query 62 average precision: 0.0\n", | |
"Query 63 average precision: 0.0\n", | |
"Query 64 average precision: 0.5\n", | |
"Query 65 average precision: 0.0\n", | |
"Query 66 average precision: 0.0\n", | |
"Query 67 average precision: 0.0365079365079365\n", | |
"Query 68 average precision: 0.1\n", | |
"Query 69 average precision: 0.06666666666666667\n", | |
"Query 70 average precision: 0.1\n", | |
"Query 71 average precision: 0.03125\n", | |
"Query 72 average precision: 0.0\n", | |
"Query 73 average precision: 0.2861111111111111\n", | |
"Query 74 average precision: 0.020833333333333332\n", | |
"Query 75 average precision: 0.03333333333333333\n", | |
"Query 76 average precision: 0.028571428571428574\n", | |
"Query 77 average precision: 0.28888888888888886\n", | |
"Query 78 average precision: 0.6666666666666666\n", | |
"Query 79 average precision: 0.0\n", | |
"Query 80 average precision: 0.0\n", | |
"Query 81 average precision: 0.25\n", | |
"Query 82 average precision: 0.09\n", | |
"Query 83 average precision: 0.0625\n", | |
"Query 84 average precision: 0.3\n", | |
"Query 85 average precision: 0.05\n", | |
"Query 86 average precision: 0.41666666666666663\n", | |
"Query 87 average precision: 0.0\n", | |
"Query 88 average precision: 0.594047619047619\n", | |
"Query 89 average precision: 0.075\n", | |
"Query 90 average precision: 0.19\n", | |
"Query 91 average precision: 0.0873015873015873\n", | |
"Query 92 average precision: 0.4726984126984126\n", | |
"Query 93 average precision: 0.5\n", | |
"Query 94 average precision: 0.4\n", | |
"Query 95 average precision: 0.5\n", | |
"Query 96 average precision: 0.4463095238095239\n", | |
"Query 97 average precision: 0.15416666666666665\n", | |
"Query 98 average precision: 0.0\n", | |
"Query 99 average precision: 0.19642857142857142\n", | |
"Query 100 average precision: 0.2185185185185185\n", | |
"Query 101 average precision: 0.6180555555555555\n", | |
"Query 102 average precision: 0.08333333333333333\n", | |
"Query 103 average precision: 0.05555555555555555\n", | |
"Query 104 average precision: 0.1\n", | |
"Query 105 average precision: 0.35333333333333333\n", | |
"Query 106 average precision: 0.24666666666666667\n", | |
"Query 107 average precision: 0.10476190476190476\n", | |
"Query 108 average precision: 0.6261904761904761\n", | |
"Query 109 average precision: 0.0\n", | |
"Query 110 average precision: 0.0\n", | |
"Query 111 average precision: 0.07619047619047618\n", | |
"Query 112 average precision: 0.3611111111111111\n", | |
"Query 113 average precision: 0.10416666666666666\n", | |
"Query 114 average precision: 0.0\n", | |
"Query 115 average precision: 0.0\n", | |
"Query 116 average precision: 0.03333333333333333\n", | |
"Query 117 average precision: 0.0\n", | |
"Query 118 average precision: 0.041666666666666664\n", | |
"Query 119 average precision: 0.5\n", | |
"Query 120 average precision: 0.17724867724867724\n", | |
"Query 121 average precision: 0.4768707482993197\n", | |
"Query 122 average precision: 0.05925925925925926\n", | |
"Query 123 average precision: 0.0\n", | |
"Query 124 average precision: 0.0\n", | |
"Query 125 average precision: 0.02\n", | |
"Query 126 average precision: 0.20833333333333331\n", | |
"Query 127 average precision: 0.025\n", | |
"Query 128 average precision: 0.0\n", | |
"Query 129 average precision: 0.4773809523809524\n", | |
"Query 130 average precision: 0.3933333333333333\n", | |
"Query 131 average precision: 0.09375\n", | |
"Query 132 average precision: 0.6592063492063491\n", | |
"Query 133 average precision: 0.12976190476190477\n", | |
"Query 134 average precision: 0.5\n", | |
"Query 135 average precision: 0.38690476190476186\n", | |
"Query 136 average precision: 0.03333333333333333\n", | |
"Query 137 average precision: 0.1875\n", | |
"Query 138 average precision: 0.25\n", | |
"Query 139 average precision: 0.0\n", | |
"Query 140 average precision: 0.10833333333333334\n", | |
"Query 141 average precision: 0.05555555555555555\n", | |
"Query 142 average precision: 0.0\n", | |
"Query 143 average precision: 0.6111111111111112\n", | |
"Query 144 average precision: 0.14722222222222223\n", | |
"Query 145 average precision: 0.07142857142857142\n", | |
"Query 146 average precision: 0.41666666666666663\n", | |
"Query 147 average precision: 0.14666666666666667\n", | |
"Query 148 average precision: 0.125\n", | |
"Query 149 average precision: 0.18285714285714286\n", | |
"Query 150 average precision: 1.0\n", | |
"Query 151 average precision: 0.0\n", | |
"Query 152 average precision: 0.0\n", | |
"Query 153 average precision: 0.19999999999999998\n", | |
"Query 154 average precision: 1.0\n", | |
"Query 155 average precision: 0.06481481481481481\n", | |
"Query 156 average precision: 0.6365476190476189\n", | |
"Query 157 average precision: 0.43809523809523804\n", | |
"Query 158 average precision: 0.175\n", | |
"Query 159 average precision: 0.015625\n", | |
"Query 160 average precision: 0.2\n", | |
"Query 161 average precision: 0.43333333333333335\n", | |
"Query 162 average precision: 0.020833333333333332\n", | |
"Query 163 average precision: 0.21666666666666667\n", | |
"Query 164 average precision: 0.40208333333333335\n", | |
"Query 165 average precision: 0.16666666666666666\n", | |
"Query 166 average precision: 0.0\n", | |
"Query 167 average precision: 0.325\n", | |
"Query 168 average precision: 0.08333333333333333\n", | |
"Query 169 average precision: 0.125\n", | |
"Query 170 average precision: 0.4611111111111111\n", | |
"Query 171 average precision: 0.4888888888888889\n", | |
"Query 172 average precision: 0.5416666666666666\n", | |
"Query 173 average precision: 0.8333333333333333\n", | |
"Query 174 average precision: 0.03333333333333333\n", | |
"Query 175 average precision: 0.02222222222222222\n", | |
"Query 176 average precision: 0.0\n", | |
"Query 177 average precision: 0.475\n", | |
"Query 178 average precision: 0.49166666666666664\n", | |
"Query 179 average precision: 0.125\n", | |
"Query 180 average precision: 0.34523809523809523\n", | |
"Query 181 average precision: 0.2\n", | |
"Query 182 average precision: 0.7\n", | |
"Query 183 average precision: 0.5308730158730158\n", | |
"Query 184 average precision: 0.06938775510204082\n", | |
"Query 185 average precision: 0.6388888888888888\n", | |
"Query 186 average precision: 0.016666666666666666\n", | |
"Query 187 average precision: 0.16666666666666666\n", | |
"Query 188 average precision: 0.09\n", | |
"Query 189 average precision: 0.015873015873015872\n", | |
"Query 190 average precision: 0.1\n", | |
"Query 191 average precision: 0.014285714285714285\n", | |
"Query 192 average precision: 0.5\n", | |
"Query 193 average precision: 0.6481481481481481\n", | |
"Query 194 average precision: 0.08888888888888889\n", | |
"Query 195 average precision: 0.1111111111111111\n", | |
"Query 196 average precision: 0.03333333333333333\n", | |
"Query 197 average precision: 0.6666666666666666\n", | |
"Query 198 average precision: 0.3125\n", | |
"Query 199 average precision: 0.05208333333333333\n", | |
"Query 200 average precision: 0.24074074074074073\n", | |
"Query 201 average precision: 0.3\n", | |
"Query 202 average precision: 0.26\n", | |
"Query 203 average precision: 0.07222222222222222\n", | |
"Query 204 average precision: 0.0\n", | |
"Query 205 average precision: 0.75\n", | |
"Query 206 average precision: 0.3333333333333333\n", | |
"Query 207 average precision: 0.15\n", | |
"Query 208 average precision: 0.6829365079365081\n", | |
"Query 209 average precision: 0.02\n", | |
"Query 210 average precision: 0.27777777777777773\n", | |
"Query 211 average precision: 0.0125\n", | |
"Query 212 average precision: 0.36666666666666664\n", | |
"Query 213 average precision: 0.5583333333333333\n", | |
"Query 214 average precision: 0.08333333333333333\n", | |
"Query 215 average precision: 0.0\n", | |
"Query 216 average precision: 0.0\n", | |
"Query 217 average precision: 0.05833333333333333\n", | |
"Query 218 average precision: 0.0\n", | |
"Query 219 average precision: 0.014285714285714285\n", | |
"Query 220 average precision: 0.05333333333333333\n", | |
"Query 221 average precision: 0.24333333333333332\n", | |
"Query 222 average precision: 0.5148148148148147\n", | |
"Query 223 average precision: 0.44375\n", | |
"Query 224 average precision: 0.0\n", | |
"Query 225 average precision: 0.13333333333333333\n" | |
] | |
} | |
], | |
"source": [ | |
"def evaluate_ranker(ranker, ev, num_results):\n", | |
" ev.reset_stats()\n", | |
" with open('data/cranfield/cranfield-queries.txt') as query_file:\n", | |
" for query_num, line in enumerate(query_file):\n", | |
" query.content(line.strip())\n", | |
" results = ranker.score(inv_idx, query, num_results) \n", | |
" avg_p = ev.avg_p(results, query_num + 1, num_results)\n", | |
" print(\"Query {} average precision: {}\".format(query_num + 1, avg_p))\n", | |
" \n", | |
"evaluate_ranker(ranker, ev, 10)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Afterwards, we can get the mean average precision of all the queries." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"MAP: 0.21512203955656342\n" | |
] | |
} | |
], | |
"source": [ | |
"dp_map = ev.map()\n", | |
"print(\"MAP: {}\".format(dp_map))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now, let's use the two-component mixture model we discussed as an implementation of pseudo-feedback for retrieval and see if it helps improve performance. The actual ranking function used here is KL-divergence, where the query model is adjusted to include pseudo-feedback from the retrieved documents.\n", | |
"\n", | |
"In order to work, the ranker needs to be able to quickly determine what words were used in the feedback document set. The `InvertedIndex` does not provide fast access to this (since it is a mapping from term to documents, rather than from documents to terms), so we will want to first create a `ForwardIndex` to get the document -> terms mapping." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" \r", | |
" > Counting lines in file: [> ] 0% ETA 00:00:00 \r", | |
" > Counting lines in file: [=================================] 100% ETA 00:00:00 \n", | |
"1529954010: [info] Creating forward index: cranfield-idx/fwd (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/forward_index.cpp:239)\n", | |
" \r", | |
" > Tokenizing Docs: [> ] 0% ETA 00:00:00 \n", | |
"1529954010: [warning] Empty document (id = 470) generated! (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/forward_index.cpp:335)\n", | |
" \n", | |
"1529954010: [warning] Empty document (id = 994) generated! (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/forward_index.cpp:335)\n", | |
" \r", | |
" > Tokenizing Docs: [========================================] 100% ETA 00:00:00 \n", | |
" \r", | |
" > Merging: [> ] 0% ETA 00:00:00 \r", | |
" > Merging: [================================================] 100% ETA 00:00:00 \n", | |
"1529954010: [info] Done creating index: cranfield-idx/fwd (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/forward_index.cpp:278)\n" | |
] | |
} | |
], | |
"source": [ | |
"fwd_idx = metapy.index.make_forward_index('cranfield.toml')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Now we can construct the KL-divergence pseudo-feedback ranker. The main components are:\n", | |
"1. The forward index\n", | |
"2. A base language-model ranker (here we'll use `DirichletPrior`)\n", | |
"3. $\\alpha$, the query interpolation parameter (how strongly do we prefer terms from the feedback model? default 0.5)\n", | |
"4. $\\lambda$, the language-model interpolation parameter (how strong is the background model in the two-component mixture? default 0.5)\n", | |
"5. $k$, the number of documents to retrieve for the feedback set (default 10)\n", | |
"6. `max_terms`, the number of terms from the feedback model to incorporate into the new query model (default 50) " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"feedback = metapy.index.KLDivergencePRF(fwd_idx, metapy.index.DirichletPrior())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Query 1 average precision: 0.13999999999999999\n", | |
"Query 2 average precision: 0.524047619047619\n", | |
"Query 3 average precision: 0.6642857142857143\n", | |
"Query 4 average precision: 0.5\n", | |
"Query 5 average precision: 0.6875\n", | |
"Query 6 average precision: 0.0625\n", | |
"Query 7 average precision: 0.11666666666666665\n", | |
"Query 8 average precision: 0.0\n", | |
"Query 9 average precision: 0.49999999999999994\n", | |
"Query 10 average precision: 0.0625\n", | |
"Query 11 average precision: 0.023809523809523808\n", | |
"Query 12 average precision: 0.15714285714285714\n", | |
"Query 13 average precision: 0.0\n", | |
"Query 14 average precision: 0.5\n", | |
"Query 15 average precision: 0.6428571428571428\n", | |
"Query 16 average precision: 0.05555555555555555\n", | |
"Query 17 average precision: 0.05\n", | |
"Query 18 average precision: 0.16666666666666666\n", | |
"Query 19 average precision: 0.0\n", | |
"Query 20 average precision: 0.33201058201058203\n", | |
"Query 21 average precision: 0.0\n", | |
"Query 22 average precision: 0.0\n", | |
"Query 23 average precision: 0.075\n", | |
"Query 24 average precision: 0.38888888888888884\n", | |
"Query 25 average precision: 0.7530864197530864\n", | |
"Query 26 average precision: 0.061111111111111116\n", | |
"Query 27 average precision: 0.0\n", | |
"Query 28 average precision: 0.0\n", | |
"Query 29 average precision: 0.091005291005291\n", | |
"Query 30 average precision: 0.19999999999999998\n", | |
"Query 31 average precision: 0.0\n", | |
"Query 32 average precision: 0.023809523809523808\n", | |
"Query 33 average precision: 0.6666666666666666\n", | |
"Query 34 average precision: 0.21031746031746032\n", | |
"Query 35 average precision: 0.03333333333333333\n", | |
"Query 36 average precision: 0.6\n", | |
"Query 37 average precision: 0.1111111111111111\n", | |
"Query 38 average precision: 0.02\n", | |
"Query 39 average precision: 0.12857142857142856\n", | |
"Query 40 average precision: 0.07222222222222222\n", | |
"Query 41 average precision: 0.7916666666666666\n", | |
"Query 42 average precision: 0.14523809523809522\n", | |
"Query 43 average precision: 0.3333333333333333\n", | |
"Query 44 average precision: 0.0\n", | |
"Query 45 average precision: 0.13999999999999999\n", | |
"Query 46 average precision: 0.49071428571428566\n", | |
"Query 47 average precision: 0.14027777777777778\n", | |
"Query 48 average precision: 0.1433333333333333\n", | |
"Query 49 average precision: 0.0\n", | |
"Query 50 average precision: 0.0\n", | |
"Query 51 average precision: 0.22999999999999998\n", | |
"Query 52 average precision: 0.08125\n", | |
"Query 53 average precision: 0.15\n", | |
"Query 54 average precision: 0.037037037037037035\n", | |
"Query 55 average precision: 0.11833333333333333\n", | |
"Query 56 average precision: 0.05\n", | |
"Query 57 average precision: 0.0125\n", | |
"Query 58 average precision: 0.16666666666666666\n", | |
"Query 59 average precision: 0.03571428571428571\n", | |
"Query 60 average precision: 0.45999999999999996\n", | |
"Query 61 average precision: 0.4083333333333333\n", | |
"Query 62 average precision: 0.0\n", | |
"Query 63 average precision: 0.0\n", | |
"Query 64 average precision: 0.5\n", | |
"Query 65 average precision: 0.0\n", | |
"Query 66 average precision: 0.0\n", | |
"Query 67 average precision: 0.03111111111111111\n", | |
"Query 68 average precision: 0.06666666666666667\n", | |
"Query 69 average precision: 0.2\n", | |
"Query 70 average precision: 0.1\n", | |
"Query 71 average precision: 0.041666666666666664\n", | |
"Query 72 average precision: 0.0\n", | |
"Query 73 average precision: 0.38916666666666666\n", | |
"Query 74 average precision: 0.05555555555555555\n", | |
"Query 75 average precision: 0.02222222222222222\n", | |
"Query 76 average precision: 0.02040816326530612\n", | |
"Query 77 average precision: 0.34027777777777773\n", | |
"Query 78 average precision: 0.7916666666666666\n", | |
"Query 79 average precision: 0.0\n", | |
"Query 80 average precision: 0.0\n", | |
"Query 81 average precision: 0.25\n", | |
"Query 82 average precision: 0.07333333333333333\n", | |
"Query 83 average precision: 0.05\n", | |
"Query 84 average precision: 0.2375\n", | |
"Query 85 average precision: 0.05\n", | |
"Query 86 average precision: 0.30952380952380953\n", | |
"Query 87 average precision: 0.0\n", | |
"Query 88 average precision: 0.594047619047619\n", | |
"Query 89 average precision: 0.10833333333333334\n", | |
"Query 90 average precision: 0.1875\n", | |
"Query 91 average precision: 0.1111111111111111\n", | |
"Query 92 average precision: 0.5747619047619048\n", | |
"Query 93 average precision: 0.5\n", | |
"Query 94 average precision: 0.4\n", | |
"Query 95 average precision: 0.5\n", | |
"Query 96 average precision: 0.6042063492063492\n", | |
"Query 97 average precision: 0.24166666666666664\n", | |
"Query 98 average precision: 0.0\n", | |
"Query 99 average precision: 0.25\n", | |
"Query 100 average precision: 0.25555555555555554\n", | |
"Query 101 average precision: 0.5444444444444444\n", | |
"Query 102 average precision: 0.08333333333333333\n", | |
"Query 103 average precision: 0.1\n", | |
"Query 104 average precision: 0.1\n", | |
"Query 105 average precision: 0.4333333333333333\n", | |
"Query 106 average precision: 0.3746031746031746\n", | |
"Query 107 average precision: 0.1619047619047619\n", | |
"Query 108 average precision: 0.5488095238095239\n", | |
"Query 109 average precision: 0.0\n", | |
"Query 110 average precision: 0.0\n", | |
"Query 111 average precision: 0.014285714285714287\n", | |
"Query 112 average precision: 0.3611111111111111\n", | |
"Query 113 average precision: 0.10555555555555556\n", | |
"Query 114 average precision: 0.0\n", | |
"Query 115 average precision: 0.0\n", | |
"Query 116 average precision: 0.02857142857142857\n", | |
"Query 117 average precision: 0.0\n", | |
"Query 118 average precision: 0.037037037037037035\n", | |
"Query 119 average precision: 0.3333333333333333\n", | |
"Query 120 average precision: 0.1527777777777778\n", | |
"Query 121 average precision: 0.42857142857142855\n", | |
"Query 122 average precision: 0.07777777777777778\n", | |
"Query 123 average precision: 0.0\n", | |
"Query 124 average precision: 0.0\n", | |
"Query 125 average precision: 0.03333333333333333\n", | |
"Query 126 average precision: 0.20833333333333331\n", | |
"Query 127 average precision: 0.02222222222222222\n", | |
"Query 128 average precision: 0.0\n", | |
"Query 129 average precision: 0.6738095238095239\n", | |
"Query 130 average precision: 0.55\n", | |
"Query 131 average precision: 0.18333333333333335\n", | |
"Query 132 average precision: 0.7254365079365078\n", | |
"Query 133 average precision: 0.18010204081632653\n", | |
"Query 134 average precision: 0.5\n", | |
"Query 135 average precision: 0.37351190476190477\n", | |
"Query 136 average precision: 0.047619047619047616\n", | |
"Query 137 average precision: 0.1125\n", | |
"Query 138 average precision: 0.5\n", | |
"Query 139 average precision: 0.0\n", | |
"Query 140 average precision: 0.10833333333333334\n", | |
"Query 141 average precision: 0.041666666666666664\n", | |
"Query 142 average precision: 0.0\n", | |
"Query 143 average precision: 0.39285714285714285\n", | |
"Query 144 average precision: 0.20555555555555557\n", | |
"Query 145 average precision: 0.07142857142857142\n", | |
"Query 146 average precision: 0.30952380952380953\n", | |
"Query 147 average precision: 0.14999999999999997\n", | |
"Query 148 average precision: 0.075\n", | |
"Query 149 average precision: 0.17666666666666667\n", | |
"Query 150 average precision: 1.0\n", | |
"Query 151 average precision: 0.0\n", | |
"Query 152 average precision: 0.0\n", | |
"Query 153 average precision: 0.2571428571428571\n", | |
"Query 154 average precision: 1.0\n", | |
"Query 155 average precision: 0.06666666666666667\n", | |
"Query 156 average precision: 0.7532142857142856\n", | |
"Query 157 average precision: 0.40555555555555556\n", | |
"Query 158 average precision: 0.19375\n", | |
"Query 159 average precision: 0.020833333333333332\n", | |
"Query 160 average precision: 0.2\n", | |
"Query 161 average precision: 0.3611111111111111\n", | |
"Query 162 average precision: 0.013888888888888888\n", | |
"Query 163 average precision: 0.19444444444444442\n", | |
"Query 164 average precision: 0.40208333333333335\n", | |
"Query 165 average precision: 0.125\n", | |
"Query 166 average precision: 0.0\n", | |
"Query 167 average precision: 0.5\n", | |
"Query 168 average precision: 0.08333333333333333\n", | |
"Query 169 average precision: 0.25\n", | |
"Query 170 average precision: 0.46990740740740744\n", | |
"Query 171 average precision: 0.5555555555555555\n", | |
"Query 172 average precision: 0.5583333333333333\n", | |
"Query 173 average precision: 0.7\n", | |
"Query 174 average precision: 0.03333333333333333\n", | |
"Query 175 average precision: 0.02857142857142857\n", | |
"Query 176 average precision: 0.0\n", | |
"Query 177 average precision: 0.475\n", | |
"Query 178 average precision: 0.75\n", | |
"Query 179 average precision: 0.25\n", | |
"Query 180 average precision: 0.3880952380952381\n", | |
"Query 181 average precision: 0.1\n", | |
"Query 182 average precision: 0.7\n", | |
"Query 183 average precision: 0.6842063492063492\n", | |
"Query 184 average precision: 0.05215419501133787\n", | |
"Query 185 average precision: 0.6296296296296297\n", | |
"Query 186 average precision: 0.04722222222222222\n", | |
"Query 187 average precision: 0.125\n", | |
"Query 188 average precision: 0.19\n", | |
"Query 189 average precision: 0.022222222222222223\n", | |
"Query 190 average precision: 0.05\n", | |
"Query 191 average precision: 0.05357142857142857\n", | |
"Query 192 average precision: 0.575\n", | |
"Query 193 average precision: 0.6565255731922398\n", | |
"Query 194 average precision: 0.19999999999999998\n", | |
"Query 195 average precision: 0.3333333333333333\n", | |
"Query 196 average precision: 0.05\n", | |
"Query 197 average precision: 0.6666666666666666\n", | |
"Query 198 average precision: 0.1875\n", | |
"Query 199 average precision: 0.048611111111111105\n", | |
"Query 200 average precision: 0.40740740740740744\n", | |
"Query 201 average precision: 0.33999999999999997\n", | |
"Query 202 average precision: 0.22666666666666666\n", | |
"Query 203 average precision: 0.07333333333333333\n", | |
"Query 204 average precision: 0.0\n", | |
"Query 205 average precision: 0.7\n", | |
"Query 206 average precision: 0.38888888888888884\n", | |
"Query 207 average precision: 0.3\n", | |
"Query 208 average precision: 0.7345238095238096\n", | |
"Query 209 average precision: 0.03333333333333333\n", | |
"Query 210 average precision: 0.27777777777777773\n", | |
"Query 211 average precision: 0.025\n", | |
"Query 212 average precision: 0.36666666666666664\n", | |
"Query 213 average precision: 0.49309523809523814\n", | |
"Query 214 average precision: 0.08333333333333333\n", | |
"Query 215 average precision: 0.0\n", | |
"Query 216 average precision: 0.0\n", | |
"Query 217 average precision: 0.07857142857142857\n", | |
"Query 218 average precision: 0.0\n", | |
"Query 219 average precision: 0.014285714285714285\n", | |
"Query 220 average precision: 0.07333333333333333\n", | |
"Query 221 average precision: 0.24333333333333332\n", | |
"Query 222 average precision: 0.5238095238095238\n", | |
"Query 223 average precision: 0.4583333333333333\n", | |
"Query 224 average precision: 0.0\n", | |
"Query 225 average precision: 0.15\n" | |
] | |
} | |
], | |
"source": [ | |
"evaluate_ranker(feedback, ev, 10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Feedback MAP: 0.22816526133086987\n", | |
"DP MAP: 0.21512203955656342\n" | |
] | |
} | |
], | |
"source": [ | |
"fb_map = ev.map()\n", | |
"print(\"Feedback MAP: {}\".format(fb_map))\n", | |
"print(\"DP MAP: {}\".format(dp_map))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment