Skip to content

Instantly share code, notes, and snippets.

@skystrife
Created July 8, 2018 13:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save skystrife/2942c00ba09905320535b3e1e5784ee7 to your computer and use it in GitHub Desktop.
Save skystrife/2942c00ba09905320535b3e1e5784ee7 to your computer and use it in GitHub Desktop.
SIGIR18: Probabilistic Topic Models Hands-on Exercise 1 (Retrieval)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Exercise 1: Pseudo-feedback with Two-component Mixture Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's import the Python bindings for MeTA:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import metapy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you don't have `metapy` installed, you can install it with a\n",
"\n",
"```bash\n",
"pip install metapy\n",
"```\n",
"\n",
"on the command line on Linux, macOS, or Windows for either Python 2.7 or Python 3.x. (I will be using Python 3.6 in this tutorial.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Double-check that you are running the latest version. Right now, that should be `0.2.10`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.2.10'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"metapy.__version__"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's set MeTA to log to standard error so we can see progress output for long-running commands. (Only do this once, or you'll get double the output.)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"metapy.log_to_stderr()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's download all of the files we need for the tutorial."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import urllib.request\n",
"import os\n",
"import tarfile\n",
"\n",
"if not os.path.exists('sigir18-tutorial.tar.gz'):\n",
" urllib.request.urlretrieve('https://meta-toolkit.org/data/2018-06-25/sigir18-tutorial.tar.gz',\n",
" 'sigir18-tutorial.tar.gz')\n",
" \n",
"if not os.path.exists('data'):\n",
" with tarfile.open('sigir18-tutorial.tar.gz', 'r:gz') as files:\n",
" files.extractall()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's index our data using the `InvertedIndex` format. In a search engine, we want to quickly determine what documents mention a specific query term, so the `InvertedIndex` stores a mapping from term to a list of documents that contain that term (along with how many times they do)."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r",
" > Counting lines in file: [> ] 0% ETA 00:00:00 \r",
" > Counting lines in file: [=================================] 100% ETA 00:00:00 \n",
"1529953996: [info] Creating index: cranfield-idx/inv (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:119)\n",
" \r",
" > Tokenizing Docs: [> ] 0% ETA 00:00:00 \n",
"1529953996: [warning] Empty document (id = 470) generated! (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:228)\n",
" \n",
"1529953996: [warning] Empty document (id = 994) generated! (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:228)\n",
" \r",
" > Tokenizing Docs: [========================================] 100% ETA 00:00:00 \n",
" \r",
" > Merging: [> ] 0% ETA 00:00:00 \r",
" > Merging: [================================================] 100% ETA 00:00:00 \n",
"1529953996: [info] Created uncompressed postings file cranfield-idx/inv/postings.index (197.770000 KB) (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:148)\n",
" \r",
" > Compressing postings: [> ] 0% ETA 00:00:00 \r",
" > Compressing postings: [===================================] 100% ETA 00:00:00 \n",
"1529953996: [info] Created compressed postings file (168.060000 KB) (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:279)\n",
"1529953996: [info] Done creating index: cranfield-idx/inv (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/inverted_index.cpp:166)\n"
]
}
],
"source": [
"inv_idx = metapy.index.make_inverted_index('cranfield.toml')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This may take a minute at first, since the index needs to be built. Subsequent calls to `make_inverted_index` with this config file will simply load the index, which will not take any time.\n",
"\n",
"Here's how we can interact with the index object:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1400"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_idx.num_docs()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4137"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_idx.unique_terms()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"87.17857360839844"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_idx.avg_doc_length()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"122050"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inv_idx.total_corpus_terms()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's search our index. We'll start by creating a ranker:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"ranker = metapy.index.DirichletPrior()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we need a query. Let's create an example query."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"query = metapy.index.Document()\n",
"query.content(\"flow equilibrium\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use this to search our index like so:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(235, 1.2931444644927979),\n",
" (1251, 1.256299614906311),\n",
" (316, 1.1081531047821045),\n",
" (655, 1.0878994464874268),\n",
" (574, 1.076568841934204)]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top_docs = ranker.score(inv_idx, query, num_results=5)\n",
"top_docs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We are returned a ranked list of *(doc_id, score)* pairs. The scores are from the ranker, which in this case was Okapi BM25. Since the `tutorial.toml` file we created for the cranfield dataset has `store-full-text = true`, we can verify the content of our top documents by inspecting the document metadata field \"content\"."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1. criteria for thermodynamic equilibrium in gas flow . when gases flow at high velocity, the rates of internal processes may not be fast enough to maintain thermodynamic equilibrium . by defining quasi-equilibrium in flow as the condition in which the...\n",
"\n",
"2. on the approach to chemical and vibrational equilibrium behind a strong normal shock wave . the concurrent approach to chemical and vibrational equilibrium of a pure diatomic gas passing through a strong normal shock wave is investigated . it is dem...\n",
"\n",
"3. non-equilibrium flow of an ideal dissociating gas . the theory of an'ideal dissociating'gas developed by lighthill/1957/for conditions of thermodynamic equilibrium is extended to non-equilibrium conditions by postulating a simple rate equation for th...\n",
"\n",
"4. departure from dissociation equilibrium in a hypersonic nozzle . the equations of motion for the flow of an ideal dissociating gas through a nearly conical nozzle have been solved numerically, assuming a simple equation for the rate of dissociation, ...\n",
"\n",
"5. atomic recombination in a hypersonic wind tunnel nozzle . the flow of an ideal dissociating gas through a nearly conical nozzle is considered . the equations of one-dimensional motion are solved numerically assuming a simple rate equation together wi...\n",
"\n"
]
}
],
"source": [
"for num, (d_id, _) in enumerate(top_docs):\n",
" content = inv_idx.metadata(d_id).get('content')\n",
" print(\"{}. {}...\\n\".format(num + 1, content[0:250]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since we have the queries file and relevance judgements, we can do an IR evaluation."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"ev = metapy.index.IREval('cranfield.toml')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will loop over the queries file and add each result to the `IREval` object `ev`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Query 1 average precision: 0.19\n",
"Query 2 average precision: 0.5433333333333332\n",
"Query 3 average precision: 0.6541666666666666\n",
"Query 4 average precision: 0.5\n",
"Query 5 average precision: 0.35\n",
"Query 6 average precision: 0.0625\n",
"Query 7 average precision: 0.10666666666666666\n",
"Query 8 average precision: 0.0\n",
"Query 9 average precision: 0.6984126984126983\n",
"Query 10 average precision: 0.0625\n",
"Query 11 average precision: 0.028571428571428574\n",
"Query 12 average precision: 0.18\n",
"Query 13 average precision: 0.0\n",
"Query 14 average precision: 0.5\n",
"Query 15 average precision: 0.7\n",
"Query 16 average precision: 0.08333333333333333\n",
"Query 17 average precision: 0.07142857142857142\n",
"Query 18 average precision: 0.3333333333333333\n",
"Query 19 average precision: 0.0\n",
"Query 20 average precision: 0.2685185185185185\n",
"Query 21 average precision: 0.0\n",
"Query 22 average precision: 0.0\n",
"Query 23 average precision: 0.04722222222222222\n",
"Query 24 average precision: 0.3333333333333333\n",
"Query 25 average precision: 0.6507936507936507\n",
"Query 26 average precision: 0.13888888888888887\n",
"Query 27 average precision: 0.0\n",
"Query 28 average precision: 0.0\n",
"Query 29 average precision: 0.14166666666666666\n",
"Query 30 average precision: 0.11904761904761904\n",
"Query 31 average precision: 0.0\n",
"Query 32 average precision: 0.023809523809523808\n",
"Query 33 average precision: 0.5317460317460317\n",
"Query 34 average precision: 0.20555555555555557\n",
"Query 35 average precision: 0.0\n",
"Query 36 average precision: 0.25\n",
"Query 37 average precision: 0.05555555555555555\n",
"Query 38 average precision: 0.01111111111111111\n",
"Query 39 average precision: 0.12857142857142856\n",
"Query 40 average precision: 0.07222222222222222\n",
"Query 41 average precision: 0.7666666666666666\n",
"Query 42 average precision: 0.1638095238095238\n",
"Query 43 average precision: 0.48333333333333334\n",
"Query 44 average precision: 0.0\n",
"Query 45 average precision: 0.155\n",
"Query 46 average precision: 0.5380952380952382\n",
"Query 47 average precision: 0.15083333333333332\n",
"Query 48 average precision: 0.1433333333333333\n",
"Query 49 average precision: 0.0\n",
"Query 50 average precision: 0.0\n",
"Query 51 average precision: 0.16666666666666666\n",
"Query 52 average precision: 0.027777777777777776\n",
"Query 53 average precision: 0.15857142857142856\n",
"Query 54 average precision: 0.05555555555555555\n",
"Query 55 average precision: 0.11499999999999999\n",
"Query 56 average precision: 0.05\n",
"Query 57 average precision: 0.01\n",
"Query 58 average precision: 0.15873015873015872\n",
"Query 59 average precision: 0.03571428571428571\n",
"Query 60 average precision: 0.4666666666666667\n",
"Query 61 average precision: 0.45999999999999996\n",
"Query 62 average precision: 0.0\n",
"Query 63 average precision: 0.0\n",
"Query 64 average precision: 0.5\n",
"Query 65 average precision: 0.0\n",
"Query 66 average precision: 0.0\n",
"Query 67 average precision: 0.0365079365079365\n",
"Query 68 average precision: 0.1\n",
"Query 69 average precision: 0.06666666666666667\n",
"Query 70 average precision: 0.1\n",
"Query 71 average precision: 0.03125\n",
"Query 72 average precision: 0.0\n",
"Query 73 average precision: 0.2861111111111111\n",
"Query 74 average precision: 0.020833333333333332\n",
"Query 75 average precision: 0.03333333333333333\n",
"Query 76 average precision: 0.028571428571428574\n",
"Query 77 average precision: 0.28888888888888886\n",
"Query 78 average precision: 0.6666666666666666\n",
"Query 79 average precision: 0.0\n",
"Query 80 average precision: 0.0\n",
"Query 81 average precision: 0.25\n",
"Query 82 average precision: 0.09\n",
"Query 83 average precision: 0.0625\n",
"Query 84 average precision: 0.3\n",
"Query 85 average precision: 0.05\n",
"Query 86 average precision: 0.41666666666666663\n",
"Query 87 average precision: 0.0\n",
"Query 88 average precision: 0.594047619047619\n",
"Query 89 average precision: 0.075\n",
"Query 90 average precision: 0.19\n",
"Query 91 average precision: 0.0873015873015873\n",
"Query 92 average precision: 0.4726984126984126\n",
"Query 93 average precision: 0.5\n",
"Query 94 average precision: 0.4\n",
"Query 95 average precision: 0.5\n",
"Query 96 average precision: 0.4463095238095239\n",
"Query 97 average precision: 0.15416666666666665\n",
"Query 98 average precision: 0.0\n",
"Query 99 average precision: 0.19642857142857142\n",
"Query 100 average precision: 0.2185185185185185\n",
"Query 101 average precision: 0.6180555555555555\n",
"Query 102 average precision: 0.08333333333333333\n",
"Query 103 average precision: 0.05555555555555555\n",
"Query 104 average precision: 0.1\n",
"Query 105 average precision: 0.35333333333333333\n",
"Query 106 average precision: 0.24666666666666667\n",
"Query 107 average precision: 0.10476190476190476\n",
"Query 108 average precision: 0.6261904761904761\n",
"Query 109 average precision: 0.0\n",
"Query 110 average precision: 0.0\n",
"Query 111 average precision: 0.07619047619047618\n",
"Query 112 average precision: 0.3611111111111111\n",
"Query 113 average precision: 0.10416666666666666\n",
"Query 114 average precision: 0.0\n",
"Query 115 average precision: 0.0\n",
"Query 116 average precision: 0.03333333333333333\n",
"Query 117 average precision: 0.0\n",
"Query 118 average precision: 0.041666666666666664\n",
"Query 119 average precision: 0.5\n",
"Query 120 average precision: 0.17724867724867724\n",
"Query 121 average precision: 0.4768707482993197\n",
"Query 122 average precision: 0.05925925925925926\n",
"Query 123 average precision: 0.0\n",
"Query 124 average precision: 0.0\n",
"Query 125 average precision: 0.02\n",
"Query 126 average precision: 0.20833333333333331\n",
"Query 127 average precision: 0.025\n",
"Query 128 average precision: 0.0\n",
"Query 129 average precision: 0.4773809523809524\n",
"Query 130 average precision: 0.3933333333333333\n",
"Query 131 average precision: 0.09375\n",
"Query 132 average precision: 0.6592063492063491\n",
"Query 133 average precision: 0.12976190476190477\n",
"Query 134 average precision: 0.5\n",
"Query 135 average precision: 0.38690476190476186\n",
"Query 136 average precision: 0.03333333333333333\n",
"Query 137 average precision: 0.1875\n",
"Query 138 average precision: 0.25\n",
"Query 139 average precision: 0.0\n",
"Query 140 average precision: 0.10833333333333334\n",
"Query 141 average precision: 0.05555555555555555\n",
"Query 142 average precision: 0.0\n",
"Query 143 average precision: 0.6111111111111112\n",
"Query 144 average precision: 0.14722222222222223\n",
"Query 145 average precision: 0.07142857142857142\n",
"Query 146 average precision: 0.41666666666666663\n",
"Query 147 average precision: 0.14666666666666667\n",
"Query 148 average precision: 0.125\n",
"Query 149 average precision: 0.18285714285714286\n",
"Query 150 average precision: 1.0\n",
"Query 151 average precision: 0.0\n",
"Query 152 average precision: 0.0\n",
"Query 153 average precision: 0.19999999999999998\n",
"Query 154 average precision: 1.0\n",
"Query 155 average precision: 0.06481481481481481\n",
"Query 156 average precision: 0.6365476190476189\n",
"Query 157 average precision: 0.43809523809523804\n",
"Query 158 average precision: 0.175\n",
"Query 159 average precision: 0.015625\n",
"Query 160 average precision: 0.2\n",
"Query 161 average precision: 0.43333333333333335\n",
"Query 162 average precision: 0.020833333333333332\n",
"Query 163 average precision: 0.21666666666666667\n",
"Query 164 average precision: 0.40208333333333335\n",
"Query 165 average precision: 0.16666666666666666\n",
"Query 166 average precision: 0.0\n",
"Query 167 average precision: 0.325\n",
"Query 168 average precision: 0.08333333333333333\n",
"Query 169 average precision: 0.125\n",
"Query 170 average precision: 0.4611111111111111\n",
"Query 171 average precision: 0.4888888888888889\n",
"Query 172 average precision: 0.5416666666666666\n",
"Query 173 average precision: 0.8333333333333333\n",
"Query 174 average precision: 0.03333333333333333\n",
"Query 175 average precision: 0.02222222222222222\n",
"Query 176 average precision: 0.0\n",
"Query 177 average precision: 0.475\n",
"Query 178 average precision: 0.49166666666666664\n",
"Query 179 average precision: 0.125\n",
"Query 180 average precision: 0.34523809523809523\n",
"Query 181 average precision: 0.2\n",
"Query 182 average precision: 0.7\n",
"Query 183 average precision: 0.5308730158730158\n",
"Query 184 average precision: 0.06938775510204082\n",
"Query 185 average precision: 0.6388888888888888\n",
"Query 186 average precision: 0.016666666666666666\n",
"Query 187 average precision: 0.16666666666666666\n",
"Query 188 average precision: 0.09\n",
"Query 189 average precision: 0.015873015873015872\n",
"Query 190 average precision: 0.1\n",
"Query 191 average precision: 0.014285714285714285\n",
"Query 192 average precision: 0.5\n",
"Query 193 average precision: 0.6481481481481481\n",
"Query 194 average precision: 0.08888888888888889\n",
"Query 195 average precision: 0.1111111111111111\n",
"Query 196 average precision: 0.03333333333333333\n",
"Query 197 average precision: 0.6666666666666666\n",
"Query 198 average precision: 0.3125\n",
"Query 199 average precision: 0.05208333333333333\n",
"Query 200 average precision: 0.24074074074074073\n",
"Query 201 average precision: 0.3\n",
"Query 202 average precision: 0.26\n",
"Query 203 average precision: 0.07222222222222222\n",
"Query 204 average precision: 0.0\n",
"Query 205 average precision: 0.75\n",
"Query 206 average precision: 0.3333333333333333\n",
"Query 207 average precision: 0.15\n",
"Query 208 average precision: 0.6829365079365081\n",
"Query 209 average precision: 0.02\n",
"Query 210 average precision: 0.27777777777777773\n",
"Query 211 average precision: 0.0125\n",
"Query 212 average precision: 0.36666666666666664\n",
"Query 213 average precision: 0.5583333333333333\n",
"Query 214 average precision: 0.08333333333333333\n",
"Query 215 average precision: 0.0\n",
"Query 216 average precision: 0.0\n",
"Query 217 average precision: 0.05833333333333333\n",
"Query 218 average precision: 0.0\n",
"Query 219 average precision: 0.014285714285714285\n",
"Query 220 average precision: 0.05333333333333333\n",
"Query 221 average precision: 0.24333333333333332\n",
"Query 222 average precision: 0.5148148148148147\n",
"Query 223 average precision: 0.44375\n",
"Query 224 average precision: 0.0\n",
"Query 225 average precision: 0.13333333333333333\n"
]
}
],
"source": [
"def evaluate_ranker(ranker, ev, num_results):\n",
" ev.reset_stats()\n",
" with open('data/cranfield/cranfield-queries.txt') as query_file:\n",
" for query_num, line in enumerate(query_file):\n",
" query.content(line.strip())\n",
" results = ranker.score(inv_idx, query, num_results) \n",
" avg_p = ev.avg_p(results, query_num + 1, num_results)\n",
" print(\"Query {} average precision: {}\".format(query_num + 1, avg_p))\n",
" \n",
"evaluate_ranker(ranker, ev, 10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Afterwards, we can get the mean average precision of all the queries."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MAP: 0.21512203955656342\n"
]
}
],
"source": [
"dp_map = ev.map()\n",
"print(\"MAP: {}\".format(dp_map))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's use the two-component mixture model we discussed as an implementation of pseudo-feedback for retrieval and see if it helps improve performance. The actual ranking function used here is KL-divergence, where the query model is adjusted to include pseudo-feedback from the retrieved documents.\n",
"\n",
"In order to work, the ranker needs to be able to quickly determine what words were used in the feedback document set. The `InvertedIndex` does not provide fast access to this (since it is a mapping from term to documents, rather than from documents to terms), so we will want to first create a `ForwardIndex` to get the document -> terms mapping."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r",
" > Counting lines in file: [> ] 0% ETA 00:00:00 \r",
" > Counting lines in file: [=================================] 100% ETA 00:00:00 \n",
"1529954010: [info] Creating forward index: cranfield-idx/fwd (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/forward_index.cpp:239)\n",
" \r",
" > Tokenizing Docs: [> ] 0% ETA 00:00:00 \n",
"1529954010: [warning] Empty document (id = 470) generated! (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/forward_index.cpp:335)\n",
" \n",
"1529954010: [warning] Empty document (id = 994) generated! (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/forward_index.cpp:335)\n",
" \r",
" > Tokenizing Docs: [========================================] 100% ETA 00:00:00 \n",
" \r",
" > Merging: [> ] 0% ETA 00:00:00 \r",
" > Merging: [================================================] 100% ETA 00:00:00 \n",
"1529954010: [info] Done creating index: cranfield-idx/fwd (/tmp/pip-req-build-m473bt6z/deps/meta/src/index/forward_index.cpp:278)\n"
]
}
],
"source": [
"fwd_idx = metapy.index.make_forward_index('cranfield.toml')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can construct the KL-divergence pseudo-feedback ranker. The main components are:\n",
"1. The forward index\n",
"2. A base language-model ranker (here we'll use `DirichletPrior`)\n",
"3. $\\alpha$, the query interpolation parameter (how strongly do we prefer terms from the feedback model? default 0.5)\n",
"4. $\\lambda$, the language-model interpolation parameter (how strong is the background model in the two-component mixture? default 0.5)\n",
"5. $k$, the number of documents to retrieve for the feedback set (default 10)\n",
"6. `max_terms`, the number of terms from the feedback model to incorporate into the new query model (default 50) "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"feedback = metapy.index.KLDivergencePRF(fwd_idx, metapy.index.DirichletPrior())"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Query 1 average precision: 0.13999999999999999\n",
"Query 2 average precision: 0.524047619047619\n",
"Query 3 average precision: 0.6642857142857143\n",
"Query 4 average precision: 0.5\n",
"Query 5 average precision: 0.6875\n",
"Query 6 average precision: 0.0625\n",
"Query 7 average precision: 0.11666666666666665\n",
"Query 8 average precision: 0.0\n",
"Query 9 average precision: 0.49999999999999994\n",
"Query 10 average precision: 0.0625\n",
"Query 11 average precision: 0.023809523809523808\n",
"Query 12 average precision: 0.15714285714285714\n",
"Query 13 average precision: 0.0\n",
"Query 14 average precision: 0.5\n",
"Query 15 average precision: 0.6428571428571428\n",
"Query 16 average precision: 0.05555555555555555\n",
"Query 17 average precision: 0.05\n",
"Query 18 average precision: 0.16666666666666666\n",
"Query 19 average precision: 0.0\n",
"Query 20 average precision: 0.33201058201058203\n",
"Query 21 average precision: 0.0\n",
"Query 22 average precision: 0.0\n",
"Query 23 average precision: 0.075\n",
"Query 24 average precision: 0.38888888888888884\n",
"Query 25 average precision: 0.7530864197530864\n",
"Query 26 average precision: 0.061111111111111116\n",
"Query 27 average precision: 0.0\n",
"Query 28 average precision: 0.0\n",
"Query 29 average precision: 0.091005291005291\n",
"Query 30 average precision: 0.19999999999999998\n",
"Query 31 average precision: 0.0\n",
"Query 32 average precision: 0.023809523809523808\n",
"Query 33 average precision: 0.6666666666666666\n",
"Query 34 average precision: 0.21031746031746032\n",
"Query 35 average precision: 0.03333333333333333\n",
"Query 36 average precision: 0.6\n",
"Query 37 average precision: 0.1111111111111111\n",
"Query 38 average precision: 0.02\n",
"Query 39 average precision: 0.12857142857142856\n",
"Query 40 average precision: 0.07222222222222222\n",
"Query 41 average precision: 0.7916666666666666\n",
"Query 42 average precision: 0.14523809523809522\n",
"Query 43 average precision: 0.3333333333333333\n",
"Query 44 average precision: 0.0\n",
"Query 45 average precision: 0.13999999999999999\n",
"Query 46 average precision: 0.49071428571428566\n",
"Query 47 average precision: 0.14027777777777778\n",
"Query 48 average precision: 0.1433333333333333\n",
"Query 49 average precision: 0.0\n",
"Query 50 average precision: 0.0\n",
"Query 51 average precision: 0.22999999999999998\n",
"Query 52 average precision: 0.08125\n",
"Query 53 average precision: 0.15\n",
"Query 54 average precision: 0.037037037037037035\n",
"Query 55 average precision: 0.11833333333333333\n",
"Query 56 average precision: 0.05\n",
"Query 57 average precision: 0.0125\n",
"Query 58 average precision: 0.16666666666666666\n",
"Query 59 average precision: 0.03571428571428571\n",
"Query 60 average precision: 0.45999999999999996\n",
"Query 61 average precision: 0.4083333333333333\n",
"Query 62 average precision: 0.0\n",
"Query 63 average precision: 0.0\n",
"Query 64 average precision: 0.5\n",
"Query 65 average precision: 0.0\n",
"Query 66 average precision: 0.0\n",
"Query 67 average precision: 0.03111111111111111\n",
"Query 68 average precision: 0.06666666666666667\n",
"Query 69 average precision: 0.2\n",
"Query 70 average precision: 0.1\n",
"Query 71 average precision: 0.041666666666666664\n",
"Query 72 average precision: 0.0\n",
"Query 73 average precision: 0.38916666666666666\n",
"Query 74 average precision: 0.05555555555555555\n",
"Query 75 average precision: 0.02222222222222222\n",
"Query 76 average precision: 0.02040816326530612\n",
"Query 77 average precision: 0.34027777777777773\n",
"Query 78 average precision: 0.7916666666666666\n",
"Query 79 average precision: 0.0\n",
"Query 80 average precision: 0.0\n",
"Query 81 average precision: 0.25\n",
"Query 82 average precision: 0.07333333333333333\n",
"Query 83 average precision: 0.05\n",
"Query 84 average precision: 0.2375\n",
"Query 85 average precision: 0.05\n",
"Query 86 average precision: 0.30952380952380953\n",
"Query 87 average precision: 0.0\n",
"Query 88 average precision: 0.594047619047619\n",
"Query 89 average precision: 0.10833333333333334\n",
"Query 90 average precision: 0.1875\n",
"Query 91 average precision: 0.1111111111111111\n",
"Query 92 average precision: 0.5747619047619048\n",
"Query 93 average precision: 0.5\n",
"Query 94 average precision: 0.4\n",
"Query 95 average precision: 0.5\n",
"Query 96 average precision: 0.6042063492063492\n",
"Query 97 average precision: 0.24166666666666664\n",
"Query 98 average precision: 0.0\n",
"Query 99 average precision: 0.25\n",
"Query 100 average precision: 0.25555555555555554\n",
"Query 101 average precision: 0.5444444444444444\n",
"Query 102 average precision: 0.08333333333333333\n",
"Query 103 average precision: 0.1\n",
"Query 104 average precision: 0.1\n",
"Query 105 average precision: 0.4333333333333333\n",
"Query 106 average precision: 0.3746031746031746\n",
"Query 107 average precision: 0.1619047619047619\n",
"Query 108 average precision: 0.5488095238095239\n",
"Query 109 average precision: 0.0\n",
"Query 110 average precision: 0.0\n",
"Query 111 average precision: 0.014285714285714287\n",
"Query 112 average precision: 0.3611111111111111\n",
"Query 113 average precision: 0.10555555555555556\n",
"Query 114 average precision: 0.0\n",
"Query 115 average precision: 0.0\n",
"Query 116 average precision: 0.02857142857142857\n",
"Query 117 average precision: 0.0\n",
"Query 118 average precision: 0.037037037037037035\n",
"Query 119 average precision: 0.3333333333333333\n",
"Query 120 average precision: 0.1527777777777778\n",
"Query 121 average precision: 0.42857142857142855\n",
"Query 122 average precision: 0.07777777777777778\n",
"Query 123 average precision: 0.0\n",
"Query 124 average precision: 0.0\n",
"Query 125 average precision: 0.03333333333333333\n",
"Query 126 average precision: 0.20833333333333331\n",
"Query 127 average precision: 0.02222222222222222\n",
"Query 128 average precision: 0.0\n",
"Query 129 average precision: 0.6738095238095239\n",
"Query 130 average precision: 0.55\n",
"Query 131 average precision: 0.18333333333333335\n",
"Query 132 average precision: 0.7254365079365078\n",
"Query 133 average precision: 0.18010204081632653\n",
"Query 134 average precision: 0.5\n",
"Query 135 average precision: 0.37351190476190477\n",
"Query 136 average precision: 0.047619047619047616\n",
"Query 137 average precision: 0.1125\n",
"Query 138 average precision: 0.5\n",
"Query 139 average precision: 0.0\n",
"Query 140 average precision: 0.10833333333333334\n",
"Query 141 average precision: 0.041666666666666664\n",
"Query 142 average precision: 0.0\n",
"Query 143 average precision: 0.39285714285714285\n",
"Query 144 average precision: 0.20555555555555557\n",
"Query 145 average precision: 0.07142857142857142\n",
"Query 146 average precision: 0.30952380952380953\n",
"Query 147 average precision: 0.14999999999999997\n",
"Query 148 average precision: 0.075\n",
"Query 149 average precision: 0.17666666666666667\n",
"Query 150 average precision: 1.0\n",
"Query 151 average precision: 0.0\n",
"Query 152 average precision: 0.0\n",
"Query 153 average precision: 0.2571428571428571\n",
"Query 154 average precision: 1.0\n",
"Query 155 average precision: 0.06666666666666667\n",
"Query 156 average precision: 0.7532142857142856\n",
"Query 157 average precision: 0.40555555555555556\n",
"Query 158 average precision: 0.19375\n",
"Query 159 average precision: 0.020833333333333332\n",
"Query 160 average precision: 0.2\n",
"Query 161 average precision: 0.3611111111111111\n",
"Query 162 average precision: 0.013888888888888888\n",
"Query 163 average precision: 0.19444444444444442\n",
"Query 164 average precision: 0.40208333333333335\n",
"Query 165 average precision: 0.125\n",
"Query 166 average precision: 0.0\n",
"Query 167 average precision: 0.5\n",
"Query 168 average precision: 0.08333333333333333\n",
"Query 169 average precision: 0.25\n",
"Query 170 average precision: 0.46990740740740744\n",
"Query 171 average precision: 0.5555555555555555\n",
"Query 172 average precision: 0.5583333333333333\n",
"Query 173 average precision: 0.7\n",
"Query 174 average precision: 0.03333333333333333\n",
"Query 175 average precision: 0.02857142857142857\n",
"Query 176 average precision: 0.0\n",
"Query 177 average precision: 0.475\n",
"Query 178 average precision: 0.75\n",
"Query 179 average precision: 0.25\n",
"Query 180 average precision: 0.3880952380952381\n",
"Query 181 average precision: 0.1\n",
"Query 182 average precision: 0.7\n",
"Query 183 average precision: 0.6842063492063492\n",
"Query 184 average precision: 0.05215419501133787\n",
"Query 185 average precision: 0.6296296296296297\n",
"Query 186 average precision: 0.04722222222222222\n",
"Query 187 average precision: 0.125\n",
"Query 188 average precision: 0.19\n",
"Query 189 average precision: 0.022222222222222223\n",
"Query 190 average precision: 0.05\n",
"Query 191 average precision: 0.05357142857142857\n",
"Query 192 average precision: 0.575\n",
"Query 193 average precision: 0.6565255731922398\n",
"Query 194 average precision: 0.19999999999999998\n",
"Query 195 average precision: 0.3333333333333333\n",
"Query 196 average precision: 0.05\n",
"Query 197 average precision: 0.6666666666666666\n",
"Query 198 average precision: 0.1875\n",
"Query 199 average precision: 0.048611111111111105\n",
"Query 200 average precision: 0.40740740740740744\n",
"Query 201 average precision: 0.33999999999999997\n",
"Query 202 average precision: 0.22666666666666666\n",
"Query 203 average precision: 0.07333333333333333\n",
"Query 204 average precision: 0.0\n",
"Query 205 average precision: 0.7\n",
"Query 206 average precision: 0.38888888888888884\n",
"Query 207 average precision: 0.3\n",
"Query 208 average precision: 0.7345238095238096\n",
"Query 209 average precision: 0.03333333333333333\n",
"Query 210 average precision: 0.27777777777777773\n",
"Query 211 average precision: 0.025\n",
"Query 212 average precision: 0.36666666666666664\n",
"Query 213 average precision: 0.49309523809523814\n",
"Query 214 average precision: 0.08333333333333333\n",
"Query 215 average precision: 0.0\n",
"Query 216 average precision: 0.0\n",
"Query 217 average precision: 0.07857142857142857\n",
"Query 218 average precision: 0.0\n",
"Query 219 average precision: 0.014285714285714285\n",
"Query 220 average precision: 0.07333333333333333\n",
"Query 221 average precision: 0.24333333333333332\n",
"Query 222 average precision: 0.5238095238095238\n",
"Query 223 average precision: 0.4583333333333333\n",
"Query 224 average precision: 0.0\n",
"Query 225 average precision: 0.15\n"
]
}
],
"source": [
"evaluate_ranker(feedback, ev, 10)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Feedback MAP: 0.22816526133086987\n",
"DP MAP: 0.21512203955656342\n"
]
}
],
"source": [
"fb_map = ev.map()\n",
"print(\"Feedback MAP: {}\".format(fb_map))\n",
"print(\"DP MAP: {}\".format(dp_map))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment