Created
September 3, 2015 15:43
-
-
Save se4u/78830b6741de14a09455 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
| Filename : HMMGibbs.py | |
| Description : Gibbs sampling based learning for HMM models. | |
| Author : Pushpendre Rastogi | |
| Created : Thu Sep 3 02:22:31 2015 (-0400) | |
| Last-Updated: Thu Sep 3 11:38:18 2015 (-0400) | |
| By: Pushpendre Rastogi | |
| Update #: 15 | |
''' | |
import numpy as np | |
exp = np.exp | |
START_IDX, STOP_IDX = 2, 2 | |
class TrainingExample(object): | |
def __init__(self, sent, recv): | |
self.sent = np.array(sent) | |
self.recv = np.array(recv) | |
def __repr__(self): | |
return 'TrainingExample(\n\tsent=%s,\n\trecv=%s)' % ( | |
str(self.sent).replace('\n', ''), str(self.recv).replace('\n', '')) | |
class BinaryHMM(object): | |
def __init__(self, T=None, G=None): | |
# T[i, j] is the probability that bit j is received when bit i was | |
# sent. | |
self.T = (np.array([[0.8, 0.2], | |
[0.1, 0.9]]) | |
if T is None | |
else T) | |
# G[i,j] is the probability of generating j given that i was generated | |
# previously. j=2 is the stop sign. i=2 is the start sign. | |
self.G = (np.array([[0.7, 0.2, 0.1], | |
[0.19, 0.8, 0.01], | |
[0.5, 0.5, 0]]) | |
if G is None | |
else G) | |
global START_IDX | |
global STOP_IDX | |
START_IDX, STOP_IDX = 2, 2 | |
@staticmethod | |
def sample(pdf): | |
r = np.random.rand() | |
pdf_sum = 0 | |
idx = None | |
for idx in range(len(pdf)): | |
pdf_sum += pdf[idx] | |
if pdf_sum >= r: | |
break | |
return idx | |
def generate_bit(self, prev_bit): | |
pdf = self.G[prev_bit] | |
return self.sample(pdf) | |
def transmit(self, bit): | |
pdf = self.T[bit] | |
return self.sample(pdf) | |
def generate_sentence(self): | |
sent = [self.generate_bit(START_IDX)] | |
recv = [self.transmit(sent[-1])] | |
while True: | |
cur_bit = self.generate_bit(sent[-1]) | |
if cur_bit == STOP_IDX: | |
break | |
trans_bit = self.transmit(cur_bit) | |
sent.append(cur_bit) | |
recv.append(trans_bit) | |
return TrainingExample(sent, recv) | |
def generate(self, num_sentences=10): | |
return [self.generate_sentence() | |
for _ in range(num_sentences)] | |
def post_process(self, Y_hat_t): | |
for e in Y_hat_t: | |
print e | |
return Y_hat_t[-1] | |
def GibbsInfer(self, training_example, quota=5): | |
Y_true = training_example.sent | |
X = training_example.recv | |
# The goal is to use gibbs sampling(GS) to create an optimal Y_hat | |
# Gibbs sampling requires, an estimate of p(y_j | y_/j, X) | |
def p(j, Y_hat, X): | |
''' Let us say that we model y | x through a first order CRF. | |
The features are: | |
- cur bit | |
- prev bit (prev bit = prev bit in X) | |
- prev signal (prev signal = prev bit in Y) | |
- cur bit, prev signal (Ommitted for now) | |
- prev bit, prev signal (Ommitted for now) | |
''' | |
cur_bit = X[j] | |
prev_bit = X[j - 1] | |
prev_signal = Y_hat[j - 1] | |
def simple_scorer(bit): | |
return ((0.9 if cur_bit == bit else -0.1) | |
+ (0.9 if prev_bit == bit else 0) | |
+ (0.9 if prev_signal == bit else 0)) | |
# score_i = score for y_hat[j] being i | |
score_0 = exp(simple_scorer(0)) | |
score_1 = exp(simple_scorer(1)) | |
sum_scores = score_0 + score_1 | |
return np.array([score_0 / sum_scores, score_1 / sum_scores]) | |
# At iteration 0 we guess that the channel transmitted | |
# everything perfectly. None of the bits were flipped. | |
Y_hat_t = [X] | |
# In vanilla Gibbs, cyclically at each iteration, we guess | |
# t = 0 | |
for epoch in range(quota): | |
for j in range(len(X)): | |
# t += 1 | |
y_hat = np.array(Y_hat_t[-1]) # Copy | |
y_hat_j_at_t = self.sample(p(j, y_hat, X)) | |
y_hat[j] = y_hat_j_at_t | |
Y_hat_t.append(y_hat) | |
return self.post_process(Y_hat_t) | |
def test(self): | |
pass | |
import unittest | |
class TestBinaryHMM(unittest.TestCase): | |
def setUp(self): | |
self.obj = BinaryHMM() | |
def test_generate_sentence(self): | |
self.obj.generate_sentence() | |
def test_GibbsInfer(self): | |
training_example = self.obj.generate_sentence() | |
print training_example | |
self.obj.GibbsInfer(training_example) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment