Skip to content

Instantly share code, notes, and snippets.

@mkolod
Last active February 8, 2019 00:59
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mkolod/b1222fb84b762ee0b78b2d626f7ecfd4 to your computer and use it in GitHub Desktop.
Save mkolod/b1222fb84b762ee0b78b2d626f7ecfd4 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Imports\n",
"\n",
"import numpy as np\n",
"import lda.datasets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Helper functions\n",
"\n",
"import re, string\n",
"\n",
"def remove_date(text):\n",
" return re.sub(' ?\\d\\d\\d\\d-\\d\\d-\\d\\d ?', '', text)\n",
"\n",
"exclude = set(string.punctuation)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Load data\n",
"\n",
"X = lda.datasets.load_reuters()\n",
"vocab = list(lda.datasets.load_reuters_vocab())\n",
"reverse_vocab = dict(zip(vocab, range(len(vocab))))\n",
"titles = [remove_date(x.lower().split(' ', 1)[1]).translate(None, string.punctuation) \n",
" for x in lda.datasets.load_reuters_titles()]\n",
"reverse_titles = dict(zip(titles, range(len(titles))))\n",
"num_docs, vocab_size = X.shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# TF-IDF\n",
"\n",
"X_binary = np.copy(X)\n",
"X_binary[X_binary > 0] = 1\n",
"\n",
"sums = np.sum(X_binary, axis = 0)\n",
"idf = np.log(1.0 * num_docs / sums)\n",
"mul = np.multiply(X, idf)\n",
"tf_idf = np.transpose(mul)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# SVD\n",
"\n",
"u, s, v = np.linalg.svd(tf_idf, full_matrices=False)\n",
"s = np.diag(s)\n",
"np.allclose(tf_idf, np.dot(u, np.dot(s, v)))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Reduced-rank approximation\n",
"\n",
"num_docs = u.shape\n",
"num_concepts = 200\n",
"assert(num_concepts < num_docs and num_concepts < vocab_size)\n",
"\n",
"u_trunc = u[:, :num_concepts]\n",
"s_trunc = s[:num_concepts, :num_concepts]\n",
"v_trunc = v[:num_concepts, :]\n",
"\n",
"X_approx = np.dot(u_trunc, np.dot(s_trunc, v_trunc))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Normalize rows of u and columns of v' and calculate cosine similarity\n",
"from sklearn.preprocessing import normalize\n",
"\n",
"u_trunc_norm = normalize(u_trunc, axis = 1, norm = 'l2')\n",
"v_trunc_norm = normalize(v_trunc, axis = 0, norm = 'l2')\n",
"\n",
"term_sim = np.dot(u_trunc_norm, np.transpose(u_trunc_norm))\n",
"doc_sim = np.dot(np.transpose(v_trunc_norm), v_trunc_norm)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Results for 'clinton'\n",
"['mrs', 'arkansas', 'blind', 'jones', 'watched']\n",
"\n",
"Results for 'chirac'\n",
"['jacques', 'hailed', 'french', 'france', 'secular']\n",
"\n",
"Results for 'yeltsin'\n",
"['akchurin', 'renat', 'viktor', 'kremlin', 'interfax']\n"
]
}
],
"source": [
"# List similar terms to query term based on the scaled, truncated u matrix\n",
"from operator import itemgetter\n",
"\n",
"def term_query(term, num_to_retrieve = 5):\n",
" query = reverse_vocab[term]\n",
" sims = list(enumerate(term_sim[query, :]))\n",
" nearest = sorted(sims, key=itemgetter(1), reverse=True)\n",
" result = []\n",
" for i in range(1, num_to_retrieve + 1):\n",
" result.append(vocab[nearest[i][0]])\n",
" return result\n",
"\n",
"def report(query, fun, num_items = 2):\n",
" print(\"\\nResults for '%s'\" % query)\n",
" print(fun(query, num_items))\n",
" \n",
"def report_term_query(query, num_items = 2):\n",
" return report(query, term_query, num_items)\n",
"\n",
"report_term_query(\"clinton\", 5)\n",
"report_term_query(\"chirac\", 5)\n",
"report_term_query(\"yeltsin\", 5)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Results for 'usa oj simpson attacks media hints at lawsuits washington'\n",
"['usa simpsons life story defies imagination santa monica calif', 'usa oj simpsons incredible reversal of fortune santa monica calif']\n",
"\n",
"Results for 'uk prince charles spearheads british royal revolution london'\n",
"['uk queen elizabeth to hold royal summit report london', 'uk british royal family holds meeting on future london']\n",
"\n",
"Results for 'russia kremlin slams report on return of tsars heir moscow'\n",
"['russia top russian official meets royal relative moscow', 'russia wouldbe russian tsar to return for family ceremony moscow']\n"
]
}
],
"source": [
"# List similar documents based on the query term and the scaled, truncated v matrix\n",
"def doc_query(title, num_to_retrieve = 5):\n",
" query = reverse_titles[title]\n",
" sims = list(enumerate(doc_sim[query, :]))\n",
" nearest = sorted(sims, key=itemgetter(1), reverse=True)\n",
" result = []\n",
" for i in range(1, num_to_retrieve + 1):\n",
" result.append(titles[nearest[i][0]])\n",
" return result\n",
"\n",
"def report_doc_sim(query, num_items = 2):\n",
" return report(query, doc_query, num_items)\n",
"\n",
"report_doc_sim(\"usa oj simpson attacks media hints at lawsuits washington\")\n",
"report_doc_sim(\"uk prince charles spearheads british royal revolution london\")\n",
"report_doc_sim(\"russia kremlin slams report on return of tsars heir moscow\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Results for 'uk camilla charles'\n",
"['uk uks blair says charles can wed camilla report london', 'uk uks blair dismisses charles and camilla reports london', 'uk cool it with camilla major tells charles paper london', 'uk charles under fire over prospect of queen camilla london', 'uk prince charles holds 50th bash for lover camilla london']\n"
]
}
],
"source": [
"# Query new set of terms not based on exact existing documents\n",
"\n",
"def search_query(query, num_to_retrieve = 5):\n",
" vec = np.zeros((X.shape[1]), dtype = np.int32)\n",
" clean = remove_date(query.lower().translate(None, string.punctuation))\n",
" for term in clean.split():\n",
" match = reverse_vocab.get(term, None)\n",
" if match != None:\n",
" vec[match] += 1\n",
" mul = np.multiply(vec, idf)\n",
" sims = np.dot(np.linalg.inv(s_trunc), np.dot(np.transpose(u_trunc), mul))\n",
" sims = sims / np.linalg.norm(sims)\n",
" sime = sims.reshape((1, sims.shape[0]))\n",
" sims = list(enumerate(np.dot(np.transpose(v_trunc), sims)))\n",
" nearest = sorted(sims, key=itemgetter(1), reverse=True)\n",
" result = []\n",
" for i in range(1, num_to_retrieve + 1):\n",
" result.append(titles[nearest[i][0]])\n",
" return result\n",
"\n",
"def report_search_query(query, num_items = 2):\n",
" return report(query, search_query, num_items)\n",
"\n",
"report_search_query(\"uk camilla charles\", 5)"
]
}
],
"metadata": {
"anaconda-cloud": {},
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment