Last active
February 13, 2023 14:08
-
-
Save revsic/9e6153848798c4e8551e5df51ba25526 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Tuple | |
import numpy as np | |
import torch | |
from yin import YIN, localmin, parabolic_interp | |
def viterbi(log_prob: torch.Tensor, log_trans: torch.Tensor, log_init: torch.Tensor) -> torch.Tensor: | |
"""Viterbi algorithm. | |
Args: | |
log_prob: [torch.float32; [..., S, bins]], conditional log-likelihood. | |
log_trans: [torch.float32; [bins, bins]], transition log-probability. | |
log_init: [torch.float32; [bins]], log-probability of initial state. | |
Returns: | |
states: [torch.long; [..., S]], index sequence. | |
""" | |
# S | |
steps = log_prob.shape[-2] | |
# [..., bins] | |
value = log_prob[..., 0, :] + log_init | |
# [..., S, bins] | |
ptrs = torch.zeros_like(log_prob) | |
for i in range(1, steps): | |
# [..., bins] | |
max_val, ptrs[..., i, :] = (value[..., None] + log_trans).max(dim=-2) | |
# [..., bins] | |
value = log_prob[..., i, :] + max_val | |
# [..., S], backtracking | |
states = torch.zeros_like(log_prob[..., 0], dtype=torch.long) | |
# initial state | |
states[..., -1] = value.argmax(dim=-1) | |
for i in range(steps - 2, -1, -1): | |
# [...] | |
states[..., i] = ptrs[..., i + 1, :].gather( | |
-1, states[..., i + 1, None]).squeeze(dim=-1) | |
return states | |
class pYIN(YIN): | |
"""pYIN-based pitch estimation algorithm. | |
""" | |
def __init__(self, | |
sr: int, | |
frame_time: float = 0.01, | |
freq_min: float = 75., | |
freq_max: float = 600., | |
down_sr: int = 16000, | |
bins_per_octave: int = 12, | |
bins_per_semitone: int = 100, | |
beta_parameters: Tuple[int, float] = (2, 18.), | |
num_thresholds = 100, | |
switch_prob: float = 0.01, | |
lambda_: float = 2., | |
max_octave_trans: float = 35.92, | |
no_trough_prob: float = 0.01, | |
median_win: Optional[int] = 3): | |
"""Initializer. | |
Args: | |
sr: sampling rate. | |
frame_time: duration of the frame. | |
freq_min, freq_max: frequency min and max. | |
down_sr: downsampling sr for fast computation. | |
bins_per_octave: the number of the bins(semitones) in octave. | |
bins_per_semitone: the number of the bins in semitone. | |
beta_parameters: tuple of alpha, beta. | |
num_thresholds: the number of the harmonicity thresholds. | |
switch_prob: probability to swtich voiced to unvoiced. | |
lambda_: lambda value of boltzmann distribution (truncated discrete exponential). | |
max_octave_trans: maximum octave steps of transition per second. | |
no_trough_prob: | |
""" | |
super().__init__( | |
sr, | |
frame_time, | |
freq_min, | |
freq_max, | |
threshold=None, | |
median_win=None, | |
down_sr=down_sr) | |
self.median_win = median_win | |
# pYIN parameters | |
self.fmin = freq_min | |
self.bins_per_octave = bins_per_octave | |
self.bins_per_semitone = bins_per_semitone | |
self.lambda_ = lambda_ | |
self.no_trough_prob = no_trough_prob | |
self.register_pyin_state( | |
frame_time, | |
freq_min, | |
freq_max, | |
bins_per_octave, | |
bins_per_semitone, | |
beta_parameters, | |
num_thresholds, | |
switch_prob, | |
max_octave_trans) | |
def sample(self, cmnd: torch.Tensor) -> torch.Tensor: | |
"""Sample pitch frequency based on Viterbi-path searching. | |
Args: | |
cmnd: [torch.float32; [..., T / strides, tau_max - tau_min]], | |
framed cumulative mean normalized difference. | |
Returns: | |
[torch.float32; [..., S]], pitch sequence. | |
""" | |
device = cmnd.device | |
# [..., T / strides, tau_max - tau_min] | |
lmin = localmin(cmnd) | |
# [..., T / strides, tau_max - tau_min, num_thresholds] | |
tholds = lmin[..., None] & (cmnd[..., None] < self.thresholds[1:]) | |
# [..., T / strides, tau_max - tau_min, num_thresholds] | |
positions = torch.cumsum(tholds, dim=-2) | |
# boltzmann prior, truncated exponential | |
k, N = positions - 1, positions[..., -1, :] | |
_fact = np.expm1(-self.lambda_) / (-self.lambda_ * N).expm1() | |
prior = _fact[..., None, :] * (-self.lambda_ * k).exp() | |
# masking | |
prior[~tholds] = 0. | |
# [..., T / strides, tau_max - tau_min] | |
probs = torch.matmul(prior, self.beta_probs[:, None]).squeeze(dim=-1) | |
## add prob to global minima if no candidates below the threshold | |
## else add prob to each candidates below the threshold | |
# [..., T / strides] | |
global_min = cmnd.masked_fill(~lmin, np.inf).argmin(dim=-1) | |
# alias | |
num_tholds = tholds.shape[-1] | |
# [..., T / strides, 1, num_thresholds], threshold of global min | |
holds = tholds.gather( | |
-2, | |
global_min[..., None, None].repeat( | |
[1] * cmnd.dim() + [num_tholds])) | |
# [..., T / strides] | |
below_min = torch.count_nonzero(~holds.squeeze(dim=-2), dim=-1) | |
# [tholds] | |
a = torch.arange(num_tholds, device=device) | |
# add probs | |
probs.scatter_add_( | |
-1, | |
global_min[..., None], | |
self.no_trough_prob * ( | |
(a < below_min[..., None]).float() * self.beta_probs).sum(dim=-1, keepdim=True)) | |
# [tau_max - tau_min] | |
tau = torch.arange(self.tau_max - self.tau_min, device=device) | |
# [..., T / strides, tau_max - tau_min] | |
pshifts = parabolic_interp(cmnd) | |
# refining peak | |
period = tau + self.tau_min + 1 + pshifts | |
# [..., T / strides, tau_max - tau_min], to frequency | |
f0s = self.down_sr / period.clamp_min(1e-5) | |
# [..., T / strides, tau_max - tau_min], quantize | |
bins = self.bins_per_octave * self.bins_per_semitone * (f0s / self.fmin).log2() | |
bins = bins.round().clamp(0, self.total_bins - 1).long() | |
# [..., T / strides, 2 x total_bins], observation probs | |
observed = torch.zeros(*bins.shape[:-1], self.total_bins * 2, device=device) | |
observed.scatter_(-1, bins, probs) | |
# voice probs | |
voiced = observed[..., :self.total_bins].sum(dim=-1).clamp(0., 1.) | |
observed[..., self.total_bins:] = (1 - voiced[..., None]) / self.total_bins | |
# path search | |
# [..., T / strides] | |
states = viterbi( | |
observed.clamp_min(1e-7).log(), | |
self.transition.clamp_min(1e-7).log(), | |
self.p_init.clamp_min(1e-7).log()) | |
# convert to frequency | |
f0 = self.freqs[states % self.total_bins] | |
# if in voice | |
voiced_flag = states < self.total_bins | |
# unvoice to zero | |
f0[~voiced_flag] = 0. | |
# median pool | |
if self.median_win is not None: | |
f0 = torch.median( | |
f0.unfold(-1, self.median_win, 1), | |
dim=-1).values | |
return f0 | |
def register_pyin_state(self, | |
frame_time: float, | |
fmin: float, | |
fmax: float, | |
bins_per_octave: int, | |
bins_per_semitone: int, | |
beta_parameters: Tuple[int, float], | |
num_thresholds: int, | |
switch_prob: float, | |
max_octave_trans: float): | |
# [num_thresholds + 1] | |
self.register_buffer( | |
'thresholds', | |
torch.linspace(0., 1., num_thresholds + 1), | |
persistent=False) | |
# beta-distribution prior | |
import scipy.stats | |
a, b = beta_parameters | |
beta_cdf = torch.tensor(scipy.stats.beta.cdf(self.thresholds, a, b), dtype=torch.float32) | |
# [num_thresholds] | |
self.register_buffer( | |
'beta_probs', | |
beta_cdf.diff(), | |
persistent=False) | |
# the number of the possible bins | |
total_bins = int( | |
bins_per_octave * bins_per_semitone * np.log2(fmax / fmin)) | |
self.total_bins = total_bins | |
# maximum octave steps of transition per frame | |
max_semitones_per_frame = round( | |
max_octave_trans * bins_per_octave * frame_time) | |
# local transition matrix with triangular window | |
## transition[i, j] = 0 if |i - j| > width | |
## transition[i, i] is maximal | |
## transition[i, i - width // 2:i + width // 2] = window | |
w = max_semitones_per_frame * bins_per_semitone + 1 | |
# [total_bins] | |
a = torch.arange(total_bins) | |
# [total_bins, total_bins], bin-transition probs | |
grid = (w // 2 + 1 - (a - a[:, None]).abs()).clamp_min(0) | |
transition = grid / grid.sum(dim=-1) | |
# self-loop transition matrix | |
## transition[i, i] = p for all i | |
## transition[i, j] = (1 - p) / (states - 1) for all i != j | |
# [2, 2], voice-transition probs | |
t_switch = torch.full((2, 2), switch_prob) | |
t_switch[0, 0] = 1 - switch_prob | |
t_switch[1, 1] = 1 - switch_prob | |
# [2 x total_bins, 2 x total_bins], apply voice-probs | |
self.register_buffer( | |
'transition', | |
torch.kron(t_switch, transition), | |
persistent=False) | |
# [2 x total_bins], initial probs, unvoiced | |
p_init = torch.zeros(2 * total_bins) | |
p_init[total_bins:] = 1 / total_bins | |
self.register_buffer('p_init', p_init, persistent=False) | |
# [total_bins] | |
self.register_buffer( | |
'freqs', | |
fmin * 2 ** (a / (bins_per_octave * bins_per_semitone)), | |
persistent=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchaudio.functional as AF | |
def localmin(x: torch.Tensor) -> torch.Tensor: | |
"""Local minima. | |
Args: | |
x: [torch.float32; [..., T]], input tensor. | |
Returns: | |
[bool; [..., T]], local minima. | |
""" | |
# [..., T + 1], x[i + 1] - x[i] | |
d = F.pad(x.diff(dim=-1), [1, 1]) | |
# [..., T], dec & inc | |
return (d[..., :-1] <= 0) & (d[..., 1:] >= 0) | |
def parabolic_interp(x: torch.Tensor) -> torch.Tensor: | |
"""Parabolic interpolation for smoothing difference function. | |
Args: | |
x: [torch.float32; [..., C]], input tensor. | |
Returns: | |
[torch.float32; [..., C]], parabolic shifts. | |
""" | |
# [..., C - 2], previous, current, next | |
p, c, n = x[..., :-2], x[..., 1:-1], x[..., 2:] | |
# assume x is convex, then a > 0 | |
a = n + p - 2 * c | |
b = 0.5 * (n - p) | |
# [..., C - 2] | |
shifts = -b / a | |
shifts[b.abs() >= a.abs()] = 0. | |
# [..., C], edge | |
return F.pad(shifts, [1, 1]) | |
class YIN(nn.Module): | |
"""YIN-based pitch estimation algorithm. | |
""" | |
def __init__(self, | |
sr: int, | |
frame_time: float = 0.01, | |
freq_min: float = 75., | |
freq_max: float = 600., | |
threshold: float = 0.2, | |
median_win: Optional[int] = 3, | |
down_sr: int = 16000): | |
"""Initializer. | |
Args: | |
sr: sampling rate. | |
frame_time: duration of the frame. | |
freq_min, freq_max: frequency min and max. | |
threshold: harmonicity threshold. | |
median_win: length of the window for median smoothing. | |
down_sr: downsampling sr for fast computation. | |
""" | |
super().__init__() | |
self.sr = sr | |
self.strides = int(down_sr * frame_time) | |
self.tau_max = int(down_sr // freq_min) | |
self.tau_min = int(down_sr // freq_max) | |
self.threshold = threshold | |
self.median_win = median_win | |
self.down_sr = down_sr | |
@classmethod | |
def cmnd(cls, signal: torch.Tensor, tmax: int, tmin: int) -> torch.Tensor: | |
"""Cumulative mean normalized difference. | |
Args: | |
signal: [torch.float32; [..., W]], input signal. | |
tmax, tmin: maximum, minimum value of the time-lag. | |
Returns: | |
[torch.float32; [..., tmax - tmin]], CMND. | |
""" | |
# one-based | |
# d[tau] | |
# = sum_{j=1}^{W-tau} (x[j] - x[j + tau])^2 | |
# = sum_{j=1}^{W-tau} (x[j]^2 - 2x[j]x[j + tau] + x[j + tau]^2) | |
# = c[W - tau] - 2 * a[tau] + (c[W] - c[tau]) | |
# where c[k] = sum_{j=1}^k x[j]^2 | |
# a[tau] = sum_{j=1}^W x[j]x[j + tau] | |
# W | |
w = signal.shape[-1] | |
# [..., W + 1] | |
fft = torch.fft.rfft(signal, w * 2, dim=-1) | |
# [..., W x 2], symmetric | |
corr = torch.fft.irfft(fft * fft.conj(), dim=-1) | |
# [..., W] | |
cumsum = signal.square().cumsum(dim=-1) | |
# [..., tmax], difference | |
diff = ( | |
# c[W - tau] | |
torch.flip(cumsum[..., -tmax:], dims=(-1,)) | |
# -2 x a[tau] | |
- 2 * corr[..., :tmax] | |
# + (c[W] - c[tau]) | |
+ cumsum[..., -1, None] - cumsum[..., :tmax]) | |
# [..., tmax - 1], remove redundant | |
cumdiff = diff[..., 1:] / (diff[..., 1:].cumsum(dim=-1) + 1e-7) | |
# normalize | |
cumdiff = cumdiff * torch.arange(1, tmax, device=cumdiff.device) | |
# [..., tmax - tmin] | |
return F.pad(cumdiff, [1, 0], value=1.)[..., tmin:] | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
"""Estimate the pitch from the input signal based on YIN. | |
Args: | |
inputs: [torch.float32; [..., T]], input audio. | |
Returns: | |
[torch.float32; [..., S]], pitch frequency. | |
""" | |
down = AF.resample(inputs, self.sr, self.down_sr) | |
# set windows based on tau-max | |
w = int(2 ** np.ceil(np.log2(self.tau_max))) + 1 | |
# [..., T / strides, windows] | |
frames = F.pad(down, [0, w]).unfold(-1, w, self.strides) | |
# [..., T / strides, tau_max - tau_min], cumulative mean normalized difference | |
cmnd = YIN.cmnd(frames, self.tau_max, self.tau_min) | |
# sampling | |
return self.sample(cmnd) | |
def sample(self, cmnd: torch.Tensor) -> torch.Tensor: | |
"""Sample pitch frequency locally. | |
Args: | |
cmnd: [torch.float32; [..., T / strides, tau_max - tau_min]], | |
framed cumulative mean normalized difference. | |
Returns: | |
[torch.float32; [..., S]], pitch sequence. | |
""" | |
# [..., T / strides] | |
thold = (cmnd < self.threshold).long().argmax(dim=-1) | |
# if not found | |
thold[thold == 0] = self.tau_max - self.tau_min | |
# [..., T / strides, tau_max - tau_min] switch mask | |
thold = thold[..., None] <= torch.arange( | |
self.tau_max - self.tau_min, device=cmnd.device) | |
# [..., T / strides, tau_max - tau_min] | |
lmin = localmin(cmnd) | |
# [..., T / strides] | |
tau = (thold & lmin).long().argmax(dim=-1) | |
# if not found | |
tau[tau == self.tau_max - self.tau_min - 1] == 0 | |
# [..., T / strides, tau_max - tau_min] | |
pshifts = parabolic_interp(cmnd) | |
# refining peak | |
period = tau + self.tau_min + 1 + pshifts.gather(-1, tau[..., None]).squeeze(dim=-1) | |
# [..., T / strides], to frequency | |
pitch = torch.where( | |
tau > 0, | |
self.down_sr / period, | |
torch.tensor(0., device=tau.device)) | |
# median pool | |
if self.median_win is not None: | |
pitch = torch.median( | |
pitch.unfold(-1, self.median_win, 1), | |
dim=-1).values | |
return pitch |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment