import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from rfwave.heads import RFSTFTHead
from rfwave.spectral_ops import ISTFT, STFT
class MultiResolutionRFSTFTHead(RFSTFTHead):
def __init__(self, dim: int, num_bands: int, min_fft_size: int, max_fft_size: int, hop_length: int, padding: str = "same"):
super().__init__(dim, max_fft_size, hop_length, padding=padding)
self.num_bands = num_bands
self.hop_length = hop_length
self.fft_sizes = [
int(max_fft_size * (min_fft_size / max_fft_size) ** (i / (num_bands - 1)))
for i in range(num_bands)
self.fft_sizes = [max(size + (size % 2), 2 * hop_length) for size in self.fft_sizes]
self.window_sizes = self.fft_sizes
self.stft_modules = torch.nn.ModuleList([
STFT(n_fft=fft_size, hop_length=hop_length, win_length=window_size, padding=padding)
for fft_size, window_size in zip(self.fft_sizes, self.window_sizes)
self.istft_modules = torch.nn.ModuleList([
ISTFT(n_fft=fft_size, hop_length=hop_length, win_length=window_size, padding=padding)
for fft_size, window_size in zip(self.fft_sizes, self.window_sizes)
self.band_sizes = [(fft_size // 2 + 1) // num_bands for fft_size in self.fft_sizes]
self.n_fft = sum(self.band_sizes) * 2
# print(f"FFT sizes: {self.fft_sizes}")
# print(f"Window sizes: {self.window_sizes}")
# print(f"Band sizes: {self.band_sizes}")
# print(f"Total frequency bins: {self.n_fft}")
def get_spec(self, audio):
results = []
for i, stft_module in enumerate(self.stft_modules):
stft = stft_module(audio)
if i == 0: # Lowest frequency band
band = stft[:, :self.band_sizes[i], :]
elif i == len(self.stft_modules) - 1: # Highest frequency band
band = stft[:, -self.band_sizes[i]:, :]
else: # Middle bands
mid = stft.shape[1] // 2
half_band = self.band_sizes[i] // 2
band = stft[:, mid-half_band:mid+half_band, :]
result =, dim=1)
return result
def get_wave(self, S):
result = torch.zeros(S.shape[0], S.shape[-1] * self.hop_length, device=S.device)
current_bin = 0
for i, istft_module in enumerate(self.istft_modules):
band_size = self.band_sizes[i]
stft_band = S[:, current_bin:current_bin+band_size, :]
full_stft = torch.zeros((S.shape[0], self.fft_sizes[i]//2+1, S.shape[-1]), dtype=S.dtype, device=S.device)
if i == 0: # Lowest frequency band
full_stft[:, :band_size, :] = stft_band
elif i == len(self.istft_modules) - 1: # Highest frequency band
full_stft[:, -band_size:, :] = stft_band
else: # Middle bands
mid = full_stft.shape[1] // 2
half_band = band_size // 2
full_stft[:, mid-half_band:mid+half_band, :] = stft_band
wave = istft_module(full_stft)
result += wave
current_bin += band_size
return result
if __name__ == "__main__":
waveform, sample_rate = torchaudio.load("test.wav")
print(f"Loaded audio shape: {waveform.shape}, Sample rate: {sample_rate}")
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
print(f"Converted to mono. New shape: {waveform.shape}")
num_bands = 2
min_fft_size = 1024
max_fft_size = 4096
hop_length = 512
dim = 1024
# MultiResolutionRFSTFTHead
multi_res_head = MultiResolutionRFSTFTHead(dim, num_bands, min_fft_size, max_fft_size, hop_length)
# Original RFSTFTHead
original_head = RFSTFTHead(dim, 2048, hop_length)
# Perform STFT and ISTFT 50 times
num_iterations = 100
# stft_multi = multi_res_head.get_spec(waveform)
# print(f"STFT shape: {stft_multi.shape}")
# waveform = multi_res_head.get_wave(stft_multi)
multi_audio = waveform.clone()
original_audio = waveform.clone()
for i in range(num_iterations):
print(f"Iteration {i+1}/{num_iterations}")
# Multi-resolution
stft_multi = multi_res_head.get_spec(multi_audio)
print(f"STFT shape: {stft_multi.shape}")
multi_audio = multi_res_head.get_wave(stft_multi)
# Original
stft_original = original_head.get_spec(original_audio)
original_audio = original_head.get_wave(stft_original)
print(f"Final multi-resolution audio shape: {multi_audio.shape}")
print(f"Final original audio shape: {original_audio.shape}")
# Plotting
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
mag_multi = torch.abs(stft_multi[0]).numpy()
ax1.imshow(np.log1p(mag_multi), aspect='auto', origin='lower', interpolation='nearest')
ax1.set_title('Multi-resolution STFT (after 50 iterations)')
mag_original = torch.abs(stft_original[0]).numpy()
ax2.imshow(np.log1p(mag_original), aspect='auto', origin='lower', interpolation='nearest')
ax2.set_title('Original STFT (after 50 iterations)')
# Save reconstructed audio"recon_multi_50_iterations.wav", multi_audio, sample_rate)"recon_original_50_iterations.wav", original_audio, sample_rate)
# Compute MSE and variance for both methods
def compute_metrics(original, reconstructed):
min_length = min(original.shape[1], reconstructed.shape[1])
mse = torch.mean((original[:, :min_length] - reconstructed[:, :min_length]) ** 2)
var_original = torch.var(original[:, :min_length])
var_reconstructed = torch.var(reconstructed[:, :min_length])
return mse.item(), var_original.item(), var_reconstructed.item()
mse_multi, var_waveform, var_reconstructed_multi = compute_metrics(waveform, multi_audio)
mse_original, _, var_reconstructed_original = compute_metrics(waveform, original_audio)
print("\nMulti-resolution RFSTFTHead (50 iterations):")
print(f"MSE: {mse_multi}")
print(f"Variance of waveform: {var_waveform}")
print(f"Variance of reconstructed: {var_reconstructed_multi}")
print("\nOriginal RFSTFTHead (50 iterations):")
print(f"MSE: {mse_original}")
print(f"Variance of waveform: {var_waveform}")
print(f"Variance of reconstructed: {var_reconstructed_original}")
