Skip to content

Instantly share code, notes, and snippets.

@MahmoudAshraf97
Last active October 31, 2024 22:25
Show Gist options
  • Save MahmoudAshraf97/87a96999cd29efb3b6b13591b89cd1ea to your computer and use it in GitHub Desktop.
Save MahmoudAshraf97/87a96999cd29efb3b6b13591b89cd1ea to your computer and use it in GitHub Desktop.
Benchmarking several feature extraction methods
import os
import timeit
print(f"Cores: {os.sched_getaffinity(0)}")
import cupy as cp
import numpy as np
import torch
def stft(
array_module: np,
input_tensor: np.ndarray,
n_fft: int,
hop_length: int = None,
win_length: int = None,
window: np.ndarray = None,
center: bool = True,
mode: str = "reflect",
normalized: bool = False,
onesided: bool = None,
return_complex: bool = None,
):
# Default initialization for hop_length and win_length
hop_length = hop_length if hop_length is not None else n_fft // 4
win_length = win_length if win_length is not None else n_fft
input_is_complex = np.iscomplexobj(input_tensor)
# Determine if the output should be complex
return_complex = (
return_complex
if return_complex is not None
else (input_is_complex or (window is not None and np.iscomplexobj(window)))
)
if not return_complex and return_complex is None:
raise ValueError("stft requires the return_complex parameter for real inputs.")
# Input checks
if not np.issubdtype(input_tensor.dtype, np.floating) and not input_is_complex:
raise ValueError(
f"stft: expected a tensor of floating point or complex values, got {input_tensor.dtype}"
)
if input_tensor.ndim > 2 or input_tensor.ndim < 1:
raise ValueError(
f"stft: expected a 1D or 2D tensor, but got {input_tensor.ndim}D tensor"
)
# Handle 1D input
if input_tensor.ndim == 1:
input_tensor = np.expand_dims(input_tensor, axis=0)
input_tensor_1d = True
else:
input_tensor_1d = False
# Center padding if required
if center:
pad_amount = n_fft // 2
input_tensor = np.pad(
input_tensor, ((0, 0), (pad_amount, pad_amount)), mode=mode
)
batch, length = input_tensor.shape
# Additional input checks
if n_fft <= 0 or n_fft > length:
raise ValueError(f"stft: expected 0 < n_fft <= {length}, but got n_fft={n_fft}")
if hop_length <= 0:
raise ValueError(
f"stft: expected hop_length > 0, but got hop_length={hop_length}"
)
if win_length <= 0 or win_length > n_fft:
raise ValueError(
f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}"
)
if window is not None:
if window.ndim != 1 or window.shape[0] != win_length:
raise ValueError(
f"stft: expected a 1D window tensor of size equal to win_length={win_length}, "
f"but got window with size {window.shape}"
)
# Handle padding of the window if necessary
if win_length < n_fft:
left = (n_fft - win_length) // 2
window_ = array_module.zeros(n_fft, dtype=window.dtype)
window_[left : left + win_length] = window
else:
window_ = window
# Calculate the number of frames
n_frames = 1 + (length - n_fft) // hop_length
# Time to columns
input_tensor = np.lib.stride_tricks.as_strided(
input_tensor,
(batch, n_frames, n_fft),
(
input_tensor.strides[0],
hop_length * input_tensor.strides[1],
input_tensor.strides[1],
),
)
if window_ is not None:
input_tensor = array_module.asarray(input_tensor) * window_
# FFT and transpose
complex_fft = input_is_complex
onesided = onesided if onesided is not None else not complex_fft
if normalized:
norm = "ortho"
else:
norm = None
if complex_fft:
if onesided:
raise ValueError(
"Cannot have onesided output if window or input is complex"
)
output = array_module.fft.fft(input_tensor, n=n_fft, axis=-1, norm=norm)
else:
output = array_module.fft.rfft(input_tensor, n=n_fft, axis=-1, norm=norm)
output = output.transpose((0, 2, 1))
if input_tensor_1d:
output = output.squeeze(0)
return output if return_complex else array_module.real(output)
def get_mel_filters(sr, n_fft, n_mels=128):
"""
Implementation of librosa.filters.mel in Pytorch
"""
# Initialize the weights
n_mels = int(n_mels)
# Center freqs of each FFT bin
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = torch.linspace(min_mel, max_mel, n_mels + 2)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = torch.diff(mel_f)
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
upper = ramps[2:] / fdiff[1:].unsqueeze(1)
# Intersect them with each other and zero, vectorized across all i
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm.unsqueeze(1)
return weights
def fe_torch(
waveform,
mel_filters,
n_samples=480000,
hop_length=160,
n_fft=400,
sampling_rate=16000,
padding=True,
chunk_length=None,
to_cpu=False,
):
"""
Compute the log-Mel spectrogram of the provided audio.
"""
if chunk_length is not None:
n_samples = chunk_length * sampling_rate
nb_max_frames = n_samples // hop_length
if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)
# waveform = (
# waveform.to(self.device)
# if self.device == "cuda" and not waveform.is_cuda
# else waveform
# )
if padding:
waveform = torch.nn.functional.pad(waveform, (0, n_samples))
window = torch.hann_window(n_fft).to(waveform.device)
stft = torch.stft(waveform, n_fft, hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_spec = mel_filters.to(waveform.device) @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
# When the model is running on multiple GPUs, the output should be moved
# to the CPU since we don't know which GPU will handle the next job.
return log_spec.cpu() if to_cpu else log_spec
def fe_np(
waveform,
mel_filters,
n_samples=480000,
hop_length=160,
n_fft=400,
sampling_rate=16000,
padding=True,
chunk_length=None,
to_cpu=False,
):
"""
Compute the log-Mel spectrogram of the provided audio.
"""
if chunk_length is not None:
n_samples = chunk_length * sampling_rate
nb_max_frames = n_samples // hop_length
if waveform.ndim == 1:
waveform = np.expand_dims(waveform, axis=0)
if waveform.dtype is not np.float32:
waveform = waveform.astype(np.float32)
if padding:
waveform = np.pad(waveform, ((0, 0), (0, 160)))
window = np.hanning(n_fft + 1)[:-1].astype("float32")
stft_output = stft(
np, waveform, n_fft, hop_length, window=window, return_complex=True
).astype("complex64")
magnitudes = np.abs(stft_output[..., :-1]) ** 2
mel_spec = mel_filters @ magnitudes
log_spec = np.log10(np.clip(mel_spec, 1e-10, None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
def fe_cp(
waveform,
mel_filters,
n_samples=480000,
hop_length=160,
n_fft=400,
sampling_rate=16000,
padding=True,
chunk_length=None,
to_cpu=False,
):
"""
Compute the log-Mel spectrogram of the provided audio.
"""
if chunk_length is not None:
n_samples = chunk_length * sampling_rate
nb_max_frames = n_samples // hop_length
if waveform.ndim == 1:
waveform = np.expand_dims(waveform, axis=0)
if waveform.dtype is not np.float32:
waveform = waveform.astype(np.float32)
if padding:
waveform = np.pad(waveform, ((0, 0), (0, n_samples)))
window = cp.hanning(n_fft + 1)[:-1].astype("float32")
stft_output = stft(
cp, waveform, n_fft, hop_length, window=window, return_complex=True
).astype("complex64")
magnitudes = cp.abs(stft_output[..., :-1]) ** 2
mel_spec = mel_filters @ magnitudes
log_spec = cp.log10(cp.clip(mel_spec, 1e-10, None))
log_spec = cp.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
class FeatureExtractor:
def __init__(
self,
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
):
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.time_per_frame = hop_length / sampling_rate
self.sampling_rate = sampling_rate
self.mel_filters = self.get_mel_filters(
sampling_rate, n_fft, n_mels=feature_size
)
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
# Initialize the weights
n_mels = int(n_mels)
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
# Center freqs of each FFT bin
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mels = np.asanyarray(mels)
# Fill in the linear scale
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
# And now the nonlinear scale
min_log_hz = 1000.0 # beginning of log region (Hz)
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = np.log(6.4) / 27.0 # step size for log region
# If we have vector data, vectorize
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
mel_f = freqs
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
# lower and upper slopes for all bins
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
# .. then intersect them with each other and zero
weights[i] = np.maximum(0, np.minimum(lower, upper))
# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm[:, np.newaxis]
return weights
def fram_wave(self, waveform, center=True):
"""
Transform a raw waveform into a list of smaller waveforms.
The window length defines how much of the signal is
contain in each frame (smalle waveform), while the hope length defines the step
between the beginning of each new frame.
Centering is done by reflecting the waveform which is first centered around
`frame_idx * hop_length`.
"""
frames = []
for i in range(0, waveform.shape[0] + 1, self.hop_length):
half_window = (self.n_fft - 1) // 2 + 1
if center:
start = i - half_window if i > half_window else 0
end = (
i + half_window
if i < waveform.shape[0] - half_window
else waveform.shape[0]
)
frame = waveform[start:end]
if start == 0:
padd_width = (-i + half_window, 0)
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
elif end == waveform.shape[0]:
padd_width = (0, (i - waveform.shape[0] + half_window))
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
else:
frame = waveform[i : i + self.n_fft]
frame_width = frame.shape[0]
if frame_width < waveform.shape[0]:
frame = np.lib.pad(
frame,
pad_width=(0, self.n_fft - frame_width),
mode="constant",
constant_values=0,
)
frames.append(frame)
return np.stack(frames, 0)
def stft(self, frames, window):
"""
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
Should give the same results as `torch.stft`.
"""
frame_size = frames.shape[1]
fft_size = self.n_fft
if fft_size is None:
fft_size = frame_size
if fft_size < frame_size:
raise ValueError("FFT size must greater or equal the frame size")
# number of FFT bins to store
num_fft_bins = (fft_size >> 1) + 1
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
fft_signal = np.zeros(fft_size)
for f, frame in enumerate(frames):
if window is not None:
np.multiply(frame, window, out=fft_signal[:frame_size])
else:
fft_signal[:frame_size] = frame
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
return data.T
def __call__(self, waveform, padding=True, chunk_length=None):
"""
Compute the log-Mel spectrogram of the provided audio, gives similar results
whisper's original torch implementation with 1e-5 tolerance.
"""
if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length
if padding:
waveform = np.pad(waveform, [(0, self.n_samples)])
window = np.hanning(self.n_fft + 1)[:-1]
frames = self.fram_wave(waveform)
stft = self.stft(frames, window=window)
magnitudes = np.abs(stft[:, :-1]) ** 2
filters = self.mel_filters
mel_spec = filters @ magnitudes
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
window = torch.hann_window(400)
n_fft = 400
hop_length = 160
audio = torch.randn((30 * 1 * 16000)).float()
mel_filters = get_mel_filters(16000, n_fft, 128)
old_fe = FeatureExtractor(128)
# torch on CPU
print("Torch on CPU:")
print(f"{timeit.timeit(lambda: fe_torch(audio, mel_filters), number=1000):.2f} ms")
audio_cuda = audio.cuda()
mel_filters_cuda = mel_filters.cuda()
# torch on GPU
print("Torch on GPU:")
print(
f"{timeit.timeit(lambda: fe_torch(audio_cuda, mel_filters_cuda), number=1000):.2f} ms"
)
audio = audio.numpy()
mel_filters = mel_filters.numpy()
# New numpy on CPU
print("New numpy on CPU:")
print(f"{timeit.timeit(lambda: fe_np(audio, mel_filters), number=100) * 10 :.2f} ms")
mel_filters = cp.asarray(mel_filters)
# New CuPY on GPU
print("New CuPY on GPU:")
print(f"{timeit.timeit(lambda: fe_cp(audio, mel_filters), number=1000):.2f} ms")
# Old numpy on CPU
print("Old numpy on CPU:")
print(f"{timeit.timeit(lambda: old_fe(audio), number=100) * 10:.2f} ms")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment