Skip to content

Instantly share code, notes, and snippets.

@LeopoldTal
Created September 14, 2021 09:02
Show Gist options
  • Save LeopoldTal/bb3286ea5079671a47f5afaf638f66a4 to your computer and use it in GitHub Desktop.
Save LeopoldTal/bb3286ea5079671a47f5afaf638f66a4 to your computer and use it in GitHub Desktop.
tf-idf tutorial
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<title>tf-idf tutorial</title>
<script src="https://d3js.org/d3.v5.min.js"></script>
<script src="./set-documents.js"></script>
</head>
<body>
<h1>tf-idf</h1>
<figure style="float: right;">
<img src="machine-learning.png" alt="A robot reads a book" />
<figcaption>Language processing</figcaption>
</figure>
<h2>Goal</h2>
<p>You have a whole bunch of documents: articles, product descriptions, projects, etc. You want to find
<strong>similarities</strong> between them.</p>
<p>Examples: A search engine. A recommendation engine.</p>
<p>For this tutorial, I'll use the 5917 featured articles on Wikipedia.</p>
<h2>Principle</h2>
<p>What does it mean for two documents to be similar, or related? Can we give a formal definition?</p>
<p>We're looking for a concept of <strong>distance</strong>: two copies of the same document should be at
distance 0. Similar documents should be closer together than totally unrelated documents.</p>
<p>A distance… in what? In what vector space?</p>
<p>Machines are stupid. They can't understand the documents, just see what words appear in them. Hey, what
if we used word count?</p>
<p>That defines dimensions: The more times the word "cheese" appears in the document, the further the
document is on the "cheese" axis.</p>
<p>We're using the <strong>frequency of the term</strong> "cheese": the "tf" in "tf-idf".</p>
<figure>
<figcaption>Frequency of the word "cheese" in all documents</figcaption>
<div id="singleTerm1d"></div>
</figure>
<p>That defines an absolutely enormous space: one dimension per unique word.</p>
<p>There are thousands of words in the English language but if I draw diagrams with thousands of dimensions,
they won't be very clear. So I'll just show 2: "horse" and "island".</p>
<h2>Term frequency</h2>
<p>I lied to you! Counting occurrences of the word isn't the only way to do it. Other ways:</p>
<dl>
<dt>simple</dt>
<dd>Count occurrences.
<pre><code>const freqSimple = (document, term) => document.terms[term] || 0;</code></pre>
</dd>
<dt>normalised</dt>
<dd>Count occurrences then divide by total word count in the document, so it's actually a frequency.
<pre><code>const totalTermCount = document => Object.values(document.terms).reduce((a, b) => a + b);
const freqNorm = (document, term) => freqSimple(document, term) / totalTermCount(document);</code></pre>
</dd>
<dt>booleean</dt>
<dd>Count whether the word appears or not.
<pre><code>const freqBool = (document, term) => document.terms[term] ? 1 : 0;</code></pre>
</dd>
<dt>logarithmic</dt>
<dd>Logarithm of the number of occurrences (+ 1 so that it's 0 if missing).
<pre><code>const freqLog = (document, term) => Math.log(1 + freqSimple(document, term));</code></pre>
</dd>
<dt>augmented</dt>
<dd>Instead of dividing by total word count, divide by the number of occurrences of the most frequent
word. For very long documents, that helps identify which words matter most.
<pre><code>const maxWordCount = document => Math.max(...Object.values(document.terms));
const freqAug = (document, term) => 1 + freqSimple(document, term) / maxWordCount(document);</code></pre>
</dd>
</dl>
<p><strong>Mouse over</strong> any point to see the other terms.</p>
<figure>
<figcaption>
<select id="tf2d-select">
<option value="simple">Simple</option>
<option value="norm">Normalised</option>
<option value="bool">Booleean</option>
<option value="log">Logarithmic</option>
<option value="aug">Augmented</option>
</select>
frequency of the terms "horse" and "island" in all documents
</figcaption>
<div id="tf2d"></div>
</figure>
<h2>Inverse frequency</h2>
<p>Okay, that makes a pretty picture for those 2 words, but it doesn't really work: very common words
like "the" always come out on top.</p>
<p>I could build a stopword list, but I'm lazy. Also, some words will still be more common.</p>
<p>The secret ingredient (Spärck Jones, 1972): divide everything by the log of the number of
documents that contain the term.</p>
<p><strong>The rarer the term is, the more it counts.</strong></p>
<figure style="float: right;">
<img src="spaerck_jones.jpg" alt="Karen Spärck Jones in 2002" />
<figcaption>Professor Karen Spärck Jones</figcaption>
</figure>
<p>That's the "idf" ("inverse document frequency") in "tf-idf".</p>
<pre><code>const docCount = (allDocuments, term) => allDocuments.filter(document => document.terms[term]).length;
const idf = (allDocuments, term) => Math.log(allDocuments.length / (1 + docCount(allDocuments, term)));
const tfIdf = (allDocuments, document, term, freq) => freq(document, term) * idf(allDocuments, term);</code></pre>
<figure>
<figcaption>
<select id="tfIdf2d-select">
<option value="simple">Simple</option>
<option value="norm">Normalised</option>
<option value="bool">Boolean</option>
<option value="log">Logarithmic</option>
<option value="aug">Augmented</option>
</select>
tf-idf of the terms "horse" and "island" in all documents
(mouse over to see other terms)
</figcaption>
<div id="tfIdf2d"></div>
</figure>
<p>Success! We get terms strongly related to the topic of each article.</p>
<h2>Similarities</h2>
<p>So we've transformed each document into (term, tf-idf) pairs. How do you measure distance between
two of those?</p>
<p>If a term is very frequent in document 1, you want its frequency in document 2 to matter a lot, and
vice versa. To model this, <strong>multiply</strong> the tf-idf frequencies of the term in the
two documents.</p>
<p>The contribution of each term is independent, so we can just <strong>add</strong> them together.</p>
<pre><code>const cosineSimilarity = (allDocuments, document1, document2, freq) => {
const byTerms = Object.keys(document1.terms).map(term =>
tfIdf(allDocuments, document1, term, freq) * tfIdf(allDocuments, document2, term, freq)
);
return byTerms.reduce((a, b) => a + b);
};</code></pre>
<p>Say, that's a scalar product! It's the cosine of the angle between the vectors of the two
documents. Hence the name "cosine similarity".</p>
<pre><code>const getMostSimilar = (allDocuments, toDocument, freq) => {
const otherDocuments = allDocuments.filter(document => document !== toDocument);
const similarities = otherDocuments.map(document => ({
document,
similarity: cosineSimilarity(allDocuments, document, toDocument, freq)
}));
similarities.sort((document1, document2) => document2.similarity - document1.similarity);
return similarities.slice(0, 10);
};</code></pre>
<p>A cool feature of tf-if: you can tell <strong>why</strong> two documents are similar:
look at the terms with the biggest product.</p>
<div id="similarities">
<caption>
Articles similar to
<cite id="document-name"></cite>
<button type="button" id="random-article">change</button>
by
<select id="similarity-select">
<option value="simple">simple</option>
<option value="norm">normalised</option>
<option value="bool">booleean</option>
<option value="log">logarithmic</option>
<option value="aug">augmented</option>
</select>
tf-idf
</caption>
<table id="similarity-table"></table>
</div>
<p>Question time!</p>
<script src="./tough-dough.js"></script>
</body>
</html>
// Terms for demo
const TERM_SINGLE = 'cheese';
const TERM_X = 'horse';
const TERM_Y = 'island';
const NB_TOP_TERMS = 3;
// Term frequency functions
const freqSimple = (document, term) => document.terms[term] || 0;
const totalTermCount = document => { // memoised
document.totalTermCount = document.totalTermCount || Object.values(document.terms).reduce((a, b) => a + b);
return document.totalTermCount;
};
const freqNorm = (document, term) => freqSimple(document, term) / totalTermCount(document);
const freqBool = (document, term) => document.terms[term] ? 1 : 0;
const freqLog = (document, term) => Math.log(1 + freqSimple(document, term));
const maxWordCount = document => { // memoised
document.maxWordCount = document.maxWordCount === undefined
? Math.max(...Object.values(document.terms))
: document.maxWordCount;
return document.maxWordCount;
};
const freqAug = (document, term) => 1 + freqSimple(document, term) / maxWordCount(document);
// Inverse document frequency
const idfMap = {}; // memoised
const docCount = (allDocuments, term) => allDocuments.filter(document => document.terms[term]).length;
const idf = (allDocuments, term) => {
idfMap[term] = idfMap[term] || Math.log(allDocuments.length / (1 + docCount(allDocuments, term)));
return idfMap[term];
};
// tf-idf
const tfIdf = (allDocuments, document, term, freq) => freq(document, term) * idf(allDocuments, term);
// Similarity
// pre-compile tfIdfs for target document
const cosineSimilarity = (allDocuments, targetTfIdf, otherDocument, freq) => {
const byTerms = targetTfIdf.map(({ term, tdIdfValue }) => ({
term,
weight: tdIdfValue * tfIdf(allDocuments, otherDocument, term, freq)
}));
byTerms.sort((a, b) => b.weight - a.weight);
const similarity = byTerms.reduce((acc, term) => acc + term.weight, 0);
const why = byTerms.slice(0, NB_TOP_TERMS);
return { similarity, why };
};
const getMostSimilar = (allDocuments, toDocument, freq) => {
const otherDocuments = allDocuments.filter(document => document !== toDocument);
const precompiledTfIdf = Object.keys(toDocument.terms).map(term => ({
term,
tdIdfValue: tfIdf(allDocuments, toDocument, term, freq)
}));
precompiledTfIdf.sort( ({ tdIdfValue: val1 }, { tdIdfValue: val2 }) => val2 - val1 );
const targetTfIdf = precompiledTfIdf.slice(0, 500); // only use most-significant terms
const similarities = otherDocuments.map(document => ({
document,
...cosineSimilarity(allDocuments, targetTfIdf, document, freq)
}));
similarities.sort((document1, document2) => document2.similarity - document1.similarity);
return similarities.slice(0, 10);
};
// Interactive examples
const getTicks = values => {
const maxValue = Math.max(...values);
const maxTick = parseFloat(maxValue.toPrecision(1));
const nbTicks = 5;
const ticks = [];
for (let ii = 0; ii < nbTicks; ii++) {
const tickValue = ii / nbTicks * maxTick;
const tickLabel = (Math.round(tickValue * 100) / 100).toString();
ticks.push(tickLabel);
}
return ticks;
};
const drawXAxis = (
axisGroup,
label,
data,
scale,
{ width, height }
) => {
axisGroup.append('svg:line')
.attr('x1', 0)
.attr('y1', height)
.attr('x2', width)
.attr('y2', height)
.attr('stroke', 'black')
.attr('class', 'xTicks');
const labelHeight = height + 15;
const ticks = getTicks(data.map(point => point.x));
axisGroup.selectAll('text.xAxisBottom')
.data(ticks)
.enter()
.append('svg:text')
.text(count => count)
.attr('x', scale)
.attr('y', labelHeight)
.attr('text-anchor', 'middle')
.attr('class', 'xAxisBottom');
axisGroup.append('svg:text')
.text(label)
.attr('x', width - 20)
.attr('y', labelHeight)
.attr('text-anchor', 'middle');
};
const drawYAxis = (
axisGroup,
label,
data,
scale,
{ width, height }
) => {
axisGroup.append('svg:line')
.attr('x1', width)
.attr('y1', 0)
.attr('x2', width)
.attr('y2', height)
.attr('stroke', 'black')
.attr('class', 'yTicks');
const labelLeft = width - 30;
const ticks = getTicks(data.map(point => point.y));
axisGroup.selectAll('text.yAxisLeft')
.data(ticks)
.enter()
.append('svg:text')
.text(count => count)
.attr('x', labelLeft)
.attr('y', scale)
.attr('text-anchor', 'right')
.attr('class', 'yAxisLeft');
axisGroup.append('svg:text')
.text(label)
.attr('x', labelLeft)
.attr('y', 20)
.attr('text-anchor', 'right');
};
const makeTooltip = visRoot => {
const tooltip = visRoot.append('div')
.style('display', 'none')
.style('position', 'absolute')
.style('background-color', 'white')
.style('border', 'solid')
.style('border-width', '1px')
.style('padding', '3px')
.attr('class', 'tooltip');
const showTooltip = () => tooltip.style('display', 'block');
const setTooltipText = point => {
tooltip
.html(point.getTitle())
.style('left', (d3.event.pageX + 5) + 'px')
.style('top', d3.event.pageY + 'px')
};
return { showTooltip, setTooltipText };
};
// move points slightly to avoid overlap
const jiggle = coord => coord - 2 + 4 * Math.random();
const drawCircles = (nodes, { scaleX, scaleY }, { showTooltip, setTooltipText }) => nodes
.append('svg:circle')
.attr('class', 'nodes')
.attr('cx', point => jiggle(scaleX(point.x)))
.attr('cy', point => jiggle(scaleY(point.y)))
.attr('r', '6px')
.attr('stroke', 'black')
.attr('fill', 'white')
.on('mouseover', showTooltip)
.on('mousemove', setTooltipText);
// 1D: simple frequency of a single term
const set1DExample = () => {
const points = window.allDocuments.map(document => ({
getTitle: () => `${document.title} (${freqSimple(document, TERM_SINGLE)})`,
x: freqSimple(document, TERM_SINGLE),
y: 40
}));
const visRoot = d3.select('#singleTerm1d');
const vis = visRoot
.append('svg:svg')
.attr('width', 620)
.attr('height', 80);
const scale = coord => 10 + 33 * coord;
const axisGroup = vis.append('svg:g');
drawXAxis(axisGroup, TERM_SINGLE, points, scale, {
width: 600,
height: 45
});
drawCircles(
vis.selectAll('circle .nodes').data(points).enter(),
{ scaleX: scale, scaleY: y => y },
makeTooltip(visRoot)
);
};
set1DExample();
// 2D example: raw term frequencies
const getFreq = shortName => {
const freqMap = {
simple: freqSimple,
norm: freqNorm,
bool: freqBool,
log: freqLog,
aug: freqAug
};
return freqMap[shortName];
};
const getTopTerms = (document, freq) => {
const withFreqs = Object.keys(document.terms).map(term => ({
term,
termFreq: freq(document, term)
}));
withFreqs.sort((term1, term2) => term2.termFreq - term1.termFreq);
return withFreqs.slice(0, 5);
};
const setTf2DExample = freq => {
const points = window.allDocuments.map(document => ({
getTitle: () => `<p>${document.title}</p><table>${
getTopTerms(document, freq)
.map(({ term, termFreq }) => `<tr><td>${term}</td><td>${termFreq}</td></tr>`)
.join('')
}</table>`,
x: freq(document, TERM_X),
y: freq(document, TERM_Y)
}));
const visRoot = d3.select('#tf2d');
visRoot.selectAll('*').remove();
const vis = visRoot
.append('svg:svg')
.attr('width', 650)
.attr('height', 590);
const maxCoord = Math.max(
...points.map(point => point.x),
...points.map(point => point.y)
);
const scaleX = coord => 40 + 500 * coord / maxCoord;
const scaleY = coord => 545 - 500 * coord / maxCoord;
const axisGroup = vis.append('svg:g');
drawXAxis(axisGroup, TERM_X, points, scaleX, {
width: 600,
height: 550
});
drawYAxis(axisGroup, TERM_Y, points, scaleY, {
width: 45,
height: 550
});
drawCircles(
vis.selectAll('circle .nodes').data(points).enter(),
{ scaleX, scaleY },
makeTooltip(visRoot)
);
};
const updateTf2DExample = e => setTf2DExample(getFreq(e.target.value));
document.getElementById('tf2d-select').addEventListener('change', updateTf2DExample);
setTf2DExample(freqSimple);
// 2D example: tf-idf
const setTfIdf2DExample = freq => {
const toTfIdf = (document, term) => tfIdf(window.allDocuments, document, term, freq);
const points = window.allDocuments.map(document => ({
getTitle: () => `<p>${document.title}</p><table>${
getTopTerms(document, toTfIdf)
.map(({ term, termFreq }) => `<tr><td>${term}</td><td>${termFreq}</td></tr>`)
.join('')
}</table>`,
x: toTfIdf(document, TERM_X),
y: toTfIdf(document, TERM_Y)
}));
const visRoot = d3.select('#tfIdf2d');
visRoot.selectAll('*').remove();
const vis = visRoot
.append('svg:svg')
.attr('width', 650)
.attr('height', 590);
const maxCoord = Math.max(
...points.map(point => point.x),
...points.map(point => point.y)
);
const scaleX = coord => 40 + 500 * coord / maxCoord;
const scaleY = coord => 545 - 500 * coord / maxCoord;
const axisGroup = vis.append('svg:g');
drawXAxis(axisGroup, TERM_X, points, scaleX, {
width: 600,
height: 550
});
drawYAxis(axisGroup, TERM_Y, points, scaleY, {
width: 45,
height: 550
});
drawCircles(
vis.selectAll('circle .nodes').data(points).enter(),
{ scaleX, scaleY },
makeTooltip(visRoot)
);
};
const updateTfIdf2DExample = e => setTfIdf2DExample(getFreq(e.target.value));
document.getElementById('tfIdf2d-select').addEventListener('change', updateTfIdf2DExample);
setTfIdf2DExample(freqSimple);
// Similarities
const setSimilarities = (toDocument, freq) => {
const nameDisplay = window.document.getElementById('document-name');
nameDisplay.innerText = toDocument.title;
const table = window.document.getElementById('similarity-table');
table.innerHTML = '';
const headerRow = window.document.createElement('tr');
headerRow.innerHTML = `<th>Article</th><th>Similarity</th><th>Top terms</th>`;
table.appendChild(headerRow);
const similarities = getMostSimilar(window.allDocuments, toDocument, freq);
similarities.forEach(({ document: { title }, similarity, why }) => {
const row = window.document.createElement('tr');
const topTerms = why.map(
({ term, weight }) => `${term}<small> (${weight.toPrecision(5)})</small>`
).join(', ');
row.innerHTML = `<td>${title}</td><td>${similarity.toPrecision(6)}</td><td>${topTerms}</td>`;
table.appendChild(row);
});
};
const getRandomDocument = () => window.allDocuments[
Math.floor(window.allDocuments.length * Math.random())
];
window.selectedDocument = getRandomDocument();
const changeSelectedDocument = () => {
window.selectedDocument = getRandomDocument();
setSimilarities(
window.selectedDocument,
getFreq(document.getElementById('similarity-select').value)
);
};
document.getElementById('random-article').addEventListener('click', changeSelectedDocument);
const updateSimilarities = e => setSimilarities(window.selectedDocument, getFreq(e.target.value));
document.getElementById('similarity-select').addEventListener('change', updateSimilarities);
setSimilarities(window.selectedDocument, freqSimple);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment