Skip to content

Instantly share code, notes, and snippets.

@hiromu
Created March 5, 2018 00:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hiromu/2ea697b492ab36a7266a7052d450d969 to your computer and use it in GitHub Desktop.
Save hiromu/2ea697b492ab36a7266a7052d450d969 to your computer and use it in GitHub Desktop.
"""Prepare acoustic/linguistic/duration features.
usage:
prepare_features.py [options] <DATA_ROOT> <DST_ROOT>
options:
--overwrite Overwrite files
--max_num_files=<N> Max num files to be collected. [default: -1]
-h, --help show this help message and exit
"""
from __future__ import division, print_function, absolute_import
from docopt import docopt
import numpy as np
from nnmnkwii.datasets import FileSourceDataset
from nnmnkwii.datasets.jsut import WavFileDataSource
from nnmnkwii.frontend import merlin as fe
from nnmnkwii.preprocessing import delta_features, inv_scale, scale
from nnmnkwii.preprocessing.f0 import interp1d
from nnmnkwii.io import hts
import itertools
import os
import pysptk
import pyworld
import soundfile
import re
import warnings
from sklearn.utils.extmath import _incremental_mean_and_var
from tqdm import tqdm
max_num_files = None
silence = re.compile('sil*')
order = 59
phonemes = [':', 'a', 'a:', 'b', 'by', 'ch', 'd', 'dy', 'e', 'e:', 'f', 'g', 'gy', 'h', 'hy', 'i', 'i:', 'j', 'k', 'ky', 'm', 'my', 'n', 'N', 'ny', 'o', 'o:', 'p', 'py', 'q', 'r', 'ry', 's', 'sh', 't', 'ts', 'ty', 'u', 'u:', 'w', 'y', 'z', 'zy', 'sp']
class LinguisticSource(WavFileDataSource):
def __init__(self, data_root, *args, **kwargs):
super(LinguisticSource, self).__init__(data_root, *args, **kwargs)
self.phoneme_dict = {p: i for i, p in enumerate(phonemes)}
def collect_files(self):
files = [f.replace('wav', 'lab') for f in super(LinguisticSource, self).collect_files()]
if max_num_files is not None and max_num_files > 0:
return files[:max_num_files]
else:
return files
def collect_features(self, path):
labels = hts.load(path)
features = np.delete(labels.contexts, labels.silence_phone_indices(silence), axis=0)
return np.vectorize(self.phoneme_dict.get)(features).astype(np.float32)
class DurationSource(WavFileDataSource):
def collect_files(self):
files = [f.replace('wav', 'lab') for f in super(DurationSource, self).collect_files()]
if max_num_files is not None and max_num_files > 0:
return files[:max_num_files]
else:
return files
def collect_features(self, path):
labels = hts.load(path)
return np.delete(fe.duration_features(labels), labels.silence_phone_indices(silence), axis=0).astype(np.float32)
class AcousticSource(WavFileDataSource):
def collect_files(self):
wav_files = super(AcousticSource, self).collect_files()
lab_files = [f.replace('wav', 'lab') for f in wav_files]
if max_num_files is not None and max_num_files > 0:
return wav_files[:max_num_files], lab_files[:max_num_files]
else:
return wav_files, lab_files
def collect_features(self, wav_path, lab_path):
x, fs = soundfile.read(wav_path)
f0, sp, ap = pyworld.wav2world(x, fs)
bap = pyworld.code_aperiodicity(ap, fs)
mgc = pysptk.sp2mc(sp, order=order, alpha=pysptk.util.mcepalpha(fs))
f0 = f0[:, None]
lf0 = f0.copy()
nonzero_indices = np.nonzero(f0)
lf0[nonzero_indices] = np.log(f0[nonzero_indices])
vuv = (lf0 != 0).astype(np.float32)
lf0 = interp1d(lf0, kind="slinear")
features = np.hstack((mgc, lf0, vuv, bap))
labels = hts.load(lab_path)
return np.delete(features[:labels.num_frames()], labels.silence_frame_indices(silence), axis=0).astype(np.float32)
if __name__ == "__main__":
args = docopt(__doc__)
data_root = args['<DATA_ROOT>']
dst_root = args['<DST_ROOT>']
max_num_files = int(args['--max_num_files'])
overwrite = args['--overwrite']
linguistic_source = FileSourceDataset(LinguisticSource(data_root, subsets='all'))
duration_source = FileSourceDataset(DurationSource(data_root, subsets='all'))
acoustic_source = FileSourceDataset(AcousticSource(data_root, subsets='all'))
get_name = lambda idx: os.path.join(dst_root, os.path.splitext(os.path.basename(linguistic_source.collected_files[idx][0]))[0] + '.npz')
process_indices, rescale_indices = [], []
for idx, lin in tqdm(enumerate(linguistic_source)):
if not overwrite and os.path.exists(get_name(idx)):
rescale_indices.append(idx)
else:
process_indices.append(idx)
if len(process_indices) != len(linguistic_source):
warnings.warn('{}/{} wav files are processed.'.format(len(process_indices), len(linguistic_source)))
mean, var, count = 0, 0, 0
norm_path = os.path.join(dst_root, 'norm.npz')
if not overwrite and os.path.exists(norm_path):
norm = np.load(norm_path)
mean, var, count = norm['mean'], norm['var'], norm['count']
if len(rescale_indices):
init_mean, init_std = mean.copy(), np.sqrt(var)
for idx in tqdm(process_indices):
acoustic = acoustic_source[idx]
np.savez_compressed(get_name(idx), audio_features=acoustic)
mean, var, count = _incremental_mean_and_var(acoustic, mean, var, count)
std = np.sqrt(var)
np.savez_compressed(norm_path, mean=mean, var=var, count=count)
for idx in tqdm(rescale_indices):
data = dict(np.load(get_name(idx)))
data['audio_features'] = scale(inv_scale(data['audio_features'], init_mean, init_std), mean, std)
np.savez_compressed(get_name(idx), **data)
for idx in tqdm(process_indices):
name = get_name(idx)
acoustic = np.load(name)['audio_features']
np.savez_compressed(name, file_id=os.path.splitext(os.path.basename(name))[0], phonemes=linguistic_source[idx], durations=duration_source[idx], audio_features=scale(acoustic, mean, std))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment