Skip to content

Instantly share code, notes, and snippets.

@revsic
Created September 8, 2021 15:58
Show Gist options
  • Save revsic/323d269ddf4ea2f9b1075a4ec2847abf to your computer and use it in GitHub Desktop.
Save revsic/323d269ddf4ea2f9b1075a4ec2847abf to your computer and use it in GitHub Desktop.
monotonic alignment search
def search(log_prob: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Monotonic alignment search from Glow-TTS.
Args:
log_prob: [B, S, T], log-probability.
mask: [B, S, T], attention mask.
Returns:
[B, S, T], alignment.
"""
bsize, seqlen, timestep = log_prob.shape
# [B, S, T]
direction = torch.zeros_like(log_prob, dtype=torch.long)
# [B, S]
prob = torch.zeros(bsize, seqlen, device=log_prob.device)
# [1, S]
arange = torch.arange(seqlen, device=log_prob.device)[None]
for j in range(timestep):
# [B, S]
prev = F.pad(prob[:, :-1], [1, 0], value=-np.inf)
# [B, S]
direction[:, :, j] = prob >= prev
# [B, S]
prob = torch.maximum(prob, prev) + log_prob[:, :, j]
prob.masked_fill_(arange > j, -np.inf)
# masking
direction.masked_fill_(~mask.to(torch.bool), 1)
# [B, S, T]
attn = torch.zeros_like(log_prob)
# [B], textlen
index = mask[..., 0].sum(dim=-1).long() - 1
# [B]
batch = torch.arange(bsize, device=log_prob.device)
for j in reversed(range(timestep)):
assert (index >= 0).all(), 'negative index approached'
attn[batch, index, j] = 1
index = index + direction[batch, index, j] - 1
# [B, S, T]
return attn * mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment