Skip to content

Instantly share code, notes, and snippets.

@thomashikaru
Created October 10, 2021 19:05
Show Gist options
  • Save thomashikaru/560c41ce4f2a834e254ad39ecedd46e0 to your computer and use it in GitHub Desktop.
Save thomashikaru/560c41ce4f2a834e254ad39ecedd46e0 to your computer and use it in GitHub Desktop.
from collections import defaultdict
class BigramModel:
def train(self, training_set):
self.d = defaultdict(lambda: defaultdict(int))
for sent in training_set:
for w1, w2 in zip(sent[:-1], sent[1:]):
self.d[w1][w2] += 1
def relative_freq(self, context, word):
return self.d[context][word] / sum(self.d[context].values())
if __name__ == "__main__":
training_data = [
"<s> the fox jumps over the dog </s>",
"<s> the cat jumps over the fox </s>",
"<s> the cat eats cat food </s>",
"<s> the fox steals cat food </s>",
]
training_set = [x.split() for x in training_data]
model = BigramModel()
model.train(training_set)
print(model.relative_freq("the", "cat"))
print(model.relative_freq("the", "fox"))
print(model.relative_freq("fox", "jumps"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment