Implementation of Kneser-Ney language model used for smoothing.
using System; | |
using System.Collections.Generic; | |
using System.Linq; | |
using System.Text; | |
namespace GM.NLP.Smoothing | |
{ | |
/// <summary> | |
/// Implementation of Kneser-Ney language model. | |
/// | |
/// https://en.wikipedia.org/wiki/Kneser–Ney_smoothing | |
/// </summary> | |
public class KneserNey | |
{ | |
/// <summary> | |
/// A constant which denotes the discount value subtracted from the count of each n-gram, usually between 0 and 1. | |
/// </summary> | |
private const double d = 0.75; | |
/// <summary> | |
/// n-gram | |
/// </summary> | |
public readonly int n; | |
private readonly string[] corpus; | |
/// <summary> | |
/// Distinct words of the corpus. | |
/// </summary> | |
private readonly string[] corpusDistinct; | |
/// <summary> | |
/// Number of distinct pairs of consecutive words in the corpus. | |
/// </summary> | |
private readonly int distinctPairCount; | |
private readonly string sentenceSeparator; | |
public KneserNey(int n,string corpus,string sentenceSeparator="<s/>") | |
{ | |
this.n = n; | |
this.sentenceSeparator = sentenceSeparator; | |
this.corpus = CreateWords(corpus,sentenceSeparator); | |
corpusDistinct = this.corpus.Distinct().ToArray(); | |
distinctPairCount = GetDistinctPairCount(this.corpus); | |
} | |
/// <summary> | |
/// Calculates the probability for the provided word/text (can be multiple sentences). If log is used, returned values will be negative, but bigger. | |
/// </summary> | |
public double P(string text,bool useLog=true) | |
{ | |
string[] words = CreateWords(text,sentenceSeparator); | |
return P(words,useLog); | |
} | |
/// <summary> | |
/// Calculates perplexity for the provided text. | |
/// </summary> | |
public double PP(string text) | |
{ | |
string[] words = CreateWords(text,sentenceSeparator); | |
return PP(words); | |
} | |
/// <summary> | |
/// Calculates perplexity for the provided words. | |
/// </summary> | |
private double PP(string[] w) | |
{ | |
List<string> testCorpus = new List<string>(w); | |
List<double> perplexities = new List<double>(); | |
// calculate perplexity for each sentence, and then average the results | |
while(testCorpus.Count > 1) { | |
int endOfSentence = testCorpus.IndexOf(sentenceSeparator, 1); | |
List<string> sentence = testCorpus.GetRange(0, endOfSentence + 1); | |
double p = P(sentence.ToArray(), false); | |
double perplexity = Math.Pow(p,-(1.0/sentence.Count)); | |
perplexities.Add(perplexity); | |
testCorpus.RemoveRange(0, endOfSentence); | |
} | |
return perplexities.Sum() / perplexities.Count; | |
} | |
/// <summary> | |
/// Calculates the probability for the provided words. | |
/// </summary> | |
private double P(string[] w,bool useLog) | |
{ | |
double result = useLog?0:1; | |
int nm1 = n - 1; | |
string wi; | |
string[] wi_n = new string[nm1]; | |
for(int i = w.Length - 1; i >= nm1; --i) { | |
wi = w[i]; | |
// fill up wi_k | |
for(int j = nm1; j > 0; --j) { | |
wi_n[nm1 - j] = w[i - j]; | |
} | |
// calculate probability | |
double p = PKN(wi, wi_n); | |
if(p == 0) | |
// what else to do, when the probability is 0? | |
p = 1.0 / corpusDistinct.Length; | |
// add it to result | |
if(useLog) | |
result += Math.Log(p); | |
else | |
result *= p; | |
} | |
return result; | |
} | |
/// <summary> | |
/// Calculates probability of how probable it is that the word 'wi' will follow words 'wi_n'. | |
/// </summary> | |
private double PKN(string wi, params string[] wi_n) | |
{ | |
if(wi_n.Length == 0) | |
return PKN(wi); | |
// how many times does wi follow wi_n | |
double top1 = Math.Max(c(wi_n, wi) - d, 0); | |
// how many words follow wi_n | |
int bottom1 = 0; | |
// how many times does any (distinct) word follow wi_n | |
int top2 = 0; | |
foreach(string w in corpusDistinct) { | |
int count = c(wi_n, w); | |
if(count > 0) { | |
bottom1 += count; | |
++top2; | |
} | |
} | |
// how many times does wi_n appear | |
int bottom2 = c(wi_n); | |
double pkn = PKN(wi, wi_n.Skip(1).ToArray()); | |
double leftPart = 0; | |
if(bottom1 != 0) | |
leftPart = top1 / bottom1; | |
double rightPart; | |
if(bottom2 != 0) | |
rightPart = d * (top2 / bottom2) * pkn; | |
else | |
rightPart = 1.0/corpusDistinct.Length * pkn; | |
return leftPart + rightPart; | |
} | |
/// <summary> | |
/// How likely it is to see the word 'wi' in an unfamiliar context? | |
/// </summary> | |
private double PKN(string wi) | |
{ | |
// number of times it appears after any other word | |
int top = corpusDistinct.Sum(w => (c(w, wi) > 0) ? 1 : 0); | |
return top / (double)(distinctPairCount); | |
} | |
/// <summary> | |
/// Number of occurrences of the word 'w' followed by the word 'w_' in the corpus. | |
/// </summary> | |
private int c(string w,string w_) | |
{ | |
int result = 0; | |
for(int i = corpus.Length - 1; i > 0; --i) { | |
if(corpus[i] != w_) | |
continue; | |
if(corpus[i - 1] != w) | |
continue; | |
++result; | |
} | |
return result; | |
} | |
/// <summary> | |
/// Number of occurrences of the words 'w' followed by the word 'w_' in the corpus. | |
/// </summary> | |
private int c(string[] w,string w_) | |
{ | |
int result = 0; | |
for(int i = corpus.Length - 1; i >= w.Length; --i) { | |
if(corpus[i] != w_) | |
continue; | |
int j; | |
for(j = w.Length; j > 0; --j) { | |
if(corpus[i - j] != w[w.Length - j]) | |
break; | |
} | |
if(j > 0) | |
continue; | |
++result; | |
} | |
return result; | |
} | |
/// <summary> | |
/// Number of occurrences of the words 'w' in the corpus. | |
/// </summary> | |
private int c(string[] w) | |
{ | |
int result = 0; | |
int border = w.Length; | |
for(int i = corpus.Length - 1; i > border; --i) { | |
int j; | |
for(j = w.Length-1; j >= 0; --j) { | |
if(corpus[i - j] != w[w.Length - j - 1]) | |
break; | |
} | |
if(j >= 0) | |
continue; | |
++result; | |
} | |
return result; | |
} | |
/// <summary> | |
/// Calculates the number of distinct pairs of consecutive words in the corpus. | |
/// </summary> | |
private static int GetDistinctPairCount(string[] wi) | |
{ | |
Dictionary<string, List<string>> pairs = new Dictionary<string, List<string>>(); | |
int distinctPairCount = 0; | |
for(int i = wi.Length - 1; i > 0; --i) { | |
string w1 = wi[i]; | |
string w2 = wi[i - 1]; | |
if(w1 != w2) { | |
if(pairs.ContainsKey(w1)) { | |
if(pairs[w1].Contains(w2)) | |
continue; | |
pairs[w1].Add(w2); | |
if(pairs.ContainsKey(w2)) { | |
pairs[w2].Add(w1); | |
} else { | |
pairs.Add(w2, new List<string>() { w1 }); | |
} | |
} else { | |
pairs.Add(w1, new List<string>() { w2 }); | |
pairs.Add(w2, new List<string>() { w1 }); | |
} | |
++distinctPairCount; | |
} | |
} | |
return distinctPairCount; | |
} | |
/// <summary> | |
/// Separates the provided text to words and inserts the specified sentence separator between sentences. | |
/// </summary> | |
private static string[] CreateWords(string text,string sentenceSeparator) | |
{ | |
text = text.ToLower().Replace("\r", " ").Replace("\n", " "); | |
string[] sentences = text.Split(new string[] { ". ",".\r\n",".\n\r",".\r",".\n" }, StringSplitOptions.RemoveEmptyEntries); | |
if(sentences.Length == 0) | |
throw new Exception("Provided text contains zero sentences."); | |
List<string> corpus = new List<string>(); | |
corpus.Add(sentenceSeparator); | |
for(int i = 0; i < sentences.Length; ++i) { | |
string sentence = sentences[i]; | |
// clean sentence | |
sentence = Clean(sentence).Trim(); | |
if(sentence.Length == 0) | |
continue; | |
string[] words = sentence.Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); | |
if(words.Length == 0) | |
continue; | |
corpus.AddRange(words); | |
corpus.Add(sentenceSeparator); | |
} | |
if(corpus.Count == 1) | |
throw new Exception("Provided text contains zero words."); | |
return corpus.ToArray(); | |
} | |
/// <summary> | |
/// Cleans the provided sentence, leaving only letters, apostrophes and spaces. Digits and punctuation are discarded. | |
/// </summary> | |
private static string Clean(string sentence) | |
{ | |
// replace dot with space in case there is a case of: "word1.word2" | |
sentence = sentence.Replace('.', ' '); | |
StringBuilder sb = new StringBuilder(sentence.Length); | |
foreach(char c in sentence) { | |
if(char.IsLetter(c) || c == '\'' || c == ' ') | |
sb.Append(c); | |
} | |
return sb.ToString(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment