Skip to content

Instantly share code, notes, and snippets.

Last active June 12, 2023 20:24
Show Gist options
  • Save fauxneticien/9976752d7c11619c720e99d6ef8e1d7a to your computer and use it in GitHub Desktop.
Save fauxneticien/9976752d7c11619c720e99d6ef8e1d7a to your computer and use it in GitHub Desktop.
Lhotse token collator for CTC
# Modified version with diff (see history)
class TokenCollater:
"""Collate list of tokens
Map sentences to integers. Sentences are padded to equal length.
Beginning and end-of-sequence symbols can be added.
Call .inverse(tokens_batch, tokens_lens) to reconstruct batch as string sentences.
>>> token_collater = TokenCollater(cuts)
>>> tokens_batch, tokens_lens = token_collater(cuts.subset(first=32))
>>> original_sentences = token_collater.inverse(tokens_batch, tokens_lens)
tokens_batch: IntTensor of shape (B, L)
B: batch dimension, number of input sentences
L: length of the longest sentence
tokens_lens: IntTensor of shape (B,)
Length of each sentence after adding <eos> and <bos>
but before padding.
def __init__(
cuts: CutSet,
add_eos: bool = True,
add_bos: bool = True,
add_unk: bool = True,
pad_symbol: str = "<pad>",
bos_symbol: str = "<bos>",
eos_symbol: str = "<eos>",
unk_symbol: str = "<unk>",
self.pad_symbol = pad_symbol
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
self.unk_symbol = unk_symbol
self.add_eos = add_eos
self.add_bos = add_bos
tokens = {char for cut in cuts for char in cut.supervisions[0].text}
tokens_unique = (
+ ([unk_symbol] if add_unk else [])
+ ([bos_symbol] if add_bos else [])
+ ([eos_symbol] if add_eos else [])
+ sorted(tokens)
self.token2idx = {token: idx for idx, token in enumerate(tokens_unique)}
self.idx2token = [token for token in tokens_unique]
def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.Tensor]:
token_sequences = [
" ".join(supervision.text for supervision in cut.supervisions)
for cut in cuts
max_len = len(max(token_sequences, key=len))
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ list(seq)
+ ([self.eos_symbol] if self.add_eos else [])
+ [self.pad_symbol] * (max_len - len(seq))
for seq in token_sequences
tokens_batch = torch.from_numpy(
[[self.token2idx[token] for token in seq] for seq in seqs],
tokens_lens = torch.IntTensor(
len(seq) + int(self.add_eos) + int(self.add_bos)
for seq in token_sequences
return tokens_batch, tokens_lens
def inverse(
self, tokens_batch: torch.LongTensor, tokens_lens: torch.IntTensor
) -> List[str]:
start = 1 if self.add_bos else 0
sentences = [
for idx in tokens_list[start : end - int(self.add_eos)]
for tokens_list, end in zip(tokens_batch, tokens_lens)
return sentences
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment