Created
September 23, 2017 08:21
-
-
Save himkt/0c3315e84605fb0f086467a9d1365eec to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import tqdm | |
import math | |
class HMM: | |
def __init__(self): | |
pass | |
def fit(self, Xs, Ys): | |
label_trainsition = collections.defaultdict(int) | |
word_emission = collections.defaultdict(int) | |
label_occurrence = collections.defaultdict(int) | |
for X, Y in tqdm.tqdm(zip(Xs, Ys)): | |
maxlen = len(X) | |
label_occurrence['BOS'] += 1 | |
prev_label = 'BOS' | |
for i in range(maxlen): | |
current_word = X[i] | |
current_label = Y[i] | |
# label occurrence | |
label_occurrence[Y[i]] += 1 | |
# word occurrence given label | |
word_emission[f'{current_label}_{current_word}'] += 1 | |
# label transition | |
label_trainsition[f'{prev_label}_{current_label}'] += 1 | |
prev_label = current_label | |
label_trainsition_prob = collections.defaultdict(float) | |
for label_transition_key, freq in label_trainsition.items(): | |
context, target = label_transition_key.split('_') | |
Z = label_occurrence[context] | |
label_trainsition_prob[label_transition_key] = freq / Z | |
word_emission_prob = collections.defaultdict(float) | |
for word_emission_key, freq in word_emission.items(): | |
context, target = word_emission_key.split('_') | |
Z = label_occurrence[context] | |
word_emission_prob[word_emission_key] = freq / Z | |
self.n_label = len(label_occurrence) | |
self.labels = label_occurrence.keys() | |
self.label_transition_prob = label_trainsition_prob | |
self.word_emission_prob = word_emission_prob | |
def predict(self, X): | |
best_score, best_edge = self.forward(X) | |
print(best_edge) | |
return self.viterbi_decode(best_edge) | |
def forward(self, X): | |
len_seq = len(X) | |
labels = self.labels | |
label_transition_prob = self.label_transition_prob | |
word_emission_prob = self.word_emission_prob | |
best_score = {} | |
best_edge = {} | |
best_score['0_BOS'] = 0 | |
best_edge['0_BOS'] = None | |
for position in range(len_seq-1): | |
for prev in labels: | |
for current in labels: | |
if f'{position}_{prev}' not in best_score.keys(): | |
continue | |
if f'{prev}_{current}' not in label_transition_prob.keys(): | |
continue | |
print(word_emission_prob[f'{current}_{X[position]}']) | |
print(label_transition_prob[f'{prev}_{current}']) | |
_val = -math.log(0.01+word_emission_prob[f'{current}_{X[position]}']) # NOQA | |
_val += -math.log(0.01+label_transition_prob[f'{prev}_{current}']) # NOQA | |
_score = best_score[f'{position}_{prev}'] + _val | |
if f'{position+1}_{current}' not in best_score.keys() or \ | |
best_score[f'{position}_{prev}'] < _score: | |
best_score[f'{position}_{current}'] = _score | |
best_edge[f'{position}_{current}'] = f'{position-1}_{prev}' # NOQA | |
position = len_seq - 1 | |
_val = -math.log(0.01+word_emission_prob[f'{current}_{X[position]}']) | |
_val += -math.log(0.01+label_transition_prob[f'{prev}_EOS']) | |
_score = best_score[f'{position}_{prev}'] + _val | |
if f'{len_seq}_EOS' not in best_score.keys() or \ | |
best_score[f'{position}_{prev}'] < _score: | |
best_score[f'{position}_EOS'] = _score | |
best_edge[f'EOS_{current}'] = f'{position-1}_{prev}' | |
return best_score, best_edge | |
def viterbi_decode(best_edge): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment