Skip to content

Instantly share code, notes, and snippets.

@JaimieMurdock
Last active February 6, 2019 21:29
Show Gist options
  • Save JaimieMurdock/19f0d0985ca759c56abb3b771212c3a0 to your computer and use it in GitHub Desktop.
Save JaimieMurdock/19f0d0985ca759c56abb3b771212c3a0 to your computer and use it in GitHub Desktop.
LDA Model Comparison Experiment
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LDA Model Comparison\n",
"This notebook presents a framework for comparing LDA models.\n",
"\n",
"This follows a 4-step process:\n",
"1. Load in the two LDA models under investigation.\n",
"2. Perform heirarchical clustering among all topics in the two models.\n",
"3. Plot the topic alignment.\n",
"4. Plot the topic alignment agreement."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Load the two models"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running from notebook, using serial load function.\n",
"Loading LDA data from /tmp/ap/models/ap-nltk-en-freq5-N2000-LDA-K20-document-20.npz\n",
"Running from notebook, using serial load function.\n",
"Loading LDA data from /tmp/ap/models/ap-nltk-en-freq5-N2000-LDA-K20-document-20.npz\n"
]
}
],
"source": [
"from vsm import *\n",
"from configparser import ConfigParser\n",
"\n",
"def load_from_config(config_file, k):\n",
" \n",
" config = ConfigParser()\n",
" config.read(config_file)\n",
"\n",
" # path variables\n",
" path = config.get('main', 'path')\n",
" context_type = config.get('main', 'context_type')\n",
" corpus_file = config.get('main', 'corpus_file')\n",
" model_pattern = config.get('main', 'model_pattern') \n",
"\n",
" c = Corpus.load(corpus_file)\n",
" m = LdaCgsMulti.load(model_pattern.format(k))\n",
" v = LdaCgsViewer(c, m)\n",
" return v\n",
"\n",
"v1 = load_from_config('ap.ini', k=20)\n",
"v2 = load_from_config('ap.ini', k=20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following cells are more or less automatically importing everything you need to do the alignment"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from vsm import *\n",
"import numpy as np\n",
"import itertools\n",
"import copy\n",
"from scipy.stats import spearmanr, rankdata\n",
"from scipy.stats import pearsonr\n",
"\n",
"from random import randrange\n",
"import random\n",
"\n",
"\n",
"import os.path\n",
"from configparser import ConfigParser\n",
"\n",
"def is_valid_filepath(parser, arg):\n",
" if not os.path.exists(arg):\n",
" parser.error(\"The file %s does not exist!\" % arg)\n",
" else:\n",
" return arg\n",
" \n",
"\n",
"def deep_subcorpus(labels):\n",
" # resolve labels to indexes\n",
" docs_labels = [v._res_doc_type(d) for d in labels]\n",
" docs, labels = zip(*docs_labels)\n",
" \n",
" # get lengths of all contexts\n",
" lens = np.array([len(ctx) for ctx in v.corpus.view_contexts('book')])\n",
" \n",
" # get the context_type index for use with context_data\n",
" ctx_idx = v.corpus.context_types.index(v.model.context_type)\n",
" \n",
" # get original slices\n",
" slice_idxs = [range(s.start,s.stop) for i, s in enumerate(v.corpus.view_contexts('book',as_slices=True)) \n",
" if i in docs]\n",
" \n",
" new_corpus = copy.deepcopy(v.corpus)\n",
" # reduce corpus to subcorpus \n",
" new_corpus.corpus = new_corpus.corpus[list(itertools.chain(*slice_idxs))]\n",
" \n",
" # reinitialize index fields\n",
" for i,d in enumerate(docs):\n",
" new_corpus.context_data[ctx_idx]['idx'][d] = lens[list(docs[:i+1])].sum()\n",
" \n",
" # reduce metadata to only the new subcorpus\n",
" new_corpus.context_data[ctx_idx] = new_corpus.context_data[ctx_idx][list(docs)]\n",
" \n",
" \n",
" \n",
" return new_corpus\n",
"\n",
"\n",
"from vsm.spatial import *\n",
"\n",
"import numpy as np\n",
"import scipy.cluster.hierarchy as sch\n",
"from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
"\n",
"__all__ = ['model_dist','avg_log_likelihood','perplexity','model_stats','plot_topic_similarity','compare_models']\n",
"\n",
"def topic_overlap(v1,v2):\n",
" \"\"\"\n",
" Calculates the overlap of two corpora and recalculates normalized topic matricies\n",
" including only the overlapping words, ordered by the overlap order.\n",
" \n",
" Returns the joint vocabulary, v1.topics() and v2.topics() filtered and renormed.\n",
" \"\"\"\n",
" vocab = set(v1.corpus.words)\n",
" t1 = np.array(v1.topics())['value']\n",
" t2 = np.array(v2.topics())['value']\n",
" \n",
" if v1.corpus.words_int != v2.corpus.words_int:\n",
" print(\"corpus.words_int different, aligning words\")\n",
" vocab = vocab.intersection(v2.corpus.words)\n",
" print(\"preserving {}% of words in v1; \".format(100 * len(vocab) / float(len(v1.corpus.words))))\n",
" print(\"{}% of words in v2 \".format(100 * len(vocab) / float(len(v2.corpus.words))))\n",
" \n",
" t1 = t1[:,np.array([v1.corpus.words_int[word] for word in vocab])]\n",
" t1 = (t1.T / t1.sum(axis=1)).T\n",
"\n",
" t2 = t2[:,np.array([v2.corpus.words_int[word] for word in vocab])]\n",
" t2 = (t2.T / t2.sum(axis=1)).T\n",
"\n",
" return (vocab, t1, t2)\n",
"\n",
"def model_dist(v1,v2, dist_fn=JS_dist):\n",
" \"\"\"\n",
" Takes two LdaViewer objects and a distance metric and calculates the topic-topic distance.\n",
" \"\"\"\n",
" vocab, t1, t2 = topic_overlap(v1,v2)\n",
" #t1 = np.array(v1.topics())['value']\n",
" #t2 = np.array(v2.topics())['value']\n",
" combined = np.concatenate((t1,t2))\n",
" \n",
" # NOTE: Doing this by row to reduce memory requirements and time requirements\n",
" D = np.column_stack(np.lib.pad(dist_fn(combined[i:,:],row.T), \n",
" (i,0), 'constant', constant_values=0)\n",
" if i + 1 < len(combined) else np.zeros(len(combined))\n",
" for i,row in enumerate(combined))\n",
" return D + D.T - np.diag(D.diagonal())\n",
" # Old simple version:\n",
" # return np.column_stack(dist_fn(combined,row.T) for row in combined)\n",
"\n",
"def doc_overlap(v1,v2, context_type, norm=True):\n",
" context_label = context_type + '_label'\n",
" ids = np.intersect1d(v1.corpus.view_metadata(context_type)[context_label], \n",
" v2.corpus.view_metadata(context_type)[context_label])\n",
" d1 = v1.doc_topic_matrix(ids)\n",
" d2 = v2.doc_topic_matrix(ids)\n",
"\n",
" # renormalize so that each topic is now a document probability\n",
" if norm:\n",
" d1 = (d1 / d1.sum(axis=0))\n",
" d2 = (d2 / d2.sum(axis=0))\n",
" \n",
" # original d1 and d2 are doc_topic, switch to topic_doc\n",
" return (ids, d1.T, d2.T)\n",
"\n",
"def model_doc_dist(v1, v2, context_type, dist_fn=JS_dist):\n",
" ids, d1, d2 = doc_overlap(v1,v2, context_type)\n",
" combined = np.concatenate((d1,d2))\n",
"\n",
" # NOTE: Doing this by row to reduce memory requirements and time requirements\n",
" D = np.column_stack(np.lib.pad(dist_fn(combined[i:,:],row.T),\n",
" (i,0), 'constant', constant_values=0)\n",
" if i + 1 < len(combined) else np.zeros(len(combined))\n",
" for i,row in enumerate(combined))\n",
" return D + D.T - np.diag(D.diagonal())\n",
"\n",
"\n",
"def pearson(v1, v2, context_type):\n",
" context_label = context_type + '_label'\n",
" ids, d1, d2 = doc_overlap(v1, v2, context_type)\n",
"\n",
" r_all = []\n",
" for id in ids:\n",
" sim1 = v1.dist_doc_doc(id)\n",
" ix = np.in1d(sim1['doc'], ids).reshape(sim1['doc'].shape)\n",
" sim1 = sim1[np.where(ix)]\n",
" sim1 = sim1[sim1['doc'].argsort()]\n",
" \n",
" sim2 = v2.dist_doc_doc(id)\n",
" ix = np.in1d(sim2['doc'], ids).reshape(sim2['doc'].shape)\n",
" sim2 = sim2[np.where(ix)]\n",
" sim2 = sim2[sim2['doc'].argsort()]\n",
" \n",
" r, pval = pearsonr(sim1['value'], sim2['value'])\n",
" r_all.append(r)\n",
" \n",
" return sum(r_all)/len(r_all)\n",
"\n",
"def spearman(v1, v2, context_type):\n",
" context_label = context_type + '_label'\n",
" ids, d1, d2 = doc_overlap(v1, v2, context_type)\n",
"\n",
" r_all = []\n",
" for id in ids:\n",
" sim1 = v1.dist_doc_doc(id)\n",
" ix = np.in1d(sim1['doc'], ids).reshape(sim1['doc'].shape)\n",
" sim1 = sim1[np.where(ix)]\n",
" sim1 = sim1[sim1['doc'].argsort()]\n",
" \n",
" sim2 = v2.dist_doc_doc(id)\n",
" ix = np.in1d(sim2['doc'], ids).reshape(sim2['doc'].shape)\n",
" sim2 = sim2[np.where(ix)]\n",
" sim2 = sim2[sim2['doc'].argsort()]\n",
"\n",
" r, pval = spearmanr(rankdata(sim1['value']), rankdata(sim2['value']))\n",
" r_all.append(r)\n",
"\n",
" return sum(r_all)/len(r_all)\n",
"\n",
"def recall(v1,v2, context_type,N=10):\n",
" context_label = context_type + '_label'\n",
" ids, d1, d2 = doc_overlap(v1, v2, context_type)\n",
"\n",
" r_all = []\n",
" for id in ids:\n",
" sim1 = v1.dist_doc_doc(id)\n",
" ix = np.in1d(sim1['doc'], ids).reshape(sim1['doc'].shape)\n",
" sim1 = sim1[np.where(ix)]\n",
"\n",
" sim2 = v2.dist_doc_doc(id)\n",
" ix = np.in1d(sim2['doc'], ids).reshape(sim2['doc'].shape)\n",
" sim2 = sim2[np.where(ix)]\n",
"\n",
" sim1 = np.array(sim1[1:N+2])['doc']\n",
" sim2 = np.array(sim2[1:N+2])['doc']\n",
" r = np.where(np.in1d(sim1, sim2))[0].size / float(N)\n",
" r_all.append(r)\n",
" \n",
" return sum(r_all)/len(r_all)\n",
"\n",
"def avg_log_likelihood(viewer):\n",
" \"\"\" Calculates the average log likelihood per token. \"\"\"\n",
" return viewer.model.log_probs[-1][1] / len(viewer.corpus.corpus)\n",
"\n",
"def perplexity(viewer):\n",
" \"\"\" Calculates the perplexity. \"\"\"\n",
" return np.exp(-1*avg_log_likelihood(viewer))\n",
"\n",
"def model_stats(*viewers):\n",
" \"\"\"\n",
" Prints a table of avg log likelihood and perplexity for each viewer.\n",
" \"\"\"\n",
" print(\"model\", \"k\", \"tokens\", \"types\", \"avg-log-likelihood\", \"perplexity\")\n",
" for i,v in enumerate(viewers):\n",
" print(\"M{}\".format(i), v.model.K, len(v.corpus.corpus), len(v.corpus.words), avg_log_likelihood(v), perplexity(v))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Cluster Functions"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Populating the interactive namespace from numpy and matplotlib\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jaimie/anaconda3/lib/python3.6/site-packages/IPython/core/magics/pylab.py:160: UserWarning: pylab import has clobbered these variables: ['copy', 'angle', 'random', '__version__']\n",
"`%matplotlib` prevents importing * from pylab and numpy\n",
" \"\\n`%matplotlib` prevents importing * from pylab and numpy\"\n"
]
}
],
"source": [
"%pylab inline\n",
"from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
"\n",
"import scipy\n",
"import scipy.cluster.hierarchy as sch\n",
"\n",
"def create_dendogram(D, xdim=None, method='ward', show=True):\n",
" # create_dendogram\n",
" fig = pylab.figure(figsize=(3,9))\n",
" Y = sch.linkage(D, method=method)\n",
" Z = sch.dendrogram(Y, orientation='right')\n",
" \n",
" # create plot\n",
" if show:\n",
" ax = fig.gca()\n",
" ax.set_xticks([])\n",
" if xdim is not None:\n",
" ax.set_yticklabels([\"M2 \" + str(idx - xdim) if idx >= xdim else \"M1 \" + str(idx)\n",
" for idx in np.array(Z['leaves'])])\n",
" fig.show()\n",
" else:\n",
" pylab.close()\n",
" return Z\n",
"\n",
"import pylab\n",
"def plot_dendogram(D, Z, xdim,ydim, filter_axis=False, show_self=False):\n",
" # Compute and plot dendrogram.\n",
" fig = pylab.figure(figsize=(12,10))\n",
" \n",
" D = np.copy(D)\n",
" \n",
" # Now make plot better\n",
" if not show_self:\n",
" D[:xdim,:xdim] = 0\n",
" D[xdim:,xdim:] = 0\n",
"\n",
" # Plot distance matrix.\n",
" axmatrix = fig.add_axes([0.3,0.1,0.6,0.8])\n",
" index = np.array(Z['leaves'])\n",
" if filter_axis:\n",
" D = D[index[index < xdim],:]\n",
" D = D[:,index[index >= xdim]]\n",
" else:\n",
" D = D[index,:]\n",
" D = D[:,index]\n",
" im = axmatrix.imshow(D, interpolation='none', cmap=mpl.cm.Greys_r, vmax=0.25, vmin=0.0)#aspect='auto', origin='lower')\n",
" if filter_axis:\n",
" axmatrix.set_xticks(arange(ydim))\n",
" axmatrix.set_xticklabels([str(i -xdim) for i in index[index >= xdim]])\n",
" axmatrix.set_yticks(arange(xdim))\n",
" axmatrix.set_yticklabels([str(i) for i in index[index < xdim]])\n",
" else:\n",
" axmatrix.set_xticks(arange(xdim+ydim))\n",
" axmatrix.set_xticklabels(index)\n",
" axmatrix.set_yticks(arange(xdim+ydim))\n",
" axmatrix.set_yticklabels(index)\n",
" \n",
" # create an axes on the right side of ax. The width of cax will be 5%\n",
" # of ax and the padding between cax and ax will be fixed at 0.05 inch.\n",
" divider = make_axes_locatable(axmatrix)\n",
" cax = divider.append_axes(\"right\", size=\"5%\", pad=0.15)\n",
"\n",
" fig.colorbar(im, cax=cax)\n",
"\n",
" # Display and save figure.\n",
" fig.show()\n",
"\n",
"\n",
"def plot_dendogram(D, Z, xdim,ydim, dist=None, filter_axis=False, show_self=False, alignment=None):\n",
" # Compute and plot dendrogram.\n",
" fig = figure(figsize=(12,10))\n",
" \n",
" D = np.copy(D)\n",
" \n",
" # Now make plot better\n",
" if not show_self:\n",
" D[:xdim,:xdim] = 0\n",
" D[xdim:,xdim:] = 0\n",
"\n",
" # Plot distance matrix.\n",
" axmatrix = fig.add_axes([0.3,0.1,0.6,0.8])\n",
" index = np.array(Z['leaves'])\n",
" if filter_axis:\n",
" D = D[index[index < xdim],:]\n",
" D = D[:,index[index >= xdim]]\n",
" else:\n",
" D = D[index,:]\n",
" D = D[:,index]\n",
" \n",
" # generate palette\n",
" palette=cm.Blues_r\n",
" palette.set_bad(alpha=0.0)\n",
" MAX = np.sort(D.flatten())[int(.1*D.shape[0]*D.shape[1])]\n",
" MAX = np.max(np.diagonal(dist[:,dist.argsort(axis=0)[1]]))\n",
" MAX *= 1.1\n",
" im = axmatrix.imshow(D, interpolation='none', cmap=palette, vmax=MAX, vmin=0.0)#aspect='auto', origin='lower')\n",
" if filter_axis:\n",
" axmatrix.set_xticks(arange(ydim))\n",
" axmatrix.set_xticklabels([str(i - xdim) for i in index[index >= xdim]])\n",
" axmatrix.set_yticks(arange(xdim))\n",
" axmatrix.set_yticklabels([str(i) for i in index[index < xdim]])\n",
" else:\n",
" axmatrix.set_xticks(arange(xdim+ydim))\n",
" axmatrix.set_xticklabels(index)\n",
" axmatrix.set_yticks(arange(xdim+ydim))\n",
" axmatrix.set_yticklabels(index)\n",
" \n",
" if alignment is not None:\n",
" axmatrix.autoscale(False)\n",
" ys,xs = zip(*alignment)\n",
" xindex = index[index >= xdim] - xdim\n",
" yindex = index[index < xdim]\n",
" \n",
" xs = np.array([np.squeeze(np.where(xindex == x)) for x in xs])\n",
" ys = np.array([np.squeeze(np.where(yindex == y)) for y in ys])\n",
" \n",
" axmatrix.scatter(xs, ys, marker='o', s=125, color='w', lw=4, edgecolor='k')\n",
" \n",
" title(\"Jensen-Shannon Distance from topic to topic\")\n",
" ylabel(\"Model 1\")\n",
" xlabel(\"Model 2\")\n",
" # Plot colorbar.\n",
" #axcolor = fig.add_axes([0.91,0.1,0.02,0.8])\n",
" divider = make_axes_locatable(axmatrix)\n",
" cax = divider.append_axes(\"right\", size=\"3%\", pad=0.1)\n",
" colorbar(im, cax=cax, extend='max')\n",
"\n",
"def dendogram_demo():\n",
" # Generate features and distance matrix.\n",
" x = scipy.rand(40)\n",
" D = scipy.zeros([40,40])\n",
" for i in range(40):\n",
" for j in range(40):\n",
" D[i,j] = abs(x[i] - x[j])\n",
" \n",
"\n",
" gram = create_dendogram(D, show=False)\n",
" #dendo = plot_dendogram(D,gram,10,30, show_self=True)\n",
" dendo = plot_dendogram(D,gram,10,30, D, show_self=False)\n",
" #dendo = plot_dendogram(D,gram,10,30, show_self=True, filter_axis=True)\n",
" dendo = plot_dendogram(D,gram,10,30, D, show_self=False, filter_axis=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"## Model comparison code\n",
"\n",
"def plot_topic_similarity(v1, v2, dist=None, dist_fn=JS_dist, sorted=True, alignment=None):\n",
" # Calculate distance between all topics\n",
" if dist is None:\n",
" dist = model_dist(v1, v2, dist_fn)\n",
" # print dist\n",
"\n",
" if sorted:\n",
" dendo = create_dendogram(dist, xdim=v1.model.K, method='ward') \n",
" plot_dendogram(dist, dendo, v1.model.K, v2.model.K, dist=dist, filter_axis=True, alignment=alignment)\n",
" else:\n",
" # plot distance heatmap\n",
" figure(figsize=(12,10))\n",
" \n",
" # filtering axis\n",
" xdim = v1.model.K\n",
" index = np.arange(0,len(dist))\n",
" D = np.copy(dist)\n",
" D = D[index[index < xdim],:]\n",
" D = D[:,index[index >= xdim]]\n",
" \n",
" # generate palette\n",
" palette=cm.Blues_r\n",
" palette.set_bad(alpha=0.0)\n",
" MAX = np.sort(D.flatten())[int(.1*D.shape[0]*D.shape[1])]\n",
" MAX = np.max(np.diagonal(dist[:,dist.argsort(axis=0)[1]]))\n",
" MAX *= 1.1\n",
" \n",
" # plot data\n",
" im = imshow(D, cmap=palette, vmax=MAX, vmin=0.0, interpolation='none')\n",
" if alignment is not None:\n",
" ax = gca()\n",
" ax.autoscale(False)\n",
" ys,xs = zip(*alignment)\n",
" scatter(xs,ys, marker='x', s=200, color='w', lw=5)\n",
" \n",
" # label data\n",
" xticks(np.arange(D.shape[1]))\n",
" yticks(np.arange(D.shape[0]))\n",
" title(\"Jensen-Shannon Distance from topic to topic\")\n",
" xlabel(\"Model 1\")\n",
" ylabel(\"Model 2\")\n",
" \n",
" # create heatmap\n",
" divider = make_axes_locatable(gca())\n",
" cax = divider.append_axes(\"right\", size=\"3%\", pad=0.1)\n",
" colorbar(im, cax=cax, extend='max')\n",
"\n",
"def compare_models(v1, v2, context_type='document', dist_fn=JS_dist, sorted=True):\n",
" model_stats(v1,v2)\n",
" plot_topic_similarity(v1,v2,dist_fn=dist_fn,sorted=sorted)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Create asymmetric model comparison matrix"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def alignment_fitness(topic_pairs, v1, v2, dist=None, dist_fn=JS_dist):\n",
" \"\"\"\n",
" Takes a list of topic pair tuples and returns the sum of the JS_dist between them\n",
" \"\"\" \n",
" if dist is None:\n",
" dist = model_dist(v1, v2, dist_fn)\n",
" if dist.shape[0] == (v1.model.K + v2.model.K):\n",
" dist = filter_dist(v1, v2, dist)\n",
" \n",
" return sum([dist[t[0]][t[1]] for t in topic_pairs])\n",
"\n",
"def filter_dist(v1,v2,dist):\n",
" xdim = v1.model.K\n",
" index = np.arange(0,len(dist))\n",
" D = np.copy(dist)\n",
" D = D[index[index < xdim],:]\n",
" D = D[:,index[index >= xdim]]\n",
" return D\n",
"\n",
"def plot_alignment(v1, v2, dist, alignment=None, fn_name=None):\n",
" # Calculate distance between all topics\n",
" if dist is None:\n",
" dist = model_dist(v1, v2, dist_fn)\n",
" if dist.shape[0] == (v1.model.K + v2.model.K):\n",
" dist = filter_dist(v1, v2, dist)\n",
" \n",
" Xs = cm.jet_r(dist)\n",
" \n",
" if alignment is None:\n",
" alpha = 1.0\n",
" else:\n",
" alpha = np.zeros(dist.shape)\n",
" alpha[zip(*alignment)] = 1\n",
" \n",
" Xs[:,:,3] = alpha\n",
" \n",
" # plot distance heatmap\n",
" figure(figsize=(12,10))\n",
" imshow(Xs, interpolation='nearest', cmap='jet_r', vmin=0, vmax=1.0)\n",
" colorbar()\n",
" #imshow(alpha, interpolation=None, cmap=get_cmap('binary'), vmin=0, vmax=1.0, alpha=0.5)\n",
" xticks(np.arange(dist.shape[1]))\n",
" yticks(np.arange(dist.shape[0]))\n",
" \n",
" if fn_name is not None:\n",
" title(\"Topic Alignment %s() using Jensen-Shannon Distance\" % fn_name)\n",
" else:\n",
" title(\"Topic Alignment\")\n",
" xlabel(\"Model 1\")\n",
" ylabel(\"Model 2\")\n",
" show()\n",
"\n",
"\n",
"# In[46]:\n",
"\n",
"def basic_alignment(v1, v2, dist=None, dist_fn=JS_dist, debug=False):\n",
" \"\"\"\n",
" Simply aligns to the closest topic, allowing for multiple assignment. \n",
" Properties:\n",
" non-surjective, non-injective\n",
" \"\"\"\n",
" if dist is None:\n",
" dist = model_dist(v1, v2, dist_fn)\n",
" if dist.shape[0] == (v1.model.K + v2.model.K):\n",
" dist = filter_dist(v1, v2, dist)\n",
" \n",
" alignment = []\n",
" for i, topic in enumerate(dist):\n",
" # topic = a[i]\n",
" #s = topic[topic.argsort()]\n",
" #sim = topic.argsort()[s < 0.05]\n",
" closest = topic.argsort()[0]\n",
" alignment.append((i, closest))\n",
" if debug:\n",
" print(i, closest, topic[closest])\n",
" \n",
" return alignment\n",
"\n",
"\n",
"# In[47]:\n",
"\n",
"def naive_alignment(v1, v2, dist=None, dist_fn=JS_dist, debug=False):\n",
" \"\"\"\n",
" First naive overlap detector just goes to next closest element if the first topic has already been assigned\n",
" \n",
" Properties: \n",
" k1 < k2: injective, non-surjective\n",
" k1 == k2: bijective\n",
" \"\"\"\n",
" if v1.model.K > v2.model.K:\n",
" raise ValueError(\"Models must have k1 <= k2\")\n",
" if dist is None:\n",
" dist = model_dist(v1, v2, dist_fn)\n",
" if dist.shape[0] == (v1.model.K + v2.model.K):\n",
" dist = filter_dist(v1, v2, dist)\n",
" \n",
" alignment = []\n",
" aligned = []\n",
" for i, topic in enumerate(dist):\n",
" # topic = a[i]\n",
" #s = topic[topic.argsort()]\n",
" #sim = topic.argsort()[s < 0.05]\n",
" topic_idx = 0\n",
" closest = topic.argsort()[topic_idx]\n",
" if debug:\n",
" print(i, closest, topic[closest])\n",
" \n",
" while closest in aligned:\n",
" topic_idx += 1\n",
" closest = topic.argsort()[topic_idx]\n",
" if debug:\n",
" print(i, closest, topic[closest])\n",
" \n",
" \n",
" aligned.append(closest)\n",
" alignment.append((i, closest))\n",
" \n",
" return alignment\n",
"\n",
"def compare(sample_v, v, context_type='document'):\n",
" sample_size = len(sample_v.labels)\n",
"\n",
" print(\"{k}\\t{N}\\t{seed}\\t{LL}\\t{corpus_size}\\t\".format(k=sample_v.model.K, \n",
" N=sample_size, seed=sample_v.model.seed, \n",
" LL=sample_v.model.log_probs[-1][1],\n",
" corpus_size=len(sample_v.corpus)))\n",
"\n",
" # compute similarity on topic-word matrix - given a topic, what is its\n",
" # distribution over words?\n",
" dist = model_dist(sample_v, v)\n",
" basic = basic_alignment(sample_v, v, dist=dist)\n",
" naive = naive_alignment(sample_v, v, dist=dist)\n",
" m1, m2 = zip(*basic)\n",
" \n",
" print(\"{fitness}\\t{naive_fitness}\\t{overlap}\\t\".format(\n",
" fitness=alignment_fitness(basic, sample_v, v, dist=dist),\n",
" naive_fitness=alignment_fitness(naive, sample_v, v, dist=dist),\n",
" overlap=len(set(m2))))\n",
"\n",
" # Compute similarity on topic-document matrix - given a topic, what is its\n",
" # distribution over documents?\n",
" dist = model_doc_dist(sample_v, v, context_type)\n",
" basic = basic_alignment(sample_v, v, dist=dist)\n",
" naive = naive_alignment(sample_v, v, dist=dist)\n",
" m1, m2 = zip(*basic)\n",
" \n",
" print(\"{fitness}\\t{naive_fitness}\\t{overlap}\\t\".format(\n",
" fitness=alignment_fitness(basic, sample_v, v, dist=dist),\n",
" naive_fitness=alignment_fitness(naive, sample_v, v, dist=dist),\n",
" overlap=len(set(m2))))\n",
"\n",
" # Calculate Spearman, Pearson, top-10 recall, and top-10-percent recall\n",
" # for each document - more of an IR-related search\n",
" \"\"\"\n",
" print(\"{spearman}\\t{pearson}\\t{recall}\\t{recall10p}\".format(\n",
" spearman=spearman(sample_v, v,context_type),\n",
" pearson=pearson(sample_v, v, context_type),\n",
" recall=recall(sample_v, v, context_type, N=10),\n",
" recall10p=recall(sample_v,v,context_type, N=int(np.floor(0.1*sample_size)))))\n",
" \"\"\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Model Comparison"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running from notebook, using serial load function.\n",
"Loading LDA data from /tmp/ap/models/ap-nltk-en-freq5-N2000-LDA-K20-document-20.npz\n",
"Running from notebook, using serial load function.\n",
"Loading LDA data from /tmp/ap/models/ap-nltk-en-freq5-N2000-LDA-K20-document-20.npz\n",
"model k tokens types avg-log-likelihood perplexity\n",
"M0 20 460795 10602 -8.471835631897047 4778.278584687178\n",
"M1 20 460795 10602 -8.471835631897047 4778.278584687178\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jaimie/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:10: ClusterWarning: scipy.cluster: The symmetric non-negative hollow observation matrix looks suspiciously like an uncondensed distance matrix\n",
" # Remove the CWD from sys.path while we load stuff.\n",
"/home/jaimie/anaconda3/lib/python3.6/site-packages/matplotlib/figure.py:459: UserWarning: matplotlib is currently using a non-GUI backend, so cannot show the figure\n",
" \"matplotlib is currently using a non-GUI backend, \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"20\t2250\t1263529745\t-3903779.5\t460795\t\n",
"0.0\t0.0\t20\t\n",
"7.290601886635126e-07\t7.290601886635126e-07\t20\t\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 216x648 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x720 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"v1 = load_from_config('ap.ini', k=20)\n",
"v2 = load_from_config('ap.ini', k=20)\n",
"\n",
"# prints statistical comparisons\n",
"compare_models(v1,v2)\n",
"\n",
"# renders chart comparing two models.\n",
"compare(v1,v2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Interpretaiton Notes\n",
"A perfect alignment between two models would show a dark diagonal all the way down the alignment heatmap.\n",
"\n",
"If multiple cells for a particular row or column are dark, this indicates that the topic is captured by multiple topics in the other model.\n",
"\n",
"Note that often missing data is as illuminating as the presence of data in a model alignment: if a particular row or column is entirely white, then that indicates that it is not covered in the other model at all. When looking at data aggregated by year from a particular discipline, this may indicate the introduction or removal of a particular topic from discussion."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# viewing particular topics in each model\n",
"print v1.topics([26])\n",
"print v2.topics([2])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment