Skip to content

Instantly share code, notes, and snippets.

@zaptrem

zaptrem/test.py Secret

Created June 28, 2024 23:28
Show Gist options
  • Save zaptrem/94d10c5d76d2f601841e9f8e8bf4859a to your computer and use it in GitHub Desktop.
Save zaptrem/94d10c5d76d2f601841e9f8e8bf4859a to your computer and use it in GitHub Desktop.
MRSTFT For RFWave
import torch
import torchaudio
import matplotlib.pyplot as plt
import numpy as np
from rfwave.heads import RFSTFTHead
from rfwave.spectral_ops import ISTFT, STFT
class MultiResolutionRFSTFTHead(RFSTFTHead):
def __init__(self, dim: int, num_bands: int, min_fft_size: int, max_fft_size: int, hop_length: int, padding: str = "same"):
super().__init__(dim, max_fft_size, hop_length, padding=padding)
self.num_bands = num_bands
self.hop_length = hop_length
self.fft_sizes = [
int(max_fft_size * (min_fft_size / max_fft_size) ** (i / (num_bands - 1)))
for i in range(num_bands)
]
self.fft_sizes = [max(size + (size % 2), 2 * hop_length) for size in self.fft_sizes]
self.window_sizes = self.fft_sizes
self.stft_modules = torch.nn.ModuleList([
STFT(n_fft=fft_size, hop_length=hop_length, win_length=window_size, padding=padding)
for fft_size, window_size in zip(self.fft_sizes, self.window_sizes)
])
self.istft_modules = torch.nn.ModuleList([
ISTFT(n_fft=fft_size, hop_length=hop_length, win_length=window_size, padding=padding)
for fft_size, window_size in zip(self.fft_sizes, self.window_sizes)
])
self.band_sizes = [(fft_size // 2 + 1) // num_bands for fft_size in self.fft_sizes]
self.n_fft = sum(self.band_sizes) * 2
# print(f"FFT sizes: {self.fft_sizes}")
# print(f"Window sizes: {self.window_sizes}")
# print(f"Band sizes: {self.band_sizes}")
# print(f"Total frequency bins: {self.n_fft}")
def get_spec(self, audio):
results = []
for i, stft_module in enumerate(self.stft_modules):
stft = stft_module(audio)
if i == 0: # Lowest frequency band
band = stft[:, :self.band_sizes[i], :]
elif i == len(self.stft_modules) - 1: # Highest frequency band
band = stft[:, -self.band_sizes[i]:, :]
else: # Middle bands
mid = stft.shape[1] // 2
half_band = self.band_sizes[i] // 2
band = stft[:, mid-half_band:mid+half_band, :]
results.append(band)
result = torch.cat(results, dim=1)
return result
def get_wave(self, S):
result = torch.zeros(S.shape[0], S.shape[-1] * self.hop_length, device=S.device)
current_bin = 0
for i, istft_module in enumerate(self.istft_modules):
band_size = self.band_sizes[i]
stft_band = S[:, current_bin:current_bin+band_size, :]
full_stft = torch.zeros((S.shape[0], self.fft_sizes[i]//2+1, S.shape[-1]), dtype=S.dtype, device=S.device)
if i == 0: # Lowest frequency band
full_stft[:, :band_size, :] = stft_band
elif i == len(self.istft_modules) - 1: # Highest frequency band
full_stft[:, -band_size:, :] = stft_band
else: # Middle bands
mid = full_stft.shape[1] // 2
half_band = band_size // 2
full_stft[:, mid-half_band:mid+half_band, :] = stft_band
wave = istft_module(full_stft)
result += wave
current_bin += band_size
return result
torch.set_grad_enabled(False)
if __name__ == "__main__":
waveform, sample_rate = torchaudio.load("test.wav")
print(f"Loaded audio shape: {waveform.shape}, Sample rate: {sample_rate}")
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
print(f"Converted to mono. New shape: {waveform.shape}")
num_bands = 2
min_fft_size = 1024
max_fft_size = 4096
hop_length = 512
dim = 1024
# MultiResolutionRFSTFTHead
multi_res_head = MultiResolutionRFSTFTHead(dim, num_bands, min_fft_size, max_fft_size, hop_length)
# Original RFSTFTHead
original_head = RFSTFTHead(dim, 2048, hop_length)
# Perform STFT and ISTFT 50 times
num_iterations = 100
# stft_multi = multi_res_head.get_spec(waveform)
# print(f"STFT shape: {stft_multi.shape}")
# waveform = multi_res_head.get_wave(stft_multi)
multi_audio = waveform.clone()
original_audio = waveform.clone()
for i in range(num_iterations):
print(f"Iteration {i+1}/{num_iterations}")
# Multi-resolution
stft_multi = multi_res_head.get_spec(multi_audio)
print(f"STFT shape: {stft_multi.shape}")
multi_audio = multi_res_head.get_wave(stft_multi)
# Original
stft_original = original_head.get_spec(original_audio)
original_audio = original_head.get_wave(stft_original)
print(f"Final multi-resolution audio shape: {multi_audio.shape}")
print(f"Final original audio shape: {original_audio.shape}")
# Plotting
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
mag_multi = torch.abs(stft_multi[0]).numpy()
ax1.imshow(np.log1p(mag_multi), aspect='auto', origin='lower', interpolation='nearest')
ax1.set_title('Multi-resolution STFT (after 50 iterations)')
ax1.set_ylabel('Frequency')
mag_original = torch.abs(stft_original[0]).numpy()
ax2.imshow(np.log1p(mag_original), aspect='auto', origin='lower', interpolation='nearest')
ax2.set_title('Original STFT (after 50 iterations)')
ax2.set_ylabel('Frequency')
ax2.set_xlabel('Time')
plt.tight_layout()
plt.savefig('stft_comparison_50_iterations.png')
plt.close()
# Save reconstructed audio
torchaudio.save("recon_multi_50_iterations.wav", multi_audio, sample_rate)
torchaudio.save("recon_original_50_iterations.wav", original_audio, sample_rate)
# Compute MSE and variance for both methods
def compute_metrics(original, reconstructed):
min_length = min(original.shape[1], reconstructed.shape[1])
mse = torch.mean((original[:, :min_length] - reconstructed[:, :min_length]) ** 2)
var_original = torch.var(original[:, :min_length])
var_reconstructed = torch.var(reconstructed[:, :min_length])
return mse.item(), var_original.item(), var_reconstructed.item()
mse_multi, var_waveform, var_reconstructed_multi = compute_metrics(waveform, multi_audio)
mse_original, _, var_reconstructed_original = compute_metrics(waveform, original_audio)
print("\nMulti-resolution RFSTFTHead (50 iterations):")
print(f"MSE: {mse_multi}")
print(f"Variance of waveform: {var_waveform}")
print(f"Variance of reconstructed: {var_reconstructed_multi}")
print("\nOriginal RFSTFTHead (50 iterations):")
print(f"MSE: {mse_original}")
print(f"Variance of waveform: {var_waveform}")
print(f"Variance of reconstructed: {var_reconstructed_original}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment