Skip to content

Instantly share code, notes, and snippets.

@istupakov
Last active June 3, 2024 11:22
Show Gist options
  • Save istupakov/d91ba7a058f512eae8c9ff7aed2e9a5a to your computer and use it in GitHub Desktop.
Save istupakov/d91ba7a058f512eae8c9ff7aed2e9a5a to your computer and use it in GitHub Desktop.
Online STFT on Python (numpy)
import numpy as np
import numpy.typing as npt
class OnlineStft:
n_freq: int
def __init__(self, n_fft: int, n_hop: int, n_channels: int):
self._hop = n_hop
self.n_freq = n_fft // 2 + 1
self._win_a = np.hamming(n_fft + 1)[:-1]
self._win_s = np.zeros_like(self._win_a)
for i in range(0, (n_fft + n_hop - 1) // n_hop):
self._win_s += np.roll(self._win_a, i * n_hop) ** 2
self._win_s = self._win_a / self._win_s
self._input = np.zeros((n_fft, n_channels))
self._output = np.zeros((n_fft, n_channels))
def forward(self, frame: npt.NDArray[np.float64]) -> npt.NDArray[np.complex128]:
assert frame.shape == (self._hop, self._input.shape[1])
self._input[: -self._hop] = self._input[self._hop :]
self._input[-self._hop :] = frame
return np.fft.rfft(self._input * self._win_a[:, None], axis=0)
def backward(self, image: npt.NDArray[np.complex128]) -> npt.NDArray[np.float64]:
assert image.shape == (self.n_freq, self._output.shape[1])
self._output[: -self._hop] = self._output[self._hop :]
self._output[-self._hop :] = 0
self._output += np.fft.irfft(image, axis=0) * self._win_s[:, None]
return self._output[: self._hop]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment