Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Last active September 14, 2022 07:12
Show Gist options
  • Save TeaPoly/6ecc8dfa46476f6bb53f4f63516719d2 to your computer and use it in GitHub Desktop.
Save TeaPoly/6ecc8dfa46476f6bb53f4f63516719d2 to your computer and use it in GitHub Desktop.
The implementation of self-attention which is helpful to improve multi-channel KWS performance as well as reduce computational complexity. Inspired from paper Joint Ego-Noise Suppression and Keyword Spotting on Sweeping Robots.
#!/usr/bin/env python3
# Copyright 2022 Lucky Wong
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""Multi-channel Attention layer definition."""
import torch
class MultiChannelAttention(torch.nn.Module):
'''Multi-channel Attention layer.
Soft self-attention is helpful to improve multi-channel KWS performance
as well as reduce computational complexity.
Args:
input_size (int): The number of input features.
hidden_dim (int): The number of hidden features.
Ref: JOINT EGO-NOISE SUPPRESSION AND KEYWORD SPOTTING ON SWEEPING ROBOTS
https://ieeexplore.ieee.org/document/9747084
'''
def __init__(self, input_size: int, hidden_dim: int):
"""Construct an MultiChannelAttention object."""
super().__init__()
self.att_weight = torch.nn.Sequential(
torch.nn.Linear(input_size, hidden_dim),
torch.nn.Tanh(),
torch.nn.Linear(hidden_dim, 1, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Transform multi-channel input features.
Args:
x (torch.Tensor): Multi-channel input feature tensor (#batch, time, channel, size).
Returns:
torch.Tensor: Single-channel transformed feature tensor, size
(#batch, time, size).
"""
# b,t,c,d -> b,t,c,1
g = torch.softmax(self.att_weight(x), dim=-2)
return torch.sum(g*x, dim=-2)
if __name__ == "__main__":
channel = 6
batch_size = 32
seq_len = 100
feature_dim = 80
hidden_dim = 128
f = torch.randn(batch_size, seq_len, channel, feature_dim)
att = MultiChannelAttention(feature_dim, hidden_dim)
f_new = att(f)
print(f.size(), f_new.size())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment