Skip to content

Instantly share code, notes, and snippets.

@brentspell
Last active December 16, 2023 09:21
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brentspell/3ad8c46b68a49f5da1bc14817b905ced to your computer and use it in GitHub Desktop.
Save brentspell/3ad8c46b68a49f5da1bc14817b905ced to your computer and use it in GitHub Desktop.
""" a simple gmm-based speech trimmer
Copyright © 2022 Brent M. Spell
Licensed under the MIT License (the "License"). You may not use this package
except in compliance with the License. You may obtain a copy of the License at
https://opensource.org/licenses/MIT
Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
"""
import numpy as np
import torch
def trim(
x,
fs,
block_length = 0.025,
block_overlap=0.6,
ref_db=-18.0,
signal_estimate=-20.0,
noise_estimate=-60.0,
pad_length=0.25,
edge_only=False,
):
# frame the signal into overlapping windows and apply a hann window function
frame_length = int(block_length * fs)
frame_stride = int((1 - block_overlap) * frame_length)
window = np.hanning(frame_length)
frames = x.unfold(-1, size=frame_length, step=frame_stride)
frames = frames * torch.from_numpy(window).to(frames.device)
# compute the peak RMS energy of the frames and calculate normalized power
peak_energy = frames.square().mean(-1).sqrt().max()
ref_energy = 10.0 ** (ref_db / 20.0)
gain = torch.nan_to_num(ref_energy / peak_energy)
energy = (frames * gain).square().mean(-1).sqrt()
power = 20.0 * (energy + 1e-5).log10()
# fit the GMM to the power vector and extract the modes
gmm = sklearn.mixture.GaussianMixture(
n_components=2,
init_params="random",
means_init=np.array([noise_estimate, signal_estimate])[..., np.newaxis],
).fit(power.unsqueeze(-1).cpu().numpy())
modes = gmm.means_[..., 0]
# use the midpoint of the modes as the cutoff
noise, signal = sorted(modes)
cutoff = modes.mean()
keep = power > cutoff
if signal > -10:
return x
# apply a smoothing filter to the keep vector
pad_frames = int(pad_length * fs / frame_stride) * 2
pad_frames += 1 - pad_frames % 2
pad_weight = 1.0 / pad_frames
pad_filter = torch.full([1, 1, pad_frames], pad_weight).to(keep.device)
keep = torch.nn.functional.conv1d(
keep.type(torch.float32).view(1, 1, -1),
pad_filter,
padding="same",
).flatten() >= pad_weight
# if we are only trimming edges, ignore interior silence
if edge_only:
start = keep.int().argmax()
stop = len(keep) - keep.flip(-1).int().argmax()
keep[start:stop] = True
# discard any trimmed frames
trimmed = frames[keep]
# create a synthesis window for the overlap-and-add operation
denom = window ** 2.0
overlap = -(-frame_length // frame_stride)
denom = np.pad(denom, [(0, overlap * frame_stride - frame_length)])
denom = denom.reshape([overlap, frame_stride])
denom = denom.sum(0, keepdims=True)
denom = np.tile(denom, [overlap, 1])
denom = denom.reshape([overlap * frame_stride])[:frame_length]
window /= denom
trimmed = trimmed * torch.from_numpy(window).to(trimmed.device)
# reconstruct the trimmed audio signal
y = torch.nn.functional.fold(
trimmed.transpose(0, 1).unsqueeze(0),
output_size=(1, (trimmed.shape[0] - 1) * frame_stride + frame_length),
kernel_size=(1, frame_length),
stride=frame_stride,
).flatten()
return y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment