Skip to content

Instantly share code, notes, and snippets.

@revsic
Last active May 21, 2024 10:17
Show Gist options
  • Save revsic/56fc069cf2096a8ce4ef1bf664bcf306 to your computer and use it in GitHub Desktop.
Save revsic/56fc069cf2096a8ce4ef1bf664bcf306 to your computer and use it in GitHub Desktop.
"""
Copyright (C) https://github.com/praat/praat
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.functional as AF
class NACF(nn.Module):
"""Normalized autocorrelation-based Pitch estimation, reimplementation of Praat.
"""
def __init__(self,
sr: int,
frame_time: float = 0.01,
freq_min: float = 75.,
freq_max: float = 600.,
down_sr: int = 16000,
k: int = 15,
thold_silence: float = 0.03,
thold_voicing: float = 0.45,
cost_octave: float = 0.01,
cost_jump: float = 0.35,
cost_vuv: float = 0.14,
median_win: Optional[int] = 5):
"""Initializer.
Args:
sr: sampling rate.
frame_time: duration of the frame.
freq_min, freq_max: frequency min and max.
down_sr: downsampling sr for fast computation.
k: the maximum number of the candidates.
"""
super().__init__()
self.sr = sr
self.strides = int(down_sr * frame_time)
self.fmin, self.fmax = freq_min, freq_max
self.tmax = int(down_sr // freq_min)
self.tmin = int(down_sr // freq_max)
self.down_sr = down_sr
self.k = min(max(k, int(freq_max / freq_min)), self.tmax - self.tmin)
# set windows based on tau-max
self.w = int(2 ** np.ceil(np.log2(self.tmax))) + 1
self.register_buffer(
'window', torch.hann_window(self.w), persistent=False)
# [w + 1], normalized autocorrelation of the window
r = torch.fft.rfft(self.window, self.w * 2, dim=-1)
# [w]
ws = torch.fft.irfft(r * r.conj(), dim=-1)[:self.w]
ws = ws / ws[0]
self.register_buffer('nacw', ws, persistent=False)
# alias
self.thold_silence = thold_silence
self.thold_voicing = thold_voicing
self.cost_octave = cost_octave
# correction
c = 1. # c = 0.01 * down_sr
self.cost_jump = cost_jump * c
self.cost_vuv = cost_vuv * c
self.median_win = median_win
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""Estimate the pitch from the input signal.
Args:
inputs: [torch.float32; [..., T]], input signal, [-1, 1]-ranged.
Returns:
[torch.float32; [..., S]], estimated pitch sequence.
"""
x = AF.resample(inputs, self.sr, self.down_sr)
# [..., T / strides, w]
frames = F.pad(x, [0, self.w]).unfold(-1, self.w, self.strides)
# [..., T / strides, w]
frames = (frames - frames.mean(dim=-1)[..., None]) * self.window
# [...]
global_peak = inputs.abs().amax(dim=-1)
# [..., T / strides]
local_peak = frames.abs().amax(dim=-1)
# [..., T / strides]
intensity = torch.where(
local_peak > global_peak[..., None],
torch.tensor(1., device=x.device),
local_peak / global_peak[..., None])
# [..., T / strides, w + 1]
fft = torch.fft.rfft(frames, self.w * 2, dim=-1)
# [..., T / strides, w]
acf = torch.fft.irfft(fft * fft.conj(), dim=-1)[..., :self.w]
# [..., T / strides, w], normalized autocorrelation
nacf = acf / (acf[..., :1] * self.nacw)
# [..., T / strides, tmax + 1]
nacf = nacf[..., :self.tmax + 1]
# [..., T / strides, tmax], x[i + 1] - x[i]
d = nacf.diff(dim=-1)
# [..., T / strides, tmax - 1], inc & dec
localmax = (d[..., :-1] >= 0) & (d[..., 1:] <= 0)
# [..., T / strides, tmax - 1]
flag = localmax & (nacf[..., 1:-1] > 0.5 * self.thold_voicing)
# [..., T / strides, tmax - 1], parabolic interpolation
n, c, p = nacf[..., 2:], nacf[..., 1:-1], nacf[..., :-2]
dr = 0.5 * (n - p) / (2. * c - n - p)
# [tmax - 1]
a = torch.arange(self.tmax - 1, device=dr.device)
# [..., T / strides, tmax - 1]
freqs = self.down_sr / (1 + (dr + a).clamp_min(0.))
## TODO: sinc interpolation, depth=30
logits = nacf[..., 1:self.tmax]
# reflect logits of high values (for short windows)
logits = logits.where(logits <= 1., 1 / logits)
# additional penalty
logits = logits - self.cost_octave * (self.fmin / freqs).log2()
# masking
FLOOR = -1e5
logits.masked_fill_(~flag, FLOOR)
# [..., T / strides, k], [..., T / strides, k], topk
logits, indices = logits.topk(self.k, dim=-1)
# [..., T / strides, k]
freqs = freqs.gather(-1, indices)
## TODO: maximize sinc interpolation, depth=4
logits, freqs = logits, freqs
# [..., T / strides]
logits_uv = 2. - intensity / (
self.thold_silence / (1. + self.thold_voicing))
logits_uv = self.thold_voicing + logits_uv.clamp_min(0.)
# [..., T / strides, k]
voiced = (logits > FLOOR) & (freqs < self.fmax)
# [..., T / strides, k]
delta = torch.where(
~voiced,
logits_uv[..., None],
logits - self.cost_octave * (self.fmax / freqs).log2())
# [..., T / strides - 1, k, k]
trans = self.cost_jump * (
freqs[..., :-1, :, None] / freqs[..., 1:, None, :]).log2().abs()
# both voiceless
trans.masked_fill_(
~voiced[..., :-1, :, None] & ~voiced[..., 1:, None, :], 0.)
# voice transition
trans.masked_fill_(
voiced[..., :-1, :, None] != voiced[..., 1:, None, :], self.cost_vuv)
# S(=T / strides)
steps = delta.shape[-2]
# [..., k]
value = delta[..., 0, :]
# [..., S, k]
ptrs = torch.zeros_like(delta)
for i in range(1, steps):
# [..., k]
value, ptrs[..., i, :] = (
value[..., None]
- trans[..., i - 1, :, :] + delta[..., i, None, :]).max(dim=-2)
# [..., S], backtracking
states = torch.zeros_like(delta[..., 0], dtype=torch.long)
# initial state
states[..., -1] = value.argmax(dim=-1)
for i in range(steps - 2, -1, -1):
# [...]
states[..., i] = ptrs[..., i + 1, :].gather(
-1, states[..., i + 1, None]).squeeze(dim=-1)
# [..., T / strides, 1], sampling
freqs = freqs.gather(-1, states[..., None])
# masking unvoiced
freqs.masked_fill_(~voiced.gather(-1, states[..., None]), 0.)
# [..., T / strides]
f0 = freqs.squeeze(dim=-1)
# median pool
if self.median_win is not None:
w = self.median_win // 2
# replication
f0 = torch.cat(
[f0[..., :1]] * w + [f0] + [f0[..., -1:]] * w, dim=-1)
f0 = torch.median(
f0.unfold(-1, self.median_win, 1),
dim=-1).values
return f0
"""
Copyright (C) https://github.com/praat/praat
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>
"""
import warnings
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
class PitchShift(nn.Module):
"""Pitch shift with formant correction, based on TD-PSOLA, reimplementation of Praat.
"""
def __init__(self,
sr: int,
floor: float = 60,
bins_per_octave: int = 12,
window_fn: Callable[[int, torch.device], torch.Tensor] = torch.hann_window):
"""Initializer.
Args:
sr: sampling rate.
floor: floor value of fundamental frequency.
bins_per_octave: the number of the bins in an octave.
window_fn: window function, default hann window.
"""
super().__init__()
self.sr = sr
self.floor = floor
self.bins_per_octave = bins_per_octave
self.window_fn = window_fn
def forward(self,
snd: torch.Tensor,
pitch: torch.Tensor,
steps: int,
pitch_range: float) -> torch.Tensor:
"""Manipulate the fundamental frequency.
Args:
snd: [torch.float32; [T]], audio signal, [-1, 1]-ranged.
pitch: [torch.float32; [S]], fundamental frequencies.
steps: the number of the steps to shift.
pitch_range: ranging factor.
Returns:
[torch.float32; [T]], resampled.
"""
if not (pitch > 1e-5).any():
warnings.warn('all unvoiced, pitch is all zero')
return snd
# [T]
f0 = F.interpolate(pitch[None, None], size=len(snd), mode='linear')[0, 0]
# [P], find all peaks in voiced segment
peaks = self.find_allpeaks(snd, f0)
if peaks is None:
warnings.warn('peak not found, maybe given is all unvoiced.')
return snd
# nonzero median
median = pitch[pitch > 0.].median().item()
# scaling factor
factor = 2 ** (steps / self.bins_per_octave)
# shift
f0, median = f0 * factor, median * factor
# rerange
f0 = torch.where(
f0 > 0.,
median + (f0 - median) * pitch_range,
0.)
# resample
return self.psola(snd, f0, peaks)
def find_voiced_segment(self, f0: torch.Tensor, i: int = 0) -> Optional[Tuple[int, int]]:
"""Find voiced segment starting from `i`.
Args:
f0: [torch.float32; [T]], fundamental frequencies, hertz-level.
i: starting index.
Returns:
segment left and right if voiced segment exist (half inclusive range)
"""
# if f0 tensor is empty
if len(f0[i:]) == 0:
return None
# next voiced interval
flag = (f0[i:] > 0.).long()
# force first label if False
flag[0] = 0
# if all unvoiced
if not flag.any():
return None
# if found
left = i + flag.argmax()
# count the numbers
_, (_, cnt, *_) = flag.unique_consecutive(return_counts=True)
right = left + cnt
return left.item(), right.item()
def find_allpeaks(self, signal: torch.Tensor, f0: torch.Tensor) -> Optional[torch.Tensor]:
"""Compute all periods from signal.
Args:
signal: [torch.float32; [T]], speech signal, [-1, 1]-ranged.
f0: [torch.float32; [T]], fundamental frequencies, hertz-level.
Returns:
[torch.long; [P]], found peaks.
"""
# []
global_peak = signal.abs().max()
def find_peak(i: int, dir_: Union['left', 'right']):
# find peaks
w = int(self.sr / f0[i].clamp_min(self.floor))
s = max(i - w // 2, 0)
if dir_ == 'left':
# -1.25 - 0.5, -0.8 - 0.5
cand_l, cand_r = max(int(i - 1.75 * w), 0), max(int(i - 1.3 * w), 0)
if dir_ == 'right':
# 0.8 - 0.5, 1.25 - 0.5
cand_l, cand_r = int(i + 0.3 * w), int(i + 0.75 * w)
# if nothing to find
if cand_l == cand_r or len(signal) - cand_l < w:
return w, -1, i, 0
# [cand(=cand_r - cand_l), w]
seg = signal[cand_l:cand_r + w].unfold(-1, w, 1)
# [cand]
corr = torch.matmul(
F.normalize(seg, dim=-1),
F.normalize(signal[s:s + w], dim=-1)[:, None]).squeeze(dim=-1)
# []
max_corr, r = corr.max(), corr.argmax()
peak = seg[r].abs().max()
# add bias (cand_l - s)
return w, max_corr, i + (r + cand_l) - s, peak
added_right, i, peaks = -1e308, 0, []
while True:
voiced = self.find_voiced_segment(f0, i)
if voiced is None:
break
# if exist
left, right = voiced
# middle interval
middle = (left + right) // 2
# find first extremum
w = int(self.sr / f0[middle])
s = max(middle - w // 2, 0)
# []
minima, imin = signal[s:s + w].min(), signal[s:s + w].argmin()
maxima, imax = signal[s:s + w].max(), signal[s:s + w].argmax()
# if all same
if minima == maxima:
i = middle
else:
i = s + (imin if abs(minima) > abs(maxima) else imax)
backup = i
# left interval search
while True:
w, corr, i, peak = find_peak(i, 'left')
if corr == -1.:
i -= w
if i < left:
if corr > 0.7 and peak > 0.023333 * global_peak and i - added_right > 0.8 * w:
peaks.append(i)
break
if corr > 0.3 and (peak == 0. or peak > 0.01 * global_peak):
if i - added_right > 0.8 * w:
peaks.append(i)
i = backup
# right interval search
while True:
w, corr, i, peak = find_peak(i, 'right')
if corr == -1.:
i += w
# half-exclusive
if i >= right:
if corr > 0.7 and peak > 0.023333 * global_peak:
peaks.append(i)
added_right = i
break
if corr > 0.3 and (peak == 0. or peak > 0.01 * global_peak):
peaks.append(i)
added_right = i
# to next interval
i = right
if len(peaks) == 0:
return None
# sort the point
return torch.stack(sorted(peaks)).clamp(0, len(signal) - 1)
def psola(self, signal: torch.Tensor, pitch: torch.Tensor, peaks: torch.Tensor) -> torch.Tensor:
"""Pitch-synchronous overlap and add.
Args:
signal: [torch.float32; [T]], speech signal, [-1, 1]-ranged.
pitch: [torch.float32; [T]], fundamental frequencies, hertz-level.
peaks: [torch.float32; [P]], peaks.
Returns:
[torch.float32; [T]], resampled.
"""
device = signal.device
max_w = 1.25 * self.sr / pitch[pitch > 0].min()
# T
timesteps, = signal.shape
# [T]
new_signal = torch.zeros_like(signal)
cache = {}
def cached_window(left: int, right: int) -> torch.Tensor:
nonlocal device, cache
if left not in cache:
cache[left] = self.window_fn(left * 2, device=device)
if right not in cache:
cache[right] = self.window_fn(right * 2, device=device)
return torch.cat([cache[left][:left], cache[right][right:]], dim=0)
i = 0
while i < len(signal):
voiced = self.find_voiced_segment(pitch, i)
if voiced is None:
break
# if voice found
left_v, right_v = voiced
if i < left_v:
# copy noise, do not cache the window
window = self.window_fn(left_v - i, device=device)
new_signal[i:left_v] += window * signal[i:left_v]
while left_v < right_v:
# find nearest peak
p = (peaks - left_v).abs().argmin().item()
period = int(self.sr / pitch[left_v].clamp_min(self.floor))
# width
left_w, right_w = period // 2, period // 2
# clamping for aliasing
if p > 0 and peaks[p] - peaks[p - 1] <= max_w:
left_w = min(peaks[p] - peaks[p - 1], left_w)
if p < len(peaks) - 1 and peaks[p + 1] - peaks[p] <= max_w:
right_w = min(peaks[p + 1] - peaks[p], right_w)
# clamping for sampling
left_w, right_w = min(left_w, peaks[p]), max(min(right_w, timesteps - peaks[p]), 0)
# offset to index
left_i, right_i = int(peaks[p] - left_w), int(peaks[p] + right_w)
# copy
s = left_v - (right_i - left_i) // 2
s, r = max(s, 0), max(-s, 0)
intval = min(right_i - left_i - r, timesteps - s)
# for safety
if intval == 0:
break
seg = cached_window(left_w, right_w) * signal[left_i:right_i]
new_signal[s:s + intval] = seg[r:r + intval]
# next
left_v += intval
# next segment
i = right_v
# copy last noise
new_signal[i:] = signal[i:]
return new_signal
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment