Skip to content

Instantly share code, notes, and snippets.

@zilunpeng
Created March 23, 2021 22:13
Show Gist options
  • Save zilunpeng/3e8e6ec6b3f7355629ebe187eaf5ec3c to your computer and use it in GitHub Desktop.
Save zilunpeng/3e8e6ec6b3f7355629ebe187eaf5ec3c to your computer and use it in GitHub Desktop.
Make calls to the C++ method for Viterbi decoding. Code below is part of utils.py (https://git.io/JYeHy).
def decode(self, emissions):
B, T, N = emissions.size()
hypos = list()
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),)
return [[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}] for b in range(B)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment