Last active
May 21, 2024 10:17
-
-
Save revsic/56fc069cf2096a8ce4ef1bf664bcf306 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright (C) https://github.com/praat/praat | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
This program is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with this program. If not, see <http://www.gnu.org/licenses/> | |
""" | |
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio.functional as AF | |
class NACF(nn.Module): | |
"""Normalized autocorrelation-based Pitch estimation, reimplementation of Praat. | |
""" | |
def __init__(self, | |
sr: int, | |
frame_time: float = 0.01, | |
freq_min: float = 75., | |
freq_max: float = 600., | |
down_sr: int = 16000, | |
k: int = 15, | |
thold_silence: float = 0.03, | |
thold_voicing: float = 0.45, | |
cost_octave: float = 0.01, | |
cost_jump: float = 0.35, | |
cost_vuv: float = 0.14, | |
median_win: Optional[int] = 5): | |
"""Initializer. | |
Args: | |
sr: sampling rate. | |
frame_time: duration of the frame. | |
freq_min, freq_max: frequency min and max. | |
down_sr: downsampling sr for fast computation. | |
k: the maximum number of the candidates. | |
""" | |
super().__init__() | |
self.sr = sr | |
self.strides = int(down_sr * frame_time) | |
self.fmin, self.fmax = freq_min, freq_max | |
self.tmax = int(down_sr // freq_min) | |
self.tmin = int(down_sr // freq_max) | |
self.down_sr = down_sr | |
self.k = min(max(k, int(freq_max / freq_min)), self.tmax - self.tmin) | |
# set windows based on tau-max | |
self.w = int(2 ** np.ceil(np.log2(self.tmax))) + 1 | |
self.register_buffer( | |
'window', torch.hann_window(self.w), persistent=False) | |
# [w + 1], normalized autocorrelation of the window | |
r = torch.fft.rfft(self.window, self.w * 2, dim=-1) | |
# [w] | |
ws = torch.fft.irfft(r * r.conj(), dim=-1)[:self.w] | |
ws = ws / ws[0] | |
self.register_buffer('nacw', ws, persistent=False) | |
# alias | |
self.thold_silence = thold_silence | |
self.thold_voicing = thold_voicing | |
self.cost_octave = cost_octave | |
# correction | |
c = 1. # c = 0.01 * down_sr | |
self.cost_jump = cost_jump * c | |
self.cost_vuv = cost_vuv * c | |
self.median_win = median_win | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
"""Estimate the pitch from the input signal. | |
Args: | |
inputs: [torch.float32; [..., T]], input signal, [-1, 1]-ranged. | |
Returns: | |
[torch.float32; [..., S]], estimated pitch sequence. | |
""" | |
x = AF.resample(inputs, self.sr, self.down_sr) | |
# [..., T / strides, w] | |
frames = F.pad(x, [0, self.w]).unfold(-1, self.w, self.strides) | |
# [..., T / strides, w] | |
frames = (frames - frames.mean(dim=-1)[..., None]) * self.window | |
# [...] | |
global_peak = inputs.abs().amax(dim=-1) | |
# [..., T / strides] | |
local_peak = frames.abs().amax(dim=-1) | |
# [..., T / strides] | |
intensity = torch.where( | |
local_peak > global_peak[..., None], | |
torch.tensor(1., device=x.device), | |
local_peak / global_peak[..., None]) | |
# [..., T / strides, w + 1] | |
fft = torch.fft.rfft(frames, self.w * 2, dim=-1) | |
# [..., T / strides, w] | |
acf = torch.fft.irfft(fft * fft.conj(), dim=-1)[..., :self.w] | |
# [..., T / strides, w], normalized autocorrelation | |
nacf = acf / (acf[..., :1] * self.nacw) | |
# [..., T / strides, tmax + 1] | |
nacf = nacf[..., :self.tmax + 1] | |
# [..., T / strides, tmax], x[i + 1] - x[i] | |
d = nacf.diff(dim=-1) | |
# [..., T / strides, tmax - 1], inc & dec | |
localmax = (d[..., :-1] >= 0) & (d[..., 1:] <= 0) | |
# [..., T / strides, tmax - 1] | |
flag = localmax & (nacf[..., 1:-1] > 0.5 * self.thold_voicing) | |
# [..., T / strides, tmax - 1], parabolic interpolation | |
n, c, p = nacf[..., 2:], nacf[..., 1:-1], nacf[..., :-2] | |
dr = 0.5 * (n - p) / (2. * c - n - p) | |
# [tmax - 1] | |
a = torch.arange(self.tmax - 1, device=dr.device) | |
# [..., T / strides, tmax - 1] | |
freqs = self.down_sr / (1 + (dr + a).clamp_min(0.)) | |
## TODO: sinc interpolation, depth=30 | |
logits = nacf[..., 1:self.tmax] | |
# reflect logits of high values (for short windows) | |
logits = logits.where(logits <= 1., 1 / logits) | |
# additional penalty | |
logits = logits - self.cost_octave * (self.fmin / freqs).log2() | |
# masking | |
FLOOR = -1e5 | |
logits.masked_fill_(~flag, FLOOR) | |
# [..., T / strides, k], [..., T / strides, k], topk | |
logits, indices = logits.topk(self.k, dim=-1) | |
# [..., T / strides, k] | |
freqs = freqs.gather(-1, indices) | |
## TODO: maximize sinc interpolation, depth=4 | |
logits, freqs = logits, freqs | |
# [..., T / strides] | |
logits_uv = 2. - intensity / ( | |
self.thold_silence / (1. + self.thold_voicing)) | |
logits_uv = self.thold_voicing + logits_uv.clamp_min(0.) | |
# [..., T / strides, k] | |
voiced = (logits > FLOOR) & (freqs < self.fmax) | |
# [..., T / strides, k] | |
delta = torch.where( | |
~voiced, | |
logits_uv[..., None], | |
logits - self.cost_octave * (self.fmax / freqs).log2()) | |
# [..., T / strides - 1, k, k] | |
trans = self.cost_jump * ( | |
freqs[..., :-1, :, None] / freqs[..., 1:, None, :]).log2().abs() | |
# both voiceless | |
trans.masked_fill_( | |
~voiced[..., :-1, :, None] & ~voiced[..., 1:, None, :], 0.) | |
# voice transition | |
trans.masked_fill_( | |
voiced[..., :-1, :, None] != voiced[..., 1:, None, :], self.cost_vuv) | |
# S(=T / strides) | |
steps = delta.shape[-2] | |
# [..., k] | |
value = delta[..., 0, :] | |
# [..., S, k] | |
ptrs = torch.zeros_like(delta) | |
for i in range(1, steps): | |
# [..., k] | |
value, ptrs[..., i, :] = ( | |
value[..., None] | |
- trans[..., i - 1, :, :] + delta[..., i, None, :]).max(dim=-2) | |
# [..., S], backtracking | |
states = torch.zeros_like(delta[..., 0], dtype=torch.long) | |
# initial state | |
states[..., -1] = value.argmax(dim=-1) | |
for i in range(steps - 2, -1, -1): | |
# [...] | |
states[..., i] = ptrs[..., i + 1, :].gather( | |
-1, states[..., i + 1, None]).squeeze(dim=-1) | |
# [..., T / strides, 1], sampling | |
freqs = freqs.gather(-1, states[..., None]) | |
# masking unvoiced | |
freqs.masked_fill_(~voiced.gather(-1, states[..., None]), 0.) | |
# [..., T / strides] | |
f0 = freqs.squeeze(dim=-1) | |
# median pool | |
if self.median_win is not None: | |
w = self.median_win // 2 | |
# replication | |
f0 = torch.cat( | |
[f0[..., :1]] * w + [f0] + [f0[..., -1:]] * w, dim=-1) | |
f0 = torch.median( | |
f0.unfold(-1, self.median_win, 1), | |
dim=-1).values | |
return f0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Copyright (C) https://github.com/praat/praat | |
This program is free software: you can redistribute it and/or modify | |
it under the terms of the GNU General Public License as published by | |
the Free Software Foundation, either version 3 of the License, or | |
(at your option) any later version. | |
This program is distributed in the hope that it will be useful, | |
but WITHOUT ANY WARRANTY; without even the implied warranty of | |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
GNU General Public License for more details. | |
You should have received a copy of the GNU General Public License | |
along with this program. If not, see <http://www.gnu.org/licenses/> | |
""" | |
import warnings | |
from typing import Callable, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class PitchShift(nn.Module): | |
"""Pitch shift with formant correction, based on TD-PSOLA, reimplementation of Praat. | |
""" | |
def __init__(self, | |
sr: int, | |
floor: float = 60, | |
bins_per_octave: int = 12, | |
window_fn: Callable[[int, torch.device], torch.Tensor] = torch.hann_window): | |
"""Initializer. | |
Args: | |
sr: sampling rate. | |
floor: floor value of fundamental frequency. | |
bins_per_octave: the number of the bins in an octave. | |
window_fn: window function, default hann window. | |
""" | |
super().__init__() | |
self.sr = sr | |
self.floor = floor | |
self.bins_per_octave = bins_per_octave | |
self.window_fn = window_fn | |
def forward(self, | |
snd: torch.Tensor, | |
pitch: torch.Tensor, | |
steps: int, | |
pitch_range: float) -> torch.Tensor: | |
"""Manipulate the fundamental frequency. | |
Args: | |
snd: [torch.float32; [T]], audio signal, [-1, 1]-ranged. | |
pitch: [torch.float32; [S]], fundamental frequencies. | |
steps: the number of the steps to shift. | |
pitch_range: ranging factor. | |
Returns: | |
[torch.float32; [T]], resampled. | |
""" | |
if not (pitch > 1e-5).any(): | |
warnings.warn('all unvoiced, pitch is all zero') | |
return snd | |
# [T] | |
f0 = F.interpolate(pitch[None, None], size=len(snd), mode='linear')[0, 0] | |
# [P], find all peaks in voiced segment | |
peaks = self.find_allpeaks(snd, f0) | |
if peaks is None: | |
warnings.warn('peak not found, maybe given is all unvoiced.') | |
return snd | |
# nonzero median | |
median = pitch[pitch > 0.].median().item() | |
# scaling factor | |
factor = 2 ** (steps / self.bins_per_octave) | |
# shift | |
f0, median = f0 * factor, median * factor | |
# rerange | |
f0 = torch.where( | |
f0 > 0., | |
median + (f0 - median) * pitch_range, | |
0.) | |
# resample | |
return self.psola(snd, f0, peaks) | |
def find_voiced_segment(self, f0: torch.Tensor, i: int = 0) -> Optional[Tuple[int, int]]: | |
"""Find voiced segment starting from `i`. | |
Args: | |
f0: [torch.float32; [T]], fundamental frequencies, hertz-level. | |
i: starting index. | |
Returns: | |
segment left and right if voiced segment exist (half inclusive range) | |
""" | |
# if f0 tensor is empty | |
if len(f0[i:]) == 0: | |
return None | |
# next voiced interval | |
flag = (f0[i:] > 0.).long() | |
# force first label if False | |
flag[0] = 0 | |
# if all unvoiced | |
if not flag.any(): | |
return None | |
# if found | |
left = i + flag.argmax() | |
# count the numbers | |
_, (_, cnt, *_) = flag.unique_consecutive(return_counts=True) | |
right = left + cnt | |
return left.item(), right.item() | |
def find_allpeaks(self, signal: torch.Tensor, f0: torch.Tensor) -> Optional[torch.Tensor]: | |
"""Compute all periods from signal. | |
Args: | |
signal: [torch.float32; [T]], speech signal, [-1, 1]-ranged. | |
f0: [torch.float32; [T]], fundamental frequencies, hertz-level. | |
Returns: | |
[torch.long; [P]], found peaks. | |
""" | |
# [] | |
global_peak = signal.abs().max() | |
def find_peak(i: int, dir_: Union['left', 'right']): | |
# find peaks | |
w = int(self.sr / f0[i].clamp_min(self.floor)) | |
s = max(i - w // 2, 0) | |
if dir_ == 'left': | |
# -1.25 - 0.5, -0.8 - 0.5 | |
cand_l, cand_r = max(int(i - 1.75 * w), 0), max(int(i - 1.3 * w), 0) | |
if dir_ == 'right': | |
# 0.8 - 0.5, 1.25 - 0.5 | |
cand_l, cand_r = int(i + 0.3 * w), int(i + 0.75 * w) | |
# if nothing to find | |
if cand_l == cand_r or len(signal) - cand_l < w: | |
return w, -1, i, 0 | |
# [cand(=cand_r - cand_l), w] | |
seg = signal[cand_l:cand_r + w].unfold(-1, w, 1) | |
# [cand] | |
corr = torch.matmul( | |
F.normalize(seg, dim=-1), | |
F.normalize(signal[s:s + w], dim=-1)[:, None]).squeeze(dim=-1) | |
# [] | |
max_corr, r = corr.max(), corr.argmax() | |
peak = seg[r].abs().max() | |
# add bias (cand_l - s) | |
return w, max_corr, i + (r + cand_l) - s, peak | |
added_right, i, peaks = -1e308, 0, [] | |
while True: | |
voiced = self.find_voiced_segment(f0, i) | |
if voiced is None: | |
break | |
# if exist | |
left, right = voiced | |
# middle interval | |
middle = (left + right) // 2 | |
# find first extremum | |
w = int(self.sr / f0[middle]) | |
s = max(middle - w // 2, 0) | |
# [] | |
minima, imin = signal[s:s + w].min(), signal[s:s + w].argmin() | |
maxima, imax = signal[s:s + w].max(), signal[s:s + w].argmax() | |
# if all same | |
if minima == maxima: | |
i = middle | |
else: | |
i = s + (imin if abs(minima) > abs(maxima) else imax) | |
backup = i | |
# left interval search | |
while True: | |
w, corr, i, peak = find_peak(i, 'left') | |
if corr == -1.: | |
i -= w | |
if i < left: | |
if corr > 0.7 and peak > 0.023333 * global_peak and i - added_right > 0.8 * w: | |
peaks.append(i) | |
break | |
if corr > 0.3 and (peak == 0. or peak > 0.01 * global_peak): | |
if i - added_right > 0.8 * w: | |
peaks.append(i) | |
i = backup | |
# right interval search | |
while True: | |
w, corr, i, peak = find_peak(i, 'right') | |
if corr == -1.: | |
i += w | |
# half-exclusive | |
if i >= right: | |
if corr > 0.7 and peak > 0.023333 * global_peak: | |
peaks.append(i) | |
added_right = i | |
break | |
if corr > 0.3 and (peak == 0. or peak > 0.01 * global_peak): | |
peaks.append(i) | |
added_right = i | |
# to next interval | |
i = right | |
if len(peaks) == 0: | |
return None | |
# sort the point | |
return torch.stack(sorted(peaks)).clamp(0, len(signal) - 1) | |
def psola(self, signal: torch.Tensor, pitch: torch.Tensor, peaks: torch.Tensor) -> torch.Tensor: | |
"""Pitch-synchronous overlap and add. | |
Args: | |
signal: [torch.float32; [T]], speech signal, [-1, 1]-ranged. | |
pitch: [torch.float32; [T]], fundamental frequencies, hertz-level. | |
peaks: [torch.float32; [P]], peaks. | |
Returns: | |
[torch.float32; [T]], resampled. | |
""" | |
device = signal.device | |
max_w = 1.25 * self.sr / pitch[pitch > 0].min() | |
# T | |
timesteps, = signal.shape | |
# [T] | |
new_signal = torch.zeros_like(signal) | |
cache = {} | |
def cached_window(left: int, right: int) -> torch.Tensor: | |
nonlocal device, cache | |
if left not in cache: | |
cache[left] = self.window_fn(left * 2, device=device) | |
if right not in cache: | |
cache[right] = self.window_fn(right * 2, device=device) | |
return torch.cat([cache[left][:left], cache[right][right:]], dim=0) | |
i = 0 | |
while i < len(signal): | |
voiced = self.find_voiced_segment(pitch, i) | |
if voiced is None: | |
break | |
# if voice found | |
left_v, right_v = voiced | |
if i < left_v: | |
# copy noise, do not cache the window | |
window = self.window_fn(left_v - i, device=device) | |
new_signal[i:left_v] += window * signal[i:left_v] | |
while left_v < right_v: | |
# find nearest peak | |
p = (peaks - left_v).abs().argmin().item() | |
period = int(self.sr / pitch[left_v].clamp_min(self.floor)) | |
# width | |
left_w, right_w = period // 2, period // 2 | |
# clamping for aliasing | |
if p > 0 and peaks[p] - peaks[p - 1] <= max_w: | |
left_w = min(peaks[p] - peaks[p - 1], left_w) | |
if p < len(peaks) - 1 and peaks[p + 1] - peaks[p] <= max_w: | |
right_w = min(peaks[p + 1] - peaks[p], right_w) | |
# clamping for sampling | |
left_w, right_w = min(left_w, peaks[p]), max(min(right_w, timesteps - peaks[p]), 0) | |
# offset to index | |
left_i, right_i = int(peaks[p] - left_w), int(peaks[p] + right_w) | |
# copy | |
s = left_v - (right_i - left_i) // 2 | |
s, r = max(s, 0), max(-s, 0) | |
intval = min(right_i - left_i - r, timesteps - s) | |
# for safety | |
if intval == 0: | |
break | |
seg = cached_window(left_w, right_w) * signal[left_i:right_i] | |
new_signal[s:s + intval] = seg[r:r + intval] | |
# next | |
left_v += intval | |
# next segment | |
i = right_v | |
# copy last noise | |
new_signal[i:] = signal[i:] | |
return new_signal |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment