Skip to content

Instantly share code, notes, and snippets.

@tz579

tz579/test.py Secret

Last active April 4, 2019 06:25
Show Gist options
  • Save tz579/7067976e1e78a9d0d0897adb446acb2c to your computer and use it in GitHub Desktop.
Save tz579/7067976e1e78a9d0d0897adb446acb2c to your computer and use it in GitHub Desktop.
inverse STFT tz579's modification/simplification from the work of keunwoochoi & faroit
import test_scipy
import test_librosa
import test_torch
import utils
import numpy as np
import matplotlib.pyplot as plt
stfts = [test_torch, test_scipy, test_librosa]
istfts = [test_torch, test_scipy, test_librosa]
if __name__ == "__main__":
s = utils.sine(dtype=np.float32)
for forward_method in stfts:
stft = getattr(forward_method, 'stft')
for inverse_method in istfts:
istft = getattr(inverse_method, 'istft')
X = stft(s)
x = istft(X)
print(
forward_method.__name__,
"-->",
inverse_method.__name__,
utils.rms(s, x)
)
import librosa
import numpy as np
import utils
def stft(x, n_fft=2048, n_hopsize=1024, center=True, window='hann', dtype=np.complex64):
return librosa.core.stft(x, n_fft=n_fft, hop_length=n_hopsize, pad_mode='constant', center=center, window=window, dtype=dtype)
def istft(X, rate=44100, n_fft=2048, n_hopsize=1024, center=True, window='hann', dtype=np.float32):
return librosa.core.istft(X, hop_length=n_hopsize, center=center, window=window, dtype=dtype)
def spectrogram(X, power):
return np.abs(X)**power
if __name__ == "__main__":
s = utils.sine()
X = stft(s)
print(X.shape)
x = istft(X, rate=44100)
print(utils.rms(s, x))
import scipy.signal
import numpy as np
import utils
import librosa
import torch
def stft(x, n_fft=2048, n_hopsize=1024, window='hann'):
f, t, X = scipy.signal.stft(
x,
nperseg=n_fft,
noverlap=n_fft - n_hopsize,
window=window,
padded=True,
)
return X * n_hopsize
def istft(X, rate=44100, n_fft=2048, n_hopsize=1024, window='hann'):
t, audio = scipy.signal.istft(
X / n_hopsize,
rate,
nperseg=n_fft,
noverlap=n_fft - n_hopsize,
window=window,
boundary=True
)
return audio
def spectrogram(X, power):
return np.abs(X)**power
if __name__ == "__main__":
s = utils.sine()
X = stft(s)
print(X.shape)
x = istft(X, rate=44100)
print(utils.rms(s, x))
import torch
import numpy as np
import utils
def stft(sig_vec, n_fft=None, hop_length=None, window=torch.hann_window, out_type="numpy"):
""" sig_vec = [batch, time]
default values are consistent with librosa.core.spectrum._spectrogram:
center = True,
normalized = False,
onesided = True,
pad_mode = 'reflect'
"""
if not isinstance(sig_vec, torch.DoubleTensor):
sig_vec = torch.from_numpy(np.atleast_2d(sig_vec)).float()
# sig_vec = sig_vec.to('cuda')
if n_fft is None: n_fft = 2048 # better to be an even number ?
if hop_length is None: hop_length = int(n_fft // 2)
window_stft = window(n_fft)
window_stft = window_stft.to(sig_vec.device)
stft_mat = torch.stft(sig_vec,
n_fft = n_fft,
hop_length = hop_length,
window = window_stft,
center = True,
normalized = False,
onesided = True,
pad_mode = 'reflect'
).transpose(1, 2)
out_torch = stft_mat.squeeze().cpu().numpy().T
if out_type == "torch":
return out_torch
elif out_type == "numpy":
return out_torch[0, ...] + out_torch[1, ...]*1j # combine real and imaginary part
def istft(stft_mat, hop_length=None, window=torch.hann_window):
""" stft_mat = [batch, freq, time, complex]
default values are consistent with librosa.core.spectrum._spectrogram:
center = True,
normalized = False,
onesided = True,
unpad_mode = 'reflect'
All based on librosa
- http://librosa.github.io/librosa/_modules/librosa/core/spectrum.html#istft
What's missing?
- normalize by sum of squared window --> do we need it here?
Actually the result is ok by simply dividing y by 2.
"""
if not isinstance(stft_mat, torch.DoubleTensor):
stft_mat = torch.from_numpy(np.stack([np.real(stft_mat), np.imag(stft_mat)]).transpose((1, 2, 0))[None, ...]).float()
# stft_mat = stft_mat.to('cuda')
n_fft = 2 * (stft_mat.shape[-3] - 1) # would always be an even number
if hop_length is None: hop_length = int(n_fft // 2)
window_istft = window(n_fft)
window_istft = window_istft.to(stft_mat.device)
n_frames = stft_mat.shape[-2] # [time] (time domain of stft_mat)
n_samples = n_fft + hop_length * (n_frames - 1) # [time] (time domain of reconstructed signal)
window_istft = window_istft.view(1, -1) # [batch, time]
sig_vec = torch.zeros(stft_mat.shape[0], n_samples, device=stft_mat.device) # [batch, time]
win_vec = torch.zeros(stft_mat.shape[0], n_samples, device=stft_mat.device) # [batch, time]
win_vec_1frame = window_istft ** 2 # [batch, time]
for i in range(n_frames):
sig_vec_1frame = torch.irfft(stft_mat[:, :, i], signal_ndim=1, signal_sizes=(n_fft,)) # [batch, time]
sig_vec_1frame *= window_istft # [batch, time]
idx_sig = i * hop_length
sig_vec[:, idx_sig:(idx_sig+n_fft)] += sig_vec_1frame
win_vec[:, idx_sig:(idx_sig+n_fft)] += win_vec_1frame
sig_vec /= win_vec
sig_vec = sig_vec[:, n_fft//2:-n_fft//2] # unpadding
# out_torch = (sig_vec / window_istft.sum()).squeeze().cpu().numpy()
out_torch = (sig_vec).squeeze().cpu().numpy()
return out_torch
def spectogram(X, power=1):
return X.pow(2).sum(axis=-1).pow(power / 2.0)
if __name__ == "__main__":
s = utils.sine()
X = stft(s)
print(X.shape)
x = istft(X)
print(utils.rms(s, x))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment