Skip to content

Instantly share code, notes, and snippets.

@devashishd12
Last active July 17, 2016 16:02
Show Gist options
  • Save devashishd12/3d082b1791dde82eaf645f7f3301508a to your computer and use it in GitHub Desktop.
Save devashishd12/3d082b1791dde82eaf645f7f3301508a to your computer and use it in GitHub Desktop.
Benchmark testing of the gensim topic coherence pipeline on 20NG dataset
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import re\n",
"\n",
"from sklearn.datasets import fetch_20newsgroups\n",
"from scipy.stats import pearsonr\n",
"from datetime import datetime\n",
"\n",
"from gensim.models import CoherenceModel\n",
"from gensim.corpora.dictionary import Dictionary"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dataset = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"documents = dataset['data'] # is a list of documents"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"texts = []\n",
"for document in documents:\n",
" # lower case all words\n",
" lowered = document.lower()\n",
" # remove punctuation and split into seperate words\n",
" words = re.findall(r'\\w+', lowered, flags = re.UNICODE | re.LOCALE)\n",
" texts.append(words)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dictionary = Dictionary(texts)\n",
"corpus = [dictionary.doc2bow(text) for text in texts]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"18846\n",
"Dictionary(134435 unique tokens: [u'3ds2scn', u'diagnositic', u'9l2t', u'l1tbk', u'porkification']...)\n"
]
}
],
"source": [
"print len(documents)\n",
"print dictionary"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"topics = [] # list of 100 topics\n",
"for l in open('/home/devashish/datasets/20NG/topics20NG.txt'):\n",
" topics.append([l.split()])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"human_scores = []\n",
"for l in open('/home/devashish/datasets/20NG/gold20NG.txt'):\n",
" human_scores.append(float(l.strip()))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken: 0:04:10.235838\n"
]
}
],
"source": [
"start = datetime.now()\n",
"u_mass = []\n",
"flags = []\n",
"for n, topic in enumerate(topics):\n",
" try:\n",
" cm = CoherenceModel(topics=topic, corpus=corpus, dictionary=dictionary, coherence='u_mass')\n",
" u_mass.append(cm.get_coherence())\n",
" except KeyError:\n",
" flags.append(n)\n",
"end = datetime.now()\n",
"print \"Time taken: %s\" % (end - start)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken: 0:50:03.377507\n"
]
}
],
"source": [
"start = datetime.now()\n",
"c_v = []\n",
"for n, topic in enumerate(topics):\n",
" try:\n",
" cm = CoherenceModel(topics=topic, texts=texts, dictionary=dictionary, coherence='c_v')\n",
" c_v.append(cm.get_coherence())\n",
" except KeyError:\n",
" pass\n",
"end = datetime.now()\n",
"print \"Time taken: %s\" % (end - start)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"final_scores = []\n",
"for n, score in enumerate(human_scores):\n",
" if n not in flags:\n",
" final_scores.append(score)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"93 93 93\n"
]
}
],
"source": [
"print len(u_mass), len(c_v), len(final_scores)\n",
"# 3 topics have words that are not in the dictionary.This is due to a difference\n",
"# in preprocessing or because of the absence of ~900 documents"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.554915406168\n",
"0.616997328993\n"
]
}
],
"source": [
"print pearsonr(u_mass, final_scores)[0]\n",
"print pearsonr(c_v, final_scores)[0]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment