Skip to content

Instantly share code, notes, and snippets.

@herrfz
Last active October 18, 2021 19:27
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save herrfz/7967781 to your computer and use it in GitHub Desktop.
Save herrfz/7967781 to your computer and use it in GitHub Desktop.
Reuters-21578 keyword extraction
Display the source blob
Display the rendered blob
Raw
{
"metadata": {
"name": ""
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "code",
"collapsed": false,
"input": [
"import pandas as pd\n",
"import numpy as np\n",
"from Reuters import *"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# the downloaded dataset\n",
"!ls -la reuters/*.sgm"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"-rw-r--r-- 1 eriza eriza 1324350 Dec 4 1996 reuters/reut2-000.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1254440 Dec 4 1996 reuters/reut2-001.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1217495 Dec 4 1996 reuters/reut2-002.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1298721 Dec 4 1996 reuters/reut2-003.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1321623 Dec 4 1996 reuters/reut2-004.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1388644 Dec 4 1996 reuters/reut2-005.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1254765 Dec 4 1996 reuters/reut2-006.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1256772 Dec 4 1996 reuters/reut2-007.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1410117 Dec 4 1996 reuters/reut2-008.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1338903 Dec 4 1996 reuters/reut2-009.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1371071 Dec 4 1996 reuters/reut2-010.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1304117 Dec 4 1996 reuters/reut2-011.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1323584 Dec 4 1996 reuters/reut2-012.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1129687 Dec 4 1996 reuters/reut2-013.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1128671 Dec 4 1996 reuters/reut2-014.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1258665 Dec 4 1996 reuters/reut2-015.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1316417 Dec 4 1996 reuters/reut2-016.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1546911 Dec 4 1996 reuters/reut2-017.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1258819 Dec 4 1996 reuters/reut2-018.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1261780 Dec 4 1996 reuters/reut2-019.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 1049566 Dec 4 1996 reuters/reut2-020.sgm\r\n",
"-rw-r--r-- 1 eriza eriza 621648 Dec 4 1996 reuters/reut2-021.sgm\r\n"
]
}
],
"prompt_number": 2
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"!grep \\<TOPICS\\>\\<D\\> reuters/*.sgm | wc -l"
],
"language": "python",
"metadata": {},
"outputs": [
{
"output_type": "stream",
"stream": "stdout",
"text": [
"11367\r\n"
]
}
],
"prompt_number": 3
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# read and parse the data\n",
"# this will download the data if it's not yet available locally\n",
"data_streamer = ReutersStreamReader('reuters').iterdocs()\n",
"data = get_minibatch(data_streamer, 50000)\n",
"data"
],
"language": "python",
"metadata": {},
"outputs": [
{
"html": [
"<pre>\n",
"&lt;class 'pandas.core.frame.DataFrame'&gt;\n",
"Int64Index: 19716 entries, 0 to 19715\n",
"Data columns (total 2 columns):\n",
"text 19716 non-null values\n",
"tags 19716 non-null values\n",
"dtypes: object(2)\n",
"</pre>"
],
"metadata": {},
"output_type": "pyout",
"prompt_number": 4,
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"Int64Index: 19716 entries, 0 to 19715\n",
"Data columns (total 2 columns):\n",
"text 19716 non-null values\n",
"tags 19716 non-null values\n",
"dtypes: object(2)"
]
}
],
"prompt_number": 4
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# 'text' is combined document title and body\n",
"data.head()"
],
"language": "python",
"metadata": {},
"outputs": [
{
"html": [
"<div style=\"max-height:1000px;max-width:1500px;overflow:auto;\">\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>text</th>\n",
" <th>tags</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td> SANDOZ PLANS WEEDKILLER JOINT VENTURE IN USSR\\...</td>\n",
" <td> [usa, ussr]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td> TAIWAN REJECTS TEXTILE MAKERS EXCHANGE RATE PL...</td>\n",
" <td> [usa, taiwan]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td> NATIONAL FSI INC &lt;NFSI&gt; 4TH QTR LOSS\\n\\nShr lo...</td>\n",
" <td> [earn, usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td> OCCIDENTAL &lt;OXY&gt; OFFICIAL RESIGNS\\n\\nMidCon Co...</td>\n",
" <td> [usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td> ITALY'S BNL TO ISSUE 120 MLN DLR CONVERTIBLE B...</td>\n",
" <td> [italy]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"metadata": {},
"output_type": "pyout",
"prompt_number": 5,
"text": [
" text tags\n",
"0 SANDOZ PLANS WEEDKILLER JOINT VENTURE IN USSR\\... [usa, ussr]\n",
"1 TAIWAN REJECTS TEXTILE MAKERS EXCHANGE RATE PL... [usa, taiwan]\n",
"2 NATIONAL FSI INC <NFSI> 4TH QTR LOSS\\n\\nShr lo... [earn, usa]\n",
"3 OCCIDENTAL <OXY> OFFICIAL RESIGNS\\n\\nMidCon Co... [usa]\n",
"4 ITALY'S BNL TO ISSUE 120 MLN DLR CONVERTIBLE B... [italy]"
]
}
],
"prompt_number": 5
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from sklearn.preprocessing import LabelBinarizer\n",
"\n",
"# binary encode the tags\n",
"lb = LabelBinarizer()\n",
"Y = lb.fit_transform(data.tags)\n",
"Y"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 6,
"text": [
"array([[0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" ..., \n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0],\n",
" [0, 0, 0, ..., 0, 0, 0]])"
]
}
],
"prompt_number": 6
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from sklearn.feature_extraction.text import TfidfVectorizer\n",
"\n",
"# get the TF-IDF of the text\n",
"vec = TfidfVectorizer(min_df=2, sublinear_tf=True, decode_error='ignore')\n",
"X = vec.fit_transform(data.text)\n",
"X"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 7,
"text": [
"<19716x25497 sparse matrix of type '<type 'numpy.float64'>'\n",
"\twith 1509007 stored elements in Compressed Sparse Row format>"
]
}
],
"prompt_number": 7
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# split into train and test set\n",
"N = int(.8 * X.shape[0])\n",
"Xtr, ytr = X[:N,:], Y[:N,:]\n",
"Xte, yte = X[N:,:], Y[N:,:]"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 8
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# there are warnings of ill-defined recall/precision etc.\n",
"# just ignore them for now\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 19
},
{
"cell_type": "code",
"collapsed": true,
"input": [
"from sklearn.multiclass import OneVsRestClassifier\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.grid_search import GridSearchCV\n",
"\n",
"# logistic regression parameter to optimise\n",
"params = {\"estimator__C\": np.logspace(1, 1.5, num=5)}\n",
"# use OneVsRestClassifier for multiclass learning\n",
"model = OneVsRestClassifier(LogisticRegression())\n",
"# do a grid search on the params, with 5-fold cross-validation\n",
"# optimise for F1-score\n",
"clf = GridSearchCV(model, param_grid=params, scoring='f1', n_jobs=-1, cv=5)\n",
"clf.fit(Xtr, ytr)\n",
"clf.best_score_, clf.best_params_"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 27,
"text": [
"(0.83655035338102735, {'estimator__C': 31.622776601683793})"
]
}
],
"prompt_number": 27
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"from sklearn.metrics import f1_score\n",
"\n",
"# compute predictions on test set\n",
"pred = clf.predict(Xte)\n",
"# compute F1-score on test set\n",
"f1_score(yte, pred)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 28,
"text": [
"0.77208685519257036"
]
}
],
"prompt_number": 28
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"# a quick look into some example predictions\n",
"# compare with tags in test data\n",
"tags = []\n",
"for n in xrange(20):\n",
" tags.append((lb.classes_[yte[n]==1], lb.classes_[pred[n]==1]))\n",
" \n",
"pd.DataFrame(tags, columns=['actual tags', 'predicted tags'])"
],
"language": "python",
"metadata": {},
"outputs": [
{
"html": [
"<div style=\"max-height:1000px;max-width:1500px;overflow:auto;\">\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>actual tags</th>\n",
" <th>predicted tags</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0 </th>\n",
" <td> [earn, usa]</td>\n",
" <td> [earn, usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1 </th>\n",
" <td> [gnp, west-germany]</td>\n",
" <td> [west-germany]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2 </th>\n",
" <td> [earn, usa]</td>\n",
" <td> [earn, usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3 </th>\n",
" <td> [acq, usa]</td>\n",
" <td> [acq, usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4 </th>\n",
" <td> [earn, uk, usa]</td>\n",
" <td> []</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5 </th>\n",
" <td> [usa]</td>\n",
" <td> [usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6 </th>\n",
" <td> [brazil, gnp, imf, trade]</td>\n",
" <td> [brazil, imf]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7 </th>\n",
" <td> [usa]</td>\n",
" <td> [usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8 </th>\n",
" <td> [cpu, usa]</td>\n",
" <td> []</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9 </th>\n",
" <td> [earn, usa]</td>\n",
" <td> [earn, usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td> [france]</td>\n",
" <td> [france]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td> [usa]</td>\n",
" <td> [usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td> [earn, usa]</td>\n",
" <td> [earn, usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td> [acq, usa]</td>\n",
" <td> [usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td> [usa]</td>\n",
" <td> [usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td> [barley, corn, cotton, grain, sorghum, usa, wh...</td>\n",
" <td> [corn, grain, usa, wheat]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td> [belgium]</td>\n",
" <td> []</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td> [acq, usa]</td>\n",
" <td> [usa]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td> [earn, sweden]</td>\n",
" <td> [earn, sweden]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td> [coffee, colombia, ico-coffee]</td>\n",
" <td> [coffee, colombia, ico-coffee]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"metadata": {},
"output_type": "pyout",
"prompt_number": 29,
"text": [
" actual tags predicted tags\n",
"0 [earn, usa] [earn, usa]\n",
"1 [gnp, west-germany] [west-germany]\n",
"2 [earn, usa] [earn, usa]\n",
"3 [acq, usa] [acq, usa]\n",
"4 [earn, uk, usa] []\n",
"5 [usa] [usa]\n",
"6 [brazil, gnp, imf, trade] [brazil, imf]\n",
"7 [usa] [usa]\n",
"8 [cpu, usa] []\n",
"9 [earn, usa] [earn, usa]\n",
"10 [france] [france]\n",
"11 [usa] [usa]\n",
"12 [earn, usa] [earn, usa]\n",
"13 [acq, usa] [usa]\n",
"14 [usa] [usa]\n",
"15 [barley, corn, cotton, grain, sorghum, usa, wh... [corn, grain, usa, wheat]\n",
"16 [belgium] []\n",
"17 [acq, usa] [usa]\n",
"18 [earn, sweden] [earn, sweden]\n",
"19 [coffee, colombia, ico-coffee] [coffee, colombia, ico-coffee]"
]
}
],
"prompt_number": 29
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": []
}
],
"metadata": {}
}
]
}
# Reuters-21578 dataset downloader and parser
#
# Author: Eustache Diemert <eustache@diemert.fr>
# http://scikit-learn.org/stable/auto_examples/applications/plot_out_of_core_classification.html
#
# Modified by @herrfz, get pandas DataFrame from the orig SGML
# License: BSD 3 clause
from __future__ import print_function
import re
import os.path
import fnmatch
import sgmllib
import urllib
import tarfile
import itertools
from pandas import DataFrame
###############################################################################
# Reuters Dataset related routines
###############################################################################
def _not_in_sphinx():
# Hack to detect whether we are running by the sphinx builder
return '__file__' in globals()
class ReutersParser(sgmllib.SGMLParser):
"""Utility class to parse a SGML file and yield documents one at a time."""
def __init__(self, verbose=0):
sgmllib.SGMLParser.__init__(self, verbose)
self._reset()
def _reset(self):
self.in_title = 0
self.in_body = 0
self.in_topics = 0
self.in_topic_d = 0
self.title = ""
self.body = ""
self.topics = []
self.topic_d = ""
def parse(self, fd):
self.docs = []
for chunk in fd:
self.feed(chunk)
for doc in self.docs:
yield doc
self.docs = []
self.close()
def handle_data(self, data):
if self.in_body:
self.body += data
elif self.in_title:
self.title += data
elif self.in_topic_d:
self.topic_d += data
def start_reuters(self, attributes):
pass
def end_reuters(self):
self.body = re.sub(r'\s+', r' ', self.body)
self.docs.append({'title': self.title,
'body': self.body,
'topics': self.topics})
self._reset()
def start_title(self, attributes):
self.in_title = 1
def end_title(self):
self.in_title = 0
def start_body(self, attributes):
self.in_body = 1
def end_body(self):
self.in_body = 0
def start_topics(self, attributes):
self.in_topics = 1
def end_topics(self):
self.in_topics = 0
def start_d(self, attributes):
self.in_topic_d = 1
def end_d(self):
self.in_topic_d = 0
self.topics.append(self.topic_d)
self.topic_d = ""
class ReutersStreamReader():
"""Iterate over documents of the Reuters dataset.
The Reuters archive will automatically be downloaded and uncompressed if
the `data_path` directory does not exist.
Documents are represented as dictionaries with 'body' (str),
'title' (str), 'topics' (list(str)) keys.
"""
DOWNLOAD_URL = ('http://archive.ics.uci.edu/ml/machine-learning-databases/'
'reuters21578-mld/reuters21578.tar.gz')
ARCHIVE_FILENAME = 'reuters21578.tar.gz'
def __init__(self, data_path):
self.data_path = data_path
if not os.path.exists(self.data_path):
self.download_dataset()
def download_dataset(self):
"""Download the dataset."""
print("downloading dataset (once and for all) into %s" %
self.data_path)
os.mkdir(self.data_path)
def progress(blocknum, bs, size):
total_sz_mb = '%.2f MB' % (size / 1e6)
current_sz_mb = '%.2f MB' % ((blocknum * bs) / 1e6)
if _not_in_sphinx():
print('\rdownloaded %s / %s' % (current_sz_mb, total_sz_mb),
end='')
urllib.urlretrieve(self.DOWNLOAD_URL,
filename=os.path.join(self.data_path,
self.ARCHIVE_FILENAME),
reporthook=progress)
if _not_in_sphinx():
print('\r', end='')
print("untaring data ...")
tfile = tarfile.open(os.path.join(self.data_path,
self.ARCHIVE_FILENAME),
'r:gz')
tfile.extractall(self.data_path)
print("done !")
def iterdocs(self):
"""Iterate doc by doc, yield a dict."""
for root, _dirnames, filenames in os.walk(self.data_path):
for filename in fnmatch.filter(filenames, '*.sgm'):
path = os.path.join(root, filename)
parser = ReutersParser()
for doc in parser.parse(open(path)):
yield doc
def get_minibatch(doc_iter, size):
"""Extract a minibatch of examples, return a tuple X, y.
Note: size is before excluding invalid docs with no topics assigned.
"""
data = [('{title}\n\n{body}'.format(**doc), doc['topics'])
for doc in itertools.islice(doc_iter, size)
if doc['topics']]
if not len(data):
return DataFrame([])
else:
return DataFrame(data, columns=['text', 'tags'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment