Skip to content

Instantly share code, notes, and snippets.

@a-h
Created January 28, 2015 11:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save a-h/9ac483930cdd972dd171 to your computer and use it in GitHub Desktop.
Save a-h/9ac483930cdd972dd171 to your computer and use it in GitHub Desktop.
Term Frequency - Inverse Document Frequency (Tf-Idf)
using LemmaSharp;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Security.Cryptography;
using System.Text;
using System.Text.RegularExpressions;
using System.Threading.Tasks;
namespace TfIdf
{
public class TermFrequencyInverseDocumentFrequency
{
/// <summary>
/// The number of times a token appears in a document within the corpus.
/// </summary>
public ConcurrentDictionary<string, int> CorpusFrequency { get; set; }
/// <summary>
/// A dictionary of all the words in the document and their position in
/// the output vector. The first word encountered will be at position zero,
/// the last new word encountered will have the largest value.
/// </summary>
public ConcurrentDictionary<string, int> DistinctWords { get; set; }
/// <summary>
/// The number of documents added to the corpus.
/// </summary>
public int DocumentCount { get; set; }
LemmatizerPrebuiltCompact lemmatizer;
static Regex wordBoundaryRegex = new Regex(@"\b", RegexOptions.Compiled);
public TermFrequencyInverseDocumentFrequency(CultureInfo culture)
: this(GetLemmatizer(culture))
{
}
private TermFrequencyInverseDocumentFrequency(LanguagePrebuilt language)
{
this.CorpusFrequency = new ConcurrentDictionary<string, int>();
this.lemmatizer = new LemmatizerPrebuiltCompact(language);
this.DistinctWords = new ConcurrentDictionary<string, int>();
}
private static LanguagePrebuilt GetLemmatizer(CultureInfo culture)
{
while (culture.Parent != null && culture.Parent != CultureInfo.InvariantCulture)
{
culture = culture.Parent;
}
switch (culture.Name)
{
case "bg":
return LanguagePrebuilt.Bulgarian;
case "cs":
return LanguagePrebuilt.Czech;
case "et":
return LanguagePrebuilt.Estonian;
case "fa":
return LanguagePrebuilt.Persian;
case "fr":
return LanguagePrebuilt.French;
case "hu":
return LanguagePrebuilt.Hungarian;
case "mk":
return LanguagePrebuilt.Macedonian;
case "pl":
return LanguagePrebuilt.Polish;
case "ro":
return LanguagePrebuilt.Romanian;
case "ru":
return LanguagePrebuilt.Russian;
case "sk":
return LanguagePrebuilt.Slovak;
case "sl":
return LanguagePrebuilt.Slovene;
case "sr":
return LanguagePrebuilt.Serbian;
case "uk":
return LanguagePrebuilt.Ukrainian;
case "de":
return LanguagePrebuilt.German;
case "it":
return LanguagePrebuilt.Italian;
case "es":
return LanguagePrebuilt.Spanish;
case "en":
default:
return LanguagePrebuilt.English;
}
}
/// <summary>
/// Used to continue calculation of the corpus term frequency.
/// </summary>
/// <param name="document"></param>
public void AddDocumentToCorpus(IEnumerable<string> document)
{
foreach(var token in document.SelectMany(sentence => SplitAndLemmatise(sentence)).Distinct())
{
CorpusFrequency.AddOrUpdate(token, 1, (key, value) => value + 1);
DistinctWords.TryAdd(token, this.DistinctWords.Count);
}
this.DocumentCount++;
}
/// <summary>
/// Used for unit testing to set the corpus data, instead of adding it via the
/// AddDocumentToCorpus() method. This simplifies creation of test corpus data.
/// </summary>
/// <param name="tokensAndCount"></param>
/// <param name="totalDocuments"></param>
public void AddDocumentDataToCorpusForUnitTest(Dictionary<string, int> tokensAndCount, int totalDocuments)
{
this.DocumentCount = totalDocuments;
foreach(var item in tokensAndCount)
{
this.CorpusFrequency.AddOrUpdate(item.Key, item.Value, (k, v) => item.Value);
}
}
/// <summary>
/// Calculates the TfIdf for a document.
/// </summary>
/// <param name="document">A document containing sentences.</param>
/// <returns>A dictionary of terms and their corresponding TfIdf values</returns>
public Dictionary<string, double> CalculateTfIdf(IEnumerable<string> document)
{
var wordsInDocument = new ConcurrentDictionary<string, int>();
int documentWordCount = 0;
foreach (var sentence in document)
{
foreach (var word in SplitAndLemmatise(sentence))
{
wordsInDocument.AddOrUpdate(word, 1, (key, value) => value + 1);
documentWordCount++;
}
}
return wordsInDocument.ToDictionary(kvp => kvp.Key, kvp =>
{
int documentFrequency = kvp.Value;
double tf = documentFrequency / (double)documentWordCount;
double idf = CalculateInverseDocumentFrequency(kvp.Key);
return tf * idf;
});
}
public double CalculateInverseDocumentFrequency(string token, bool retokenize = false)
{
token = retokenize ? lemmatizer.Lemmatize(token.Trim().ToLowerInvariant()) : token;
bool isWordPresentInCorpus = this.CorpusFrequency.ContainsKey(token);
if (isWordPresentInCorpus)
{
int numberOfTimesTheTokenIsPresentInTheCorpus = CorpusFrequency[token];
return Math.Log(this.DocumentCount / (double)(numberOfTimesTheTokenIsPresentInTheCorpus));
}
else
{
return 0d;
}
}
IEnumerable<string> SplitAndLemmatise(string sentence)
{
foreach (var word in wordBoundaryRegex.Split(sentence)
.Where(w => !string.IsNullOrWhiteSpace(w))
.Where(w => w.Any(c => char.IsLetterOrDigit(c)))
.Select(w => w.Trim().ToLowerInvariant()))
{
yield return lemmatizer.Lemmatize(word);
}
}
public double[] ConvertTfIdfToVector(Dictionary<string, double> tfIdf)
{
var rv = new double[this.DistinctWords.Count];
foreach (var item in tfIdf)
{
// Look up the index of the word in the corpus and set the vector value.
int index = 0;
if (this.DistinctWords.TryGetValue(item.Key, out index))
{
rv[index] = item.Value;
}
}
return rv;
}
public static string CalculateVectorHash(double[] vector)
{
using (var sha256 = SHA256Managed.Create())
{
var sb = new StringBuilder();
foreach (var b in sha256.ComputeHash(ConvertVectorToStream(vector)))
{
sb.Append(b.ToString("x2"));
}
return sb.ToString();
}
}
private static Stream ConvertVectorToStream(double[] vector)
{
var ms = new MemoryStream();
foreach (var d in vector)
{
var bytes = BitConverter.GetBytes(d);
ms.Write(bytes, 0, bytes.Length);
}
ms.Position = 0;
return ms;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment