Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Last active March 3, 2022 13:15
Show Gist options
  • Save TeaPoly/98a9224f5a4e847b9e6937aaf8efe2d3 to your computer and use it in GitHub Desktop.
Save TeaPoly/98a9224f5a4e847b9e6937aaf8efe2d3 to your computer and use it in GitHub Desktop.
The GCC-PHAT algorithm is applied to align the far end and near end signals based on Pytorch.
#!/usr/bin/env python3
# Copyright 2022 Lucky Wong
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
"""The GCC-PHAT algorithm is applied to align the far end and near end signals.
Ref: Weighted Recursive Least Square Filter and Neural Network based Residual Echo Suppression for the AEC-Challenge
Link: http://arxiv.org/abs/2102.08551
"""
from typing import Tuple, List, Optional
import math
import torch
def gcc_phat_frame(
probe_fft: torch.Tensor, refence_fft: torch.Tensor,
last_cross_corr: Optional[torch.Tensor] = None,
sample_rate: int = 16000,
smooth_parameter: float = 0.8,
window_stride: int = -1
):
"""Compute relative delay.
Args:
probe_fft (torch.Tensor): Fourier transform of real-valued of input frame. (time, fft)
refence_fft (torch.Tensor): Fourier transform of real-valued of refence frame. (time, fft)
Returns:
float: relative delay duration (ms)
"""
# cross correlation
cross_corr = probe_fft * torch.conj(refence_fft)
# smoothing
if last_cross_corr is not None:
smooth_cross_corr = smooth_parameter * \
last_cross_corr+(1-smooth_parameter)*cross_corr
else:
smooth_cross_corr = cross_corr
# find max cross correlation index
ifft = torch.fft.irfft(
smooth_cross_corr / torch.abs(smooth_cross_corr))
ifft = ifft[:window_stride]
max_index = torch.argmax(ifft).item()
# relative delay
tau = int(max_index/sample_rate*1000.)
return tau, smooth_cross_corr
class GccPhatAlign():
""" The GCC-PHAT algorithm is applied to align the far end and near end signals.
Ref: Weighted Recursive Least Square Filter and Neural Network based Residual Echo Suppression for the AEC-Challenge
Link: http://arxiv.org/abs/2102.08551
Args:
window_stride_ms (int): window stride duration ms
fs (int): Sample rate
smooth_parameter (float): smoothing parameter
"""
def __init__(self, window_stride_ms: int = 500, fs: int = 16000, smooth_parameter: float = 0.8):
"""Construct an EncoderLayer object."""
self.window_stride = int(window_stride_ms/1000.*fs)
self.fs = fs
self.window_len = self.window_stride*2
self.n_fft = (int)(2**math.ceil(math.log2(self.window_len)))
self.window = torch.hamming_window(
self.window_len, dtype=torch.float32)
self.smooth_parameter = smooth_parameter
def __call__(self, probe, ref):
"""Estimate last relative delay.
Args:
probe (torch.Tensor): Input signal.
ref (torch.Tensor): Refence signal.
Returns:
float: relative delay duration (ms)
"""
if probe.dtype == torch.int16:
probe = probe.to(dtype=torch.float32)
if ref.dtype == torch.int16:
ref = ref.to(dtype=torch.float32)
probe_frames = probe.unfold(0, self.window_len, self.window_stride)
refence_frames = ref.unfold(0, self.window_len, self.window_stride)
probe_fft = torch.fft.rfft(
probe_frames * self.window, n=self.n_fft)
refence_fft = torch.fft.rfft(
refence_frames * self.window, n=self.n_fft)
frame_num = min(probe_fft.size()[0], refence_frames.size()[0])
last_cross_corr = None
tau_list = []
for i in range(frame_num):
# relative delay
tau, smooth_cross_corr = gcc_phat_frame(
probe_fft[i],
refence_fft[i],
last_cross_corr,
sample_rate=self.fs,
smooth_parameter=self.smooth_parameter,
window_stride=self.window_stride
)
tau_list.append(tau)
last_cross_corr = smooth_cross_corr
# We can use every window_stride to update relative delay, here is the most frequent delay value.
return max(tau_list, key=tau_list.count)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment