Skip to content

Instantly share code, notes, and snippets.

@himkt
Created September 23, 2017 08:21
Show Gist options
  • Save himkt/0c3315e84605fb0f086467a9d1365eec to your computer and use it in GitHub Desktop.
Save himkt/0c3315e84605fb0f086467a9d1365eec to your computer and use it in GitHub Desktop.
import collections
import tqdm
import math
class HMM:
def __init__(self):
pass
def fit(self, Xs, Ys):
label_trainsition = collections.defaultdict(int)
word_emission = collections.defaultdict(int)
label_occurrence = collections.defaultdict(int)
for X, Y in tqdm.tqdm(zip(Xs, Ys)):
maxlen = len(X)
label_occurrence['BOS'] += 1
prev_label = 'BOS'
for i in range(maxlen):
current_word = X[i]
current_label = Y[i]
# label occurrence
label_occurrence[Y[i]] += 1
# word occurrence given label
word_emission[f'{current_label}_{current_word}'] += 1
# label transition
label_trainsition[f'{prev_label}_{current_label}'] += 1
prev_label = current_label
label_trainsition_prob = collections.defaultdict(float)
for label_transition_key, freq in label_trainsition.items():
context, target = label_transition_key.split('_')
Z = label_occurrence[context]
label_trainsition_prob[label_transition_key] = freq / Z
word_emission_prob = collections.defaultdict(float)
for word_emission_key, freq in word_emission.items():
context, target = word_emission_key.split('_')
Z = label_occurrence[context]
word_emission_prob[word_emission_key] = freq / Z
self.n_label = len(label_occurrence)
self.labels = label_occurrence.keys()
self.label_transition_prob = label_trainsition_prob
self.word_emission_prob = word_emission_prob
def predict(self, X):
best_score, best_edge = self.forward(X)
print(best_edge)
return self.viterbi_decode(best_edge)
def forward(self, X):
len_seq = len(X)
labels = self.labels
label_transition_prob = self.label_transition_prob
word_emission_prob = self.word_emission_prob
best_score = {}
best_edge = {}
best_score['0_BOS'] = 0
best_edge['0_BOS'] = None
for position in range(len_seq-1):
for prev in labels:
for current in labels:
if f'{position}_{prev}' not in best_score.keys():
continue
if f'{prev}_{current}' not in label_transition_prob.keys():
continue
print(word_emission_prob[f'{current}_{X[position]}'])
print(label_transition_prob[f'{prev}_{current}'])
_val = -math.log(0.01+word_emission_prob[f'{current}_{X[position]}']) # NOQA
_val += -math.log(0.01+label_transition_prob[f'{prev}_{current}']) # NOQA
_score = best_score[f'{position}_{prev}'] + _val
if f'{position+1}_{current}' not in best_score.keys() or \
best_score[f'{position}_{prev}'] < _score:
best_score[f'{position}_{current}'] = _score
best_edge[f'{position}_{current}'] = f'{position-1}_{prev}' # NOQA
position = len_seq - 1
_val = -math.log(0.01+word_emission_prob[f'{current}_{X[position]}'])
_val += -math.log(0.01+label_transition_prob[f'{prev}_EOS'])
_score = best_score[f'{position}_{prev}'] + _val
if f'{len_seq}_EOS' not in best_score.keys() or \
best_score[f'{position}_{prev}'] < _score:
best_score[f'{position}_EOS'] = _score
best_edge[f'EOS_{current}'] = f'{position-1}_{prev}'
return best_score, best_edge
def viterbi_decode(best_edge):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment