Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Created May 30, 2023 02:53
Show Gist options
  • Save TeaPoly/28c554b064a23c4a4ab7927f4a268ecc to your computer and use it in GitHub Desktop.
Save TeaPoly/28c554b064a23c4a4ab7927f4a268ecc to your computer and use it in GitHub Desktop.
Deep model with built-in self-attention alignment for acoustic echo cancellation
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright 2023 Lucky Wong
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import torch
import math
def make_pad_mask(lengths: torch.Tensor, max_len: int = None) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = int(lengths.size(0))
if max_len is None:
max_len = int(lengths.max().item())
seq_range = torch.arange(
0, max_len, dtype=torch.int64, device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def subsequent_chunk_mask(
size: int,
chunk_size: int,
left_chunk_size: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
left_chunk_size (int): size of history chunk size
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(4, 2, left_chunk_size=1)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 1, 1, 1],
[0, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, 0:ending] = True
if left_chunk_size != -1:
left_start = max(0, (i // chunk_size) *
chunk_size - left_chunk_size)
ret[i, 0:left_start] = False
return ret
class AttAlignBlock(torch.nn.Module):
"""Attention Align Block.
Reference: Deep model with built-in self-attention alignment for acoustic echo cancellation
Link: https://arxiv.org/pdf/2208.11308.pdf
Args:
mdim (int): The number of features.
fdim (int): The number of farend features.
pdim (int): The projection size.
chunk_size (int): The left chunk size.
max_delay_blocks (int): The max delay chunk size.
"""
def __init__(
self,
mdim: int = 2048,
fdim: int = 256,
pdim: int = 64,
chunk_size: int = 1,
max_delay_blocks: int = 80,
):
"""Construct an AttAlignBlock object."""
super().__init__()
self.chunk_size = chunk_size
self.max_delay_blocks = max_delay_blocks
self.pdim = pdim
self.linear_q = torch.nn.Linear(mdim, pdim)
self.linear_k = torch.nn.Linear(fdim, pdim)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute scaled dot product attention.
Args:
query (torch.Tensor): Microphone features tensor (#batch, time, size).
key (torch.Tensor): Farend features tensor (#batch, time, size).
lengths (torch.Tensor): Lengths tensor (#batch)
Returns:
torch.Tensor: Output tensor (#batch, time, d_model).
"""
q = self.linear_q(query)
k = self.linear_k(key)
# (#batch, time, size)*(#batch, time, size).T -> (#batch, time, time)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.pdim)
masks = ~make_pad_mask(lengths).unsqueeze(
1).to(query.device) # (B, 1, L)
chunk_masks = subsequent_chunk_mask(
query.size(1), self.chunk_size, self.max_delay_blocks, query.device
) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
chunk_masks = chunk_masks.eq(0) # (batch, *, time)
scores = scores.masked_fill(chunk_masks, -float("inf"))
# (#batch, time, time)
attn = torch.softmax(scores, dim=-1).masked_fill(
chunk_masks, 0.0
) # (batch, head, time, time)
# (#batch, time, time)*(#batch, time, size) -> (#batch, time, size)
return torch.matmul(attn, key), attn
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment