Skip to content

Instantly share code, notes, and snippets.

@Hiroshiba
Created January 7, 2020 16:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Hiroshiba/32ea4c6df80e2c42883a59a5c31e2af9 to your computer and use it in GitHub Desktop.
Save Hiroshiba/32ea4c6df80e2c42883a59a5c31e2af9 to your computer and use it in GitHub Desktop.
JVSデータセットの無音ラベルがどれくらい合ってるのか評価
from pathlib import Path
from typing import Union, Sequence
import librosa
import numpy
from librosa.effects import _signal_to_frame_nonsilent
from matplotlib import pyplot
from tqdm import tqdm
class Phoneme:
def __init__(
self,
name: str,
start: float = None,
end: float = None,
) -> None:
self.name = name
self.start = start
self.end = end
def __eq__(self, other: 'Phoneme'):
return self.name == other.name
def __repr__(self):
return f'Phoneme(phoneme=\'{self.name}\', start={self.start}, end={self.end})'
@property
def duration(self):
return self.end - self.start
@property
def is_silence(self):
return self.name in ('pau', 'sil')
@classmethod
def parse(cls, s: str):
words = s.split()
return cls(
start=float(words[0]),
end=float(words[1]),
name=words[2],
)
@classmethod
def load_julius_list(cls, path: Union[str, Path]):
phonemes = [
cls.parse(s)
for s in Path(path).read_text().splitlines()
if len(s) > 0
]
return phonemes
def calc_pr(silent_phoneme: numpy.ndarray, silent_wave: numpy.ndarray):
r = (silent_phoneme & silent_wave).sum() / silent_wave.sum()
p = (silent_phoneme & silent_wave).sum() / silent_phoneme.sum()
return r, p
def process(path_wav: Path, path_label: Path, top_db: float):
# wave
w, rate = librosa.load(path_wav, sr=None)
silent_w = ~ _signal_to_frame_nonsilent(
w,
frame_length=rate // 10,
hop_length=rate // 100,
top_db=top_db,
)
# phoneme
silent_p = numpy.ones(shape=silent_w.shape, dtype=bool)
phonemes = Phoneme.load_julius_list(path_label)
for p in phonemes:
if p.is_silence:
continue
s = int(p.start * 100)
e = int(p.end * 100)
silent_p[s:e + 1] = False
return silent_p, silent_w
def each_threshold(path_labels: Sequence[Path], path_wavs: Sequence[Path], top_db: float):
silent_phonemes = []
silent_waves = []
for path_label, path_wav in zip(path_labels, path_wavs):
silent_p, silent_w = process(path_label=path_label, path_wav=path_wav, top_db=top_db)
silent_phonemes.append(silent_p)
silent_waves.append(silent_w)
return silent_phonemes, silent_waves
def show_silent(path_wav: Path, silent_phoneme: numpy.ndarray, silent_wave: numpy.ndarray, output: Path = None):
if output is None:
output = f'{path_wav.stem}.svg'
pyplot.figure(figsize=[10, 6])
# wave
pyplot.subplot(3, 1, 1)
w, rate = librosa.load(path_wav)
pyplot.plot(w[::rate // 100])
# silent
pyplot.subplot(3, 1, 2)
pyplot.plot(silent_phoneme)
pyplot.title('silent phoneme')
pyplot.subplot(3, 1, 3)
pyplot.plot(silent_wave)
pyplot.title('silent wave')
# save
pyplot.tight_layout()
pyplot.savefig(output)
def main():
path_root = Path.home() / 'Downloads' / 'jvs_ver1' / 'jvs_ver1'
path_human = path_root / 'jvs009'
path_labels = sorted((path_human / 'parallel100' / 'lab' / 'mon').glob('VOICEACTRESS100_*.lab'))
path_wavs = sorted((path_human / 'parallel100' / 'wav24kHz16bit').glob('VOICEACTRESS100_*.wav'))
# top_dbs = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
top_dbs = [40]
recall = []
precision = []
for top_db in tqdm(top_dbs):
silent_phonemes, silent_waves = each_threshold(path_labels=path_labels, path_wavs=path_wavs, top_db=top_db)
for path_wav, silent_phoneme, silent_wave in zip(path_wavs, silent_phonemes, silent_waves):
r, p = calc_pr(silent_phoneme=silent_phoneme, silent_wave=silent_wave)
if r < 0.9 or p < 0.9:
print(f'{path_wav.stem}: r={r:.2f}: p={p:.2f}')
show_silent(path_wav=path_wav, silent_phoneme=silent_phoneme, silent_wave=silent_wave)
silent_phoneme = numpy.concatenate(silent_phonemes)
silent_wave = numpy.concatenate(silent_waves)
r, p = calc_pr(silent_phoneme=silent_phoneme, silent_wave=silent_wave)
recall.append(r)
precision.append(p)
# PR curve
# figure = pyplot.figure(figsize=[12, 5])
# pyplot.subplot(1, 2, 1)
# pyplot.plot(top_dbs, recall)
# pyplot.subplot(1, 2, 2)
# pyplot.plot(top_dbs, precision)
# pyplot.savefig('pr.svg')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment