-
-
Save brentspell/3ad8c46b68a49f5da1bc14817b905ced to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" a simple gmm-based speech trimmer | |
Copyright © 2022 Brent M. Spell | |
Licensed under the MIT License (the "License"). You may not use this package | |
except in compliance with the License. You may obtain a copy of the License at | |
https://opensource.org/licenses/MIT | |
Unless required by applicable law or agreed to in writing, software distributed | |
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |
CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
specific language governing permissions and limitations under the License. | |
""" | |
import numpy as np | |
import torch | |
def trim( | |
x, | |
fs, | |
block_length = 0.025, | |
block_overlap=0.6, | |
ref_db=-18.0, | |
signal_estimate=-20.0, | |
noise_estimate=-60.0, | |
pad_length=0.25, | |
edge_only=False, | |
): | |
# frame the signal into overlapping windows and apply a hann window function | |
frame_length = int(block_length * fs) | |
frame_stride = int((1 - block_overlap) * frame_length) | |
window = np.hanning(frame_length) | |
frames = x.unfold(-1, size=frame_length, step=frame_stride) | |
frames = frames * torch.from_numpy(window).to(frames.device) | |
# compute the peak RMS energy of the frames and calculate normalized power | |
peak_energy = frames.square().mean(-1).sqrt().max() | |
ref_energy = 10.0 ** (ref_db / 20.0) | |
gain = torch.nan_to_num(ref_energy / peak_energy) | |
energy = (frames * gain).square().mean(-1).sqrt() | |
power = 20.0 * (energy + 1e-5).log10() | |
# fit the GMM to the power vector and extract the modes | |
gmm = sklearn.mixture.GaussianMixture( | |
n_components=2, | |
init_params="random", | |
means_init=np.array([noise_estimate, signal_estimate])[..., np.newaxis], | |
).fit(power.unsqueeze(-1).cpu().numpy()) | |
modes = gmm.means_[..., 0] | |
# use the midpoint of the modes as the cutoff | |
noise, signal = sorted(modes) | |
cutoff = modes.mean() | |
keep = power > cutoff | |
if signal > -10: | |
return x | |
# apply a smoothing filter to the keep vector | |
pad_frames = int(pad_length * fs / frame_stride) * 2 | |
pad_frames += 1 - pad_frames % 2 | |
pad_weight = 1.0 / pad_frames | |
pad_filter = torch.full([1, 1, pad_frames], pad_weight).to(keep.device) | |
keep = torch.nn.functional.conv1d( | |
keep.type(torch.float32).view(1, 1, -1), | |
pad_filter, | |
padding="same", | |
).flatten() >= pad_weight | |
# if we are only trimming edges, ignore interior silence | |
if edge_only: | |
start = keep.int().argmax() | |
stop = len(keep) - keep.flip(-1).int().argmax() | |
keep[start:stop] = True | |
# discard any trimmed frames | |
trimmed = frames[keep] | |
# create a synthesis window for the overlap-and-add operation | |
denom = window ** 2.0 | |
overlap = -(-frame_length // frame_stride) | |
denom = np.pad(denom, [(0, overlap * frame_stride - frame_length)]) | |
denom = denom.reshape([overlap, frame_stride]) | |
denom = denom.sum(0, keepdims=True) | |
denom = np.tile(denom, [overlap, 1]) | |
denom = denom.reshape([overlap * frame_stride])[:frame_length] | |
window /= denom | |
trimmed = trimmed * torch.from_numpy(window).to(trimmed.device) | |
# reconstruct the trimmed audio signal | |
y = torch.nn.functional.fold( | |
trimmed.transpose(0, 1).unsqueeze(0), | |
output_size=(1, (trimmed.shape[0] - 1) * frame_stride + frame_length), | |
kernel_size=(1, frame_length), | |
stride=frame_stride, | |
).flatten() | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment