Skip to content

Instantly share code, notes, and snippets.

@yuq-1s
Last active July 2, 2021 12:55
Show Gist options
  • Save yuq-1s/c190a5b01ffb284e7c823d07303c62cc to your computer and use it in GitHub Desktop.
Save yuq-1s/c190a5b01ffb284e7c823d07303c62cc to your computer and use it in GitHub Desktop.
HMM on GPU with pytorch in 100 lines
'''
This proof-of-concept follows [PRML](https://www.microsoft.com/en-us/research/people/cmbishop/prml-book/)'s idea.
This code extends plain HMM in the way that it has different transition matrix and emission matrix on different features `xs`.
To get a normal HMM, you can set all `x` to the same.
`HMM.predict()` uses formula (13.44) in PRML, which considers the whole seen sequence of observation `y`s.
If you have no observed `y`s and only have `x`s, you can use `model.trans(x).view(T, N, self.H, self.H).softmax(dim=3)` as transition matrix to get predicted sequence.
`gamma` here represents posterior probability of hidden states.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
def my_div(a, b):
b += (b.abs() < 1e-8) * 1e-6
return a / b
def alpha(a, T, E):
# T: [N, H, H], transition matrix from n-1 to n
# a: [N, H], last_alpha
# E: [N, H, C], emission matrix at step n
return E * torch.einsum('nj,njk->nk', a, T)
def alpha_hat_and_c(a, T, E):
# T: [N, H, H], transition matrix from n-1 to n
# a: [N, H], last_alpha_hat
# E: [N, H, C], emission matrix at step n
rhs = alpha(a, T, E)
return my_div(rhs, rhs.sum(dim=1).view(-1, 1)), rhs.sum(dim=1)
def beta_hat(b, T, E, c_t_plus_one):
rhs = beta(b, T, E)
return my_div(rhs, c_t_plus_one.view(-1, 1))
def beta(b, T, E):
# T: transition matrix from n to n+1
# b: next_beta
# E: emission matrix at step n+1
return torch.einsum('nj,nij,nj->ni', b, T, E)
class HMM(nn.Module):
def __init__(self, H, C, D):
''' H: hidden dimension
C: emission dimension
D: feature dimension
'''
super().__init__()
self.H, self.C, self.D = H, C, D
self.trans = nn.Linear(D, H*H)
self.emit = nn.Linear(D, H*C)
# TODO: init as nn.Linear
self.init = nn.Parameter(torch.rand(H), requires_grad=True)
@property
def device(self):
return self.trans.weight.device
def forward(self, x, y):
''' x: [N, T, D]
y: [N, T]
'''
x, y = x.transpose(0, 1), y.transpose(0, 1)
T, N, D = x.shape
log_trans = self.trans(x).view(T, N, self.H, self.H).log_softmax(dim=3)
log_emit = self.emit(x).view(T, N, self.H, self.C).log_softmax(dim=3)
log_emit_prob = torch.einsum('tnhc,tnc->tnh', log_emit, F.one_hot(y, num_classes=self.C).float())
with torch.no_grad():
trans, emit_prob = log_trans.exp(), log_emit_prob.exp()
# gamma: [T, N, H]; xi: [T, N, H, H]
gamma, xi = self.gamma_and_xi_stable(trans, emit_prob) # Baum-Welch E-step
log_prob = torch.einsum('nh,h->', gamma[0], self.init.log_softmax(dim=0))
if len(x) > 1:
log_prob += torch.einsum('tnjk,tnjk->', xi, log_trans[1:])
log_prob += torch.einsum('tnk,tnk->', gamma, log_emit_prob)
return -log_prob / (N)
def alpha_hats(self, Ts, Es):
N = Ts.shape[1]
last_alpha_hat = torch.stack([self.init.softmax(dim=0)]*N)
alpha_hats, cs = [], []
for Tr, Em in zip(Ts, Es):
last_alpha_hat, c = alpha_hat_and_c(last_alpha_hat, Tr, Em)
alpha_hats.append(last_alpha_hat)
cs.append(c)
return torch.stack(alpha_hats), torch.stack(cs)
def gamma_and_xi_stable(self, Trs, Ems):
alpha_hats, cs = self.alpha_hats(Trs, Ems)
beta_hats = self.beta_hats(cs, Trs, Ems).to(self.device)
xi = torch.einsum('tni,tnj,tnij,tnj,tn->tnij', alpha_hats[:-1], Ems[1:], Trs[1:], beta_hats[1:], my_div(1, cs[1:]))
return alpha_hats * beta_hats, xi
def beta_hats(self, c, Trs, Ems):
N = Trs.shape[1]
if len(Trs) == 1:
return torch.ones(1, N, self.H, device=self.device)
else:
future_betas = self.beta_hats(c[1:], Trs[1:], Ems[1:])
next_beta = future_betas[0]
return torch.cat((beta_hat(next_beta, Trs[1], Ems[1], c[1]).view(1, N, self.H), future_betas))
def predict(self, x, y):
N, T, D = x.shape
x, y = x.transpose(0, 1), y.transpose(0, 1)
trans = self.trans(x).view(T, N, self.H, self.H).softmax(dim=3)
emit = self.emit(x).view(T, N, self.H, self.C).softmax(dim=3)
emit_prob = torch.einsum('tnhc,tnc->tnh', emit, F.one_hot(y, num_classes=self.C).float())
alpha_hats, _ = self.alpha_hats(trans, emit_prob)
predicted = []
for Tr, Em, ah in zip(trans, emit, alpha_hats):
predicted.append(torch.einsum('ni,nij,njk->nk', ah, Tr, Em).argmax(dim=1))
return torch.stack(predicted).transpose(0, 1)
if __name__ == '__main__':
H, C, T, D, N = 4, 3, 7, 10, 128
device = 'cuda'
xs = torch.randn(N, T, D, device=device)
ys = torch.randint(C, (N, T), device=device)
lr, num_epochs = 5e-3, 100
model = HMM(H, C, D).to(device)
model.train()
optim = torch.optim.Adam(lr=lr, params=model.parameters())
for epoch in range(num_epochs):
optim.zero_grad()
loss = model(xs, ys)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.)
optim.step() # Baum-Welch M-step
test_xs, test_ys = torch.randn(N, T, D, device=device), torch.randint(C, (N, T), device=device)
predicted_ys = model.predict(test_xs, test_ys)
from sklearn.metrics import accuracy_score
print(accuracy_score(test_ys.flatten().tolist(), predicted_ys.flatten().tolist()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment