-
-
Save zaptrem/94d10c5d76d2f601841e9f8e8bf4859a to your computer and use it in GitHub Desktop.
MRSTFT For RFWave
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchaudio | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from rfwave.heads import RFSTFTHead | |
from rfwave.spectral_ops import ISTFT, STFT | |
class MultiResolutionRFSTFTHead(RFSTFTHead): | |
def __init__(self, dim: int, num_bands: int, min_fft_size: int, max_fft_size: int, hop_length: int, padding: str = "same"): | |
super().__init__(dim, max_fft_size, hop_length, padding=padding) | |
self.num_bands = num_bands | |
self.hop_length = hop_length | |
self.fft_sizes = [ | |
int(max_fft_size * (min_fft_size / max_fft_size) ** (i / (num_bands - 1))) | |
for i in range(num_bands) | |
] | |
self.fft_sizes = [max(size + (size % 2), 2 * hop_length) for size in self.fft_sizes] | |
self.window_sizes = self.fft_sizes | |
self.stft_modules = torch.nn.ModuleList([ | |
STFT(n_fft=fft_size, hop_length=hop_length, win_length=window_size, padding=padding) | |
for fft_size, window_size in zip(self.fft_sizes, self.window_sizes) | |
]) | |
self.istft_modules = torch.nn.ModuleList([ | |
ISTFT(n_fft=fft_size, hop_length=hop_length, win_length=window_size, padding=padding) | |
for fft_size, window_size in zip(self.fft_sizes, self.window_sizes) | |
]) | |
self.band_sizes = [(fft_size // 2 + 1) // num_bands for fft_size in self.fft_sizes] | |
self.n_fft = sum(self.band_sizes) * 2 | |
# print(f"FFT sizes: {self.fft_sizes}") | |
# print(f"Window sizes: {self.window_sizes}") | |
# print(f"Band sizes: {self.band_sizes}") | |
# print(f"Total frequency bins: {self.n_fft}") | |
def get_spec(self, audio): | |
results = [] | |
for i, stft_module in enumerate(self.stft_modules): | |
stft = stft_module(audio) | |
if i == 0: # Lowest frequency band | |
band = stft[:, :self.band_sizes[i], :] | |
elif i == len(self.stft_modules) - 1: # Highest frequency band | |
band = stft[:, -self.band_sizes[i]:, :] | |
else: # Middle bands | |
mid = stft.shape[1] // 2 | |
half_band = self.band_sizes[i] // 2 | |
band = stft[:, mid-half_band:mid+half_band, :] | |
results.append(band) | |
result = torch.cat(results, dim=1) | |
return result | |
def get_wave(self, S): | |
result = torch.zeros(S.shape[0], S.shape[-1] * self.hop_length, device=S.device) | |
current_bin = 0 | |
for i, istft_module in enumerate(self.istft_modules): | |
band_size = self.band_sizes[i] | |
stft_band = S[:, current_bin:current_bin+band_size, :] | |
full_stft = torch.zeros((S.shape[0], self.fft_sizes[i]//2+1, S.shape[-1]), dtype=S.dtype, device=S.device) | |
if i == 0: # Lowest frequency band | |
full_stft[:, :band_size, :] = stft_band | |
elif i == len(self.istft_modules) - 1: # Highest frequency band | |
full_stft[:, -band_size:, :] = stft_band | |
else: # Middle bands | |
mid = full_stft.shape[1] // 2 | |
half_band = band_size // 2 | |
full_stft[:, mid-half_band:mid+half_band, :] = stft_band | |
wave = istft_module(full_stft) | |
result += wave | |
current_bin += band_size | |
return result | |
torch.set_grad_enabled(False) | |
if __name__ == "__main__": | |
waveform, sample_rate = torchaudio.load("test.wav") | |
print(f"Loaded audio shape: {waveform.shape}, Sample rate: {sample_rate}") | |
if waveform.shape[0] > 1: | |
waveform = torch.mean(waveform, dim=0, keepdim=True) | |
print(f"Converted to mono. New shape: {waveform.shape}") | |
num_bands = 2 | |
min_fft_size = 1024 | |
max_fft_size = 4096 | |
hop_length = 512 | |
dim = 1024 | |
# MultiResolutionRFSTFTHead | |
multi_res_head = MultiResolutionRFSTFTHead(dim, num_bands, min_fft_size, max_fft_size, hop_length) | |
# Original RFSTFTHead | |
original_head = RFSTFTHead(dim, 2048, hop_length) | |
# Perform STFT and ISTFT 50 times | |
num_iterations = 100 | |
# stft_multi = multi_res_head.get_spec(waveform) | |
# print(f"STFT shape: {stft_multi.shape}") | |
# waveform = multi_res_head.get_wave(stft_multi) | |
multi_audio = waveform.clone() | |
original_audio = waveform.clone() | |
for i in range(num_iterations): | |
print(f"Iteration {i+1}/{num_iterations}") | |
# Multi-resolution | |
stft_multi = multi_res_head.get_spec(multi_audio) | |
print(f"STFT shape: {stft_multi.shape}") | |
multi_audio = multi_res_head.get_wave(stft_multi) | |
# Original | |
stft_original = original_head.get_spec(original_audio) | |
original_audio = original_head.get_wave(stft_original) | |
print(f"Final multi-resolution audio shape: {multi_audio.shape}") | |
print(f"Final original audio shape: {original_audio.shape}") | |
# Plotting | |
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12)) | |
mag_multi = torch.abs(stft_multi[0]).numpy() | |
ax1.imshow(np.log1p(mag_multi), aspect='auto', origin='lower', interpolation='nearest') | |
ax1.set_title('Multi-resolution STFT (after 50 iterations)') | |
ax1.set_ylabel('Frequency') | |
mag_original = torch.abs(stft_original[0]).numpy() | |
ax2.imshow(np.log1p(mag_original), aspect='auto', origin='lower', interpolation='nearest') | |
ax2.set_title('Original STFT (after 50 iterations)') | |
ax2.set_ylabel('Frequency') | |
ax2.set_xlabel('Time') | |
plt.tight_layout() | |
plt.savefig('stft_comparison_50_iterations.png') | |
plt.close() | |
# Save reconstructed audio | |
torchaudio.save("recon_multi_50_iterations.wav", multi_audio, sample_rate) | |
torchaudio.save("recon_original_50_iterations.wav", original_audio, sample_rate) | |
# Compute MSE and variance for both methods | |
def compute_metrics(original, reconstructed): | |
min_length = min(original.shape[1], reconstructed.shape[1]) | |
mse = torch.mean((original[:, :min_length] - reconstructed[:, :min_length]) ** 2) | |
var_original = torch.var(original[:, :min_length]) | |
var_reconstructed = torch.var(reconstructed[:, :min_length]) | |
return mse.item(), var_original.item(), var_reconstructed.item() | |
mse_multi, var_waveform, var_reconstructed_multi = compute_metrics(waveform, multi_audio) | |
mse_original, _, var_reconstructed_original = compute_metrics(waveform, original_audio) | |
print("\nMulti-resolution RFSTFTHead (50 iterations):") | |
print(f"MSE: {mse_multi}") | |
print(f"Variance of waveform: {var_waveform}") | |
print(f"Variance of reconstructed: {var_reconstructed_multi}") | |
print("\nOriginal RFSTFTHead (50 iterations):") | |
print(f"MSE: {mse_original}") | |
print(f"Variance of waveform: {var_waveform}") | |
print(f"Variance of reconstructed: {var_reconstructed_original}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment