Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Implementation of Kneser-Ney language model used for smoothing.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace GM.NLP.Smoothing
{
/// <summary>
/// Implementation of Kneser-Ney language model.
///
/// https://en.wikipedia.org/wiki/Kneser–Ney_smoothing
/// </summary>
public class KneserNey
{
/// <summary>
/// A constant which denotes the discount value subtracted from the count of each n-gram, usually between 0 and 1.
/// </summary>
private const double d = 0.75;
/// <summary>
/// n-gram
/// </summary>
public readonly int n;
private readonly string[] corpus;
/// <summary>
/// Distinct words of the corpus.
/// </summary>
private readonly string[] corpusDistinct;
/// <summary>
/// Number of distinct pairs of consecutive words in the corpus.
/// </summary>
private readonly int distinctPairCount;
private readonly string sentenceSeparator;
public KneserNey(int n,string corpus,string sentenceSeparator="<s/>")
{
this.n = n;
this.sentenceSeparator = sentenceSeparator;
this.corpus = CreateWords(corpus,sentenceSeparator);
corpusDistinct = this.corpus.Distinct().ToArray();
distinctPairCount = GetDistinctPairCount(this.corpus);
}
/// <summary>
/// Calculates the probability for the provided word/text (can be multiple sentences). If log is used, returned values will be negative, but bigger.
/// </summary>
public double P(string text,bool useLog=true)
{
string[] words = CreateWords(text,sentenceSeparator);
return P(words,useLog);
}
/// <summary>
/// Calculates perplexity for the provided text.
/// </summary>
public double PP(string text)
{
string[] words = CreateWords(text,sentenceSeparator);
return PP(words);
}
/// <summary>
/// Calculates perplexity for the provided words.
/// </summary>
private double PP(string[] w)
{
List<string> testCorpus = new List<string>(w);
List<double> perplexities = new List<double>();
// calculate perplexity for each sentence, and then average the results
while(testCorpus.Count > 1) {
int endOfSentence = testCorpus.IndexOf(sentenceSeparator, 1);
List<string> sentence = testCorpus.GetRange(0, endOfSentence + 1);
double p = P(sentence.ToArray(), false);
double perplexity = Math.Pow(p,-(1.0/sentence.Count));
perplexities.Add(perplexity);
testCorpus.RemoveRange(0, endOfSentence);
}
return perplexities.Sum() / perplexities.Count;
}
/// <summary>
/// Calculates the probability for the provided words.
/// </summary>
private double P(string[] w,bool useLog)
{
double result = useLog?0:1;
int nm1 = n - 1;
string wi;
string[] wi_n = new string[nm1];
for(int i = w.Length - 1; i >= nm1; --i) {
wi = w[i];
// fill up wi_k
for(int j = nm1; j > 0; --j) {
wi_n[nm1 - j] = w[i - j];
}
// calculate probability
double p = PKN(wi, wi_n);
if(p == 0)
// what else to do, when the probability is 0?
p = 1.0 / corpusDistinct.Length;
// add it to result
if(useLog)
result += Math.Log(p);
else
result *= p;
}
return result;
}
/// <summary>
/// Calculates probability of how probable it is that the word 'wi' will follow words 'wi_n'.
/// </summary>
private double PKN(string wi, params string[] wi_n)
{
if(wi_n.Length == 0)
return PKN(wi);
// how many times does wi follow wi_n
double top1 = Math.Max(c(wi_n, wi) - d, 0);
// how many words follow wi_n
int bottom1 = 0;
// how many times does any (distinct) word follow wi_n
int top2 = 0;
foreach(string w in corpusDistinct) {
int count = c(wi_n, w);
if(count > 0) {
bottom1 += count;
++top2;
}
}
// how many times does wi_n appear
int bottom2 = c(wi_n);
double pkn = PKN(wi, wi_n.Skip(1).ToArray());
double leftPart = 0;
if(bottom1 != 0)
leftPart = top1 / bottom1;
double rightPart;
if(bottom2 != 0)
rightPart = d * (top2 / bottom2) * pkn;
else
rightPart = 1.0/corpusDistinct.Length * pkn;
return leftPart + rightPart;
}
/// <summary>
/// How likely it is to see the word 'wi' in an unfamiliar context?
/// </summary>
private double PKN(string wi)
{
// number of times it appears after any other word
int top = corpusDistinct.Sum(w => (c(w, wi) > 0) ? 1 : 0);
return top / (double)(distinctPairCount);
}
/// <summary>
/// Number of occurrences of the word 'w' followed by the word 'w_' in the corpus.
/// </summary>
private int c(string w,string w_)
{
int result = 0;
for(int i = corpus.Length - 1; i > 0; --i) {
if(corpus[i] != w_)
continue;
if(corpus[i - 1] != w)
continue;
++result;
}
return result;
}
/// <summary>
/// Number of occurrences of the words 'w' followed by the word 'w_' in the corpus.
/// </summary>
private int c(string[] w,string w_)
{
int result = 0;
for(int i = corpus.Length - 1; i >= w.Length; --i) {
if(corpus[i] != w_)
continue;
int j;
for(j = w.Length; j > 0; --j) {
if(corpus[i - j] != w[w.Length - j])
break;
}
if(j > 0)
continue;
++result;
}
return result;
}
/// <summary>
/// Number of occurrences of the words 'w' in the corpus.
/// </summary>
private int c(string[] w)
{
int result = 0;
int border = w.Length;
for(int i = corpus.Length - 1; i > border; --i) {
int j;
for(j = w.Length-1; j >= 0; --j) {
if(corpus[i - j] != w[w.Length - j - 1])
break;
}
if(j >= 0)
continue;
++result;
}
return result;
}
/// <summary>
/// Calculates the number of distinct pairs of consecutive words in the corpus.
/// </summary>
private static int GetDistinctPairCount(string[] wi)
{
Dictionary<string, List<string>> pairs = new Dictionary<string, List<string>>();
int distinctPairCount = 0;
for(int i = wi.Length - 1; i > 0; --i) {
string w1 = wi[i];
string w2 = wi[i - 1];
if(w1 != w2) {
if(pairs.ContainsKey(w1)) {
if(pairs[w1].Contains(w2))
continue;
pairs[w1].Add(w2);
if(pairs.ContainsKey(w2)) {
pairs[w2].Add(w1);
} else {
pairs.Add(w2, new List<string>() { w1 });
}
} else {
pairs.Add(w1, new List<string>() { w2 });
pairs.Add(w2, new List<string>() { w1 });
}
++distinctPairCount;
}
}
return distinctPairCount;
}
/// <summary>
/// Separates the provided text to words and inserts the specified sentence separator between sentences.
/// </summary>
private static string[] CreateWords(string text,string sentenceSeparator)
{
text = text.ToLower().Replace("\r", " ").Replace("\n", " ");
string[] sentences = text.Split(new string[] { ". ",".\r\n",".\n\r",".\r",".\n" }, StringSplitOptions.RemoveEmptyEntries);
if(sentences.Length == 0)
throw new Exception("Provided text contains zero sentences.");
List<string> corpus = new List<string>();
corpus.Add(sentenceSeparator);
for(int i = 0; i < sentences.Length; ++i) {
string sentence = sentences[i];
// clean sentence
sentence = Clean(sentence).Trim();
if(sentence.Length == 0)
continue;
string[] words = sentence.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
if(words.Length == 0)
continue;
corpus.AddRange(words);
corpus.Add(sentenceSeparator);
}
if(corpus.Count == 1)
throw new Exception("Provided text contains zero words.");
return corpus.ToArray();
}
/// <summary>
/// Cleans the provided sentence, leaving only letters, apostrophes and spaces. Digits and punctuation are discarded.
/// </summary>
private static string Clean(string sentence)
{
// replace dot with space in case there is a case of: "word1.word2"
sentence = sentence.Replace('.', ' ');
StringBuilder sb = new StringBuilder(sentence.Length);
foreach(char c in sentence) {
if(char.IsLetter(c) || c == '\'' || c == ' ')
sb.Append(c);
}
return sb.ToString();
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.