Last active
October 31, 2024 22:25
-
-
Save MahmoudAshraf97/87a96999cd29efb3b6b13591b89cd1ea to your computer and use it in GitHub Desktop.
Benchmarking several feature extraction methods
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import timeit | |
print(f"Cores: {os.sched_getaffinity(0)}") | |
import cupy as cp | |
import numpy as np | |
import torch | |
def stft( | |
array_module: np, | |
input_tensor: np.ndarray, | |
n_fft: int, | |
hop_length: int = None, | |
win_length: int = None, | |
window: np.ndarray = None, | |
center: bool = True, | |
mode: str = "reflect", | |
normalized: bool = False, | |
onesided: bool = None, | |
return_complex: bool = None, | |
): | |
# Default initialization for hop_length and win_length | |
hop_length = hop_length if hop_length is not None else n_fft // 4 | |
win_length = win_length if win_length is not None else n_fft | |
input_is_complex = np.iscomplexobj(input_tensor) | |
# Determine if the output should be complex | |
return_complex = ( | |
return_complex | |
if return_complex is not None | |
else (input_is_complex or (window is not None and np.iscomplexobj(window))) | |
) | |
if not return_complex and return_complex is None: | |
raise ValueError("stft requires the return_complex parameter for real inputs.") | |
# Input checks | |
if not np.issubdtype(input_tensor.dtype, np.floating) and not input_is_complex: | |
raise ValueError( | |
f"stft: expected a tensor of floating point or complex values, got {input_tensor.dtype}" | |
) | |
if input_tensor.ndim > 2 or input_tensor.ndim < 1: | |
raise ValueError( | |
f"stft: expected a 1D or 2D tensor, but got {input_tensor.ndim}D tensor" | |
) | |
# Handle 1D input | |
if input_tensor.ndim == 1: | |
input_tensor = np.expand_dims(input_tensor, axis=0) | |
input_tensor_1d = True | |
else: | |
input_tensor_1d = False | |
# Center padding if required | |
if center: | |
pad_amount = n_fft // 2 | |
input_tensor = np.pad( | |
input_tensor, ((0, 0), (pad_amount, pad_amount)), mode=mode | |
) | |
batch, length = input_tensor.shape | |
# Additional input checks | |
if n_fft <= 0 or n_fft > length: | |
raise ValueError(f"stft: expected 0 < n_fft <= {length}, but got n_fft={n_fft}") | |
if hop_length <= 0: | |
raise ValueError( | |
f"stft: expected hop_length > 0, but got hop_length={hop_length}" | |
) | |
if win_length <= 0 or win_length > n_fft: | |
raise ValueError( | |
f"stft: expected 0 < win_length <= n_fft, but got win_length={win_length}" | |
) | |
if window is not None: | |
if window.ndim != 1 or window.shape[0] != win_length: | |
raise ValueError( | |
f"stft: expected a 1D window tensor of size equal to win_length={win_length}, " | |
f"but got window with size {window.shape}" | |
) | |
# Handle padding of the window if necessary | |
if win_length < n_fft: | |
left = (n_fft - win_length) // 2 | |
window_ = array_module.zeros(n_fft, dtype=window.dtype) | |
window_[left : left + win_length] = window | |
else: | |
window_ = window | |
# Calculate the number of frames | |
n_frames = 1 + (length - n_fft) // hop_length | |
# Time to columns | |
input_tensor = np.lib.stride_tricks.as_strided( | |
input_tensor, | |
(batch, n_frames, n_fft), | |
( | |
input_tensor.strides[0], | |
hop_length * input_tensor.strides[1], | |
input_tensor.strides[1], | |
), | |
) | |
if window_ is not None: | |
input_tensor = array_module.asarray(input_tensor) * window_ | |
# FFT and transpose | |
complex_fft = input_is_complex | |
onesided = onesided if onesided is not None else not complex_fft | |
if normalized: | |
norm = "ortho" | |
else: | |
norm = None | |
if complex_fft: | |
if onesided: | |
raise ValueError( | |
"Cannot have onesided output if window or input is complex" | |
) | |
output = array_module.fft.fft(input_tensor, n=n_fft, axis=-1, norm=norm) | |
else: | |
output = array_module.fft.rfft(input_tensor, n=n_fft, axis=-1, norm=norm) | |
output = output.transpose((0, 2, 1)) | |
if input_tensor_1d: | |
output = output.squeeze(0) | |
return output if return_complex else array_module.real(output) | |
def get_mel_filters(sr, n_fft, n_mels=128): | |
""" | |
Implementation of librosa.filters.mel in Pytorch | |
""" | |
# Initialize the weights | |
n_mels = int(n_mels) | |
# Center freqs of each FFT bin | |
fftfreqs = torch.fft.rfftfreq(n=n_fft, d=1.0 / sr) | |
# 'Center freqs' of mel bands - uniformly spaced between limits | |
min_mel = 0.0 | |
max_mel = 45.245640471924965 | |
mels = torch.linspace(min_mel, max_mel, n_mels + 2) | |
# Fill in the linear scale | |
f_min = 0.0 | |
f_sp = 200.0 / 3 | |
freqs = f_min + f_sp * mels | |
# And now the nonlinear scale | |
min_log_hz = 1000.0 # beginning of log region (Hz) | |
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) | |
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region | |
# If we have vector data, vectorize | |
log_t = mels >= min_log_mel | |
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel)) | |
mel_f = freqs | |
fdiff = torch.diff(mel_f) | |
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1) | |
lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1) | |
upper = ramps[2:] / fdiff[1:].unsqueeze(1) | |
# Intersect them with each other and zero, vectorized across all i | |
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper)) | |
# Slaney-style mel is scaled to be approx constant energy per channel | |
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) | |
weights *= enorm.unsqueeze(1) | |
return weights | |
def fe_torch( | |
waveform, | |
mel_filters, | |
n_samples=480000, | |
hop_length=160, | |
n_fft=400, | |
sampling_rate=16000, | |
padding=True, | |
chunk_length=None, | |
to_cpu=False, | |
): | |
""" | |
Compute the log-Mel spectrogram of the provided audio. | |
""" | |
if chunk_length is not None: | |
n_samples = chunk_length * sampling_rate | |
nb_max_frames = n_samples // hop_length | |
if waveform.dtype is not torch.float32: | |
waveform = waveform.to(torch.float32) | |
# waveform = ( | |
# waveform.to(self.device) | |
# if self.device == "cuda" and not waveform.is_cuda | |
# else waveform | |
# ) | |
if padding: | |
waveform = torch.nn.functional.pad(waveform, (0, n_samples)) | |
window = torch.hann_window(n_fft).to(waveform.device) | |
stft = torch.stft(waveform, n_fft, hop_length, window=window, return_complex=True) | |
magnitudes = stft[..., :-1].abs() ** 2 | |
mel_spec = mel_filters.to(waveform.device) @ magnitudes | |
log_spec = torch.clamp(mel_spec, min=1e-10).log10() | |
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
# When the model is running on multiple GPUs, the output should be moved | |
# to the CPU since we don't know which GPU will handle the next job. | |
return log_spec.cpu() if to_cpu else log_spec | |
def fe_np( | |
waveform, | |
mel_filters, | |
n_samples=480000, | |
hop_length=160, | |
n_fft=400, | |
sampling_rate=16000, | |
padding=True, | |
chunk_length=None, | |
to_cpu=False, | |
): | |
""" | |
Compute the log-Mel spectrogram of the provided audio. | |
""" | |
if chunk_length is not None: | |
n_samples = chunk_length * sampling_rate | |
nb_max_frames = n_samples // hop_length | |
if waveform.ndim == 1: | |
waveform = np.expand_dims(waveform, axis=0) | |
if waveform.dtype is not np.float32: | |
waveform = waveform.astype(np.float32) | |
if padding: | |
waveform = np.pad(waveform, ((0, 0), (0, 160))) | |
window = np.hanning(n_fft + 1)[:-1].astype("float32") | |
stft_output = stft( | |
np, waveform, n_fft, hop_length, window=window, return_complex=True | |
).astype("complex64") | |
magnitudes = np.abs(stft_output[..., :-1]) ** 2 | |
mel_spec = mel_filters @ magnitudes | |
log_spec = np.log10(np.clip(mel_spec, 1e-10, None)) | |
log_spec = np.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
return log_spec | |
def fe_cp( | |
waveform, | |
mel_filters, | |
n_samples=480000, | |
hop_length=160, | |
n_fft=400, | |
sampling_rate=16000, | |
padding=True, | |
chunk_length=None, | |
to_cpu=False, | |
): | |
""" | |
Compute the log-Mel spectrogram of the provided audio. | |
""" | |
if chunk_length is not None: | |
n_samples = chunk_length * sampling_rate | |
nb_max_frames = n_samples // hop_length | |
if waveform.ndim == 1: | |
waveform = np.expand_dims(waveform, axis=0) | |
if waveform.dtype is not np.float32: | |
waveform = waveform.astype(np.float32) | |
if padding: | |
waveform = np.pad(waveform, ((0, 0), (0, n_samples))) | |
window = cp.hanning(n_fft + 1)[:-1].astype("float32") | |
stft_output = stft( | |
cp, waveform, n_fft, hop_length, window=window, return_complex=True | |
).astype("complex64") | |
magnitudes = cp.abs(stft_output[..., :-1]) ** 2 | |
mel_spec = mel_filters @ magnitudes | |
log_spec = cp.log10(cp.clip(mel_spec, 1e-10, None)) | |
log_spec = cp.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
return log_spec | |
class FeatureExtractor: | |
def __init__( | |
self, | |
feature_size=80, | |
sampling_rate=16000, | |
hop_length=160, | |
chunk_length=30, | |
n_fft=400, | |
): | |
self.n_fft = n_fft | |
self.hop_length = hop_length | |
self.chunk_length = chunk_length | |
self.n_samples = chunk_length * sampling_rate | |
self.nb_max_frames = self.n_samples // hop_length | |
self.time_per_frame = hop_length / sampling_rate | |
self.sampling_rate = sampling_rate | |
self.mel_filters = self.get_mel_filters( | |
sampling_rate, n_fft, n_mels=feature_size | |
) | |
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32): | |
# Initialize the weights | |
n_mels = int(n_mels) | |
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype) | |
# Center freqs of each FFT bin | |
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr) | |
# 'Center freqs' of mel bands - uniformly spaced between limits | |
min_mel = 0.0 | |
max_mel = 45.245640471924965 | |
mels = np.linspace(min_mel, max_mel, n_mels + 2) | |
mels = np.asanyarray(mels) | |
# Fill in the linear scale | |
f_min = 0.0 | |
f_sp = 200.0 / 3 | |
freqs = f_min + f_sp * mels | |
# And now the nonlinear scale | |
min_log_hz = 1000.0 # beginning of log region (Hz) | |
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels) | |
logstep = np.log(6.4) / 27.0 # step size for log region | |
# If we have vector data, vectorize | |
log_t = mels >= min_log_mel | |
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel)) | |
mel_f = freqs | |
fdiff = np.diff(mel_f) | |
ramps = np.subtract.outer(mel_f, fftfreqs) | |
for i in range(n_mels): | |
# lower and upper slopes for all bins | |
lower = -ramps[i] / fdiff[i] | |
upper = ramps[i + 2] / fdiff[i + 1] | |
# .. then intersect them with each other and zero | |
weights[i] = np.maximum(0, np.minimum(lower, upper)) | |
# Slaney-style mel is scaled to be approx constant energy per channel | |
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]) | |
weights *= enorm[:, np.newaxis] | |
return weights | |
def fram_wave(self, waveform, center=True): | |
""" | |
Transform a raw waveform into a list of smaller waveforms. | |
The window length defines how much of the signal is | |
contain in each frame (smalle waveform), while the hope length defines the step | |
between the beginning of each new frame. | |
Centering is done by reflecting the waveform which is first centered around | |
`frame_idx * hop_length`. | |
""" | |
frames = [] | |
for i in range(0, waveform.shape[0] + 1, self.hop_length): | |
half_window = (self.n_fft - 1) // 2 + 1 | |
if center: | |
start = i - half_window if i > half_window else 0 | |
end = ( | |
i + half_window | |
if i < waveform.shape[0] - half_window | |
else waveform.shape[0] | |
) | |
frame = waveform[start:end] | |
if start == 0: | |
padd_width = (-i + half_window, 0) | |
frame = np.pad(frame, pad_width=padd_width, mode="reflect") | |
elif end == waveform.shape[0]: | |
padd_width = (0, (i - waveform.shape[0] + half_window)) | |
frame = np.pad(frame, pad_width=padd_width, mode="reflect") | |
else: | |
frame = waveform[i : i + self.n_fft] | |
frame_width = frame.shape[0] | |
if frame_width < waveform.shape[0]: | |
frame = np.lib.pad( | |
frame, | |
pad_width=(0, self.n_fft - frame_width), | |
mode="constant", | |
constant_values=0, | |
) | |
frames.append(frame) | |
return np.stack(frames, 0) | |
def stft(self, frames, window): | |
""" | |
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal. | |
Should give the same results as `torch.stft`. | |
""" | |
frame_size = frames.shape[1] | |
fft_size = self.n_fft | |
if fft_size is None: | |
fft_size = frame_size | |
if fft_size < frame_size: | |
raise ValueError("FFT size must greater or equal the frame size") | |
# number of FFT bins to store | |
num_fft_bins = (fft_size >> 1) + 1 | |
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64) | |
fft_signal = np.zeros(fft_size) | |
for f, frame in enumerate(frames): | |
if window is not None: | |
np.multiply(frame, window, out=fft_signal[:frame_size]) | |
else: | |
fft_signal[:frame_size] = frame | |
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins] | |
return data.T | |
def __call__(self, waveform, padding=True, chunk_length=None): | |
""" | |
Compute the log-Mel spectrogram of the provided audio, gives similar results | |
whisper's original torch implementation with 1e-5 tolerance. | |
""" | |
if chunk_length is not None: | |
self.n_samples = chunk_length * self.sampling_rate | |
self.nb_max_frames = self.n_samples // self.hop_length | |
if padding: | |
waveform = np.pad(waveform, [(0, self.n_samples)]) | |
window = np.hanning(self.n_fft + 1)[:-1] | |
frames = self.fram_wave(waveform) | |
stft = self.stft(frames, window=window) | |
magnitudes = np.abs(stft[:, :-1]) ** 2 | |
filters = self.mel_filters | |
mel_spec = filters @ magnitudes | |
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None)) | |
log_spec = np.maximum(log_spec, log_spec.max() - 8.0) | |
log_spec = (log_spec + 4.0) / 4.0 | |
return log_spec | |
window = torch.hann_window(400) | |
n_fft = 400 | |
hop_length = 160 | |
audio = torch.randn((30 * 1 * 16000)).float() | |
mel_filters = get_mel_filters(16000, n_fft, 128) | |
old_fe = FeatureExtractor(128) | |
# torch on CPU | |
print("Torch on CPU:") | |
print(f"{timeit.timeit(lambda: fe_torch(audio, mel_filters), number=1000):.2f} ms") | |
audio_cuda = audio.cuda() | |
mel_filters_cuda = mel_filters.cuda() | |
# torch on GPU | |
print("Torch on GPU:") | |
print( | |
f"{timeit.timeit(lambda: fe_torch(audio_cuda, mel_filters_cuda), number=1000):.2f} ms" | |
) | |
audio = audio.numpy() | |
mel_filters = mel_filters.numpy() | |
# New numpy on CPU | |
print("New numpy on CPU:") | |
print(f"{timeit.timeit(lambda: fe_np(audio, mel_filters), number=100) * 10 :.2f} ms") | |
mel_filters = cp.asarray(mel_filters) | |
# New CuPY on GPU | |
print("New CuPY on GPU:") | |
print(f"{timeit.timeit(lambda: fe_cp(audio, mel_filters), number=1000):.2f} ms") | |
# Old numpy on CPU | |
print("Old numpy on CPU:") | |
print(f"{timeit.timeit(lambda: old_fe(audio), number=100) * 10:.2f} ms") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment