Skip to content

Instantly share code, notes, and snippets.

@lzqlzzq
Last active January 18, 2024 03:13
Show Gist options
  • Save lzqlzzq/c5ba6f5cca60819f270721cc4374f529 to your computer and use it in GitHub Desktop.
Save lzqlzzq/c5ba6f5cca60819f270721cc4374f529 to your computer and use it in GitHub Desktop.
Trainable STFT(Short-time Fourier Transformation) Module in pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
def get_fourier_basis(win_length, window_func=torch.hann_window):
# Create kernels for STFT, initialized to Fourier basis
n_basis = win_length // 2 + 1
t = torch.arange(win_length).float()
w = torch.arange(n_basis).float()
window = window_func(win_length)
basis = torch.stack((torch.cos(2 * np.pi * w[:, None] / win_length * t[None, :]),
-torch.sin(2 * np.pi * w[:, None] / win_length * t[None, :])), dim=0) * window
return n_basis, basis.view(2 * n_basis, 1, win_length)
class TrainableSTFT(nn.Module):
def __init__(self,
win_length: int,
hop_length: int = None,
window_func: nn.Module = torch.hann_window):
super(TrainableSTFT, self).__init__()
self.win_length = win_length
self.hop_length = hop_length or self.win_length // 2
n_basis, basis = get_fourier_basis(win_length, window_func)
# Initialize nn.Conv1d layer with the Fourier basis
self.conv_real = nn.Conv1d(1, n_basis, self.win_length, stride=self.hop_length, bias=False)
self.conv_imag = nn.Conv1d(1, n_basis, self.win_length, stride=self.hop_length, bias=False)
with torch.no_grad():
self.conv_real.weight[:,:] = basis[:n_basis,:]
self.conv_imag.weight[:,:] = basis[n_basis:,:]
# Make the convolution layers trainable
self.conv_real.weight.requires_grad = True
self.conv_imag.weight.requires_grad = True
def pad(self, input_signal):
# Calculate the number of frames
num_frames = (input_signal.size(-1) - self.win_length) // self.hop_length + 1
# Calculate the required padding size
pad_size = max(0, (num_frames - 1) * self.hop_length + self.win_length - input_signal.size(-1))
# Apply padding to the input signal
return F.pad(input_signal.transpose(-1, -2), (0, 0, 0, pad_size)).transpose(-1, -2)
@property
def feature_size(self):
return self.win_length // 2 + 1
def forward(self, input_signal):
B, C, L = input_signal.shape
input_signal = input_signal.reshape(-1, 1, L)
real_part = self.conv_real(input_signal).transpose(-1, -2)
imag_part = self.conv_imag(input_signal).transpose(-1, -2)
return torch.sqrt(real_part**2 + imag_part**2).reshape(B, C, -1, self.feature_size)
class TrainableMel(nn.Module):
def __init__(self,
win_length: int,
sample_rate: int,
n_mels: int,
hop_length: int = None,
window_func: nn.Module = torch.hann_window):
super().__init__()
self.stft_module = TrainableSTFT(win_length, hop_length, window_func)
self.mel_fbank = nn.Linear(self.stft_module.feature_size, n_mels, bias=False)
with torch.no_grad():
self.mel_fbank.weight[:,:] = torchaudio.functional.melscale_fbanks(self.stft_module.feature_size,
0,
sample_rate // 2,
n_mels,
sample_rate,
norm="slaney").transpose(-1, -2)
self.mel_fbank.weight.requires_grad = True
def forward(self, x):
spec = self.stft_module(x)
return self.mel_fbank(spec)
class TrainableISTFT(nn.Module):
def __init__(self,
win_length: int,
hop_length: int = None,
window_func: nn.Module = torch.hann_window):
super(TrainableISTFT, self).__init__()
self.win_length = win_length
self.hop_length = hop_length or win_length // 2
n_basis, basis = get_fourier_basis(win_length, window_func)
# Initialize nn.ConvTranspose1d layer with the inverse Fourier basis
self.conv_transpose = nn.ConvTranspose1d(n_basis, 1, self.win_length, stride=self.hop_length, bias=False)
with torch.no_grad():
self.conv_transpose.weight[:,:] = basis[:n_basis,:] # Using only real part of the basis
# Make the convolution layer trainable
self.conv_transpose.weight.requires_grad = True
def forward(self, input_spec):
B, C, L, N = input_spec.shape
# Use ConvTranspose1d for overlap-add in the time domain
output_signal = self.conv_transpose(input_spec.reshape(-1, L, N).transpose(-1, -2))
return output_signal.reshape(B, C, -1)
from matplotlib import pyplot as plt
if __name__ == '__main__':
# Example usage:
stft_module = TrainableSTFT(win_length=2048, hop_length=512)
istft_module = TrainableISTFT(win_length=2048, hop_length=512)
signal = torch.randn(4, 2, 4096) # Example signal
stft_output = stft_module(signal)
print("stft_output:", stft_output.shape)
istft_output = istft_module(stft_output)
print("istft_output:", istft_output.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment