Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save taroushirani/bfba9ac5dc7bfe8dd403868e213e5187 to your computer and use it in GitHub Desktop.
Save taroushirani/bfba9ac5dc7bfe8dd403868e213e5187 to your computer and use it in GitHub Desktop.
Tempo-shift data augmentation with preserved consonant duration
#! /usr/bin/python
import argparse
from glob import glob
import logging
import os
from os.path import join, basename, splitext
import re
import sys
from tqdm import tqdm
import librosa
import soundfile as sf
import numpy as np
import pyrubberband as pyrb
from nnmnkwii.io import hts
from nnsvs.io.hts import get_note_indices
def get_parser():
parser = argparse.ArgumentParser(
description="Consonant-invariant tempo-shift data augmentation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("src_dir", type=str, help="Source directory")
parser.add_argument("dest_dir", type=str, help="Destination directory")
parser.add_argument("scale", type=float, help="Scale of tempo conversion")
parser.add_argument('--debug', action='store_true', help='Debug Mode')
return parser
def _is_full_context(labels):
assert(isinstance(labels, hts.HTSLabelFile))
return "@" in labels[0][-1]
def _is_vowel(phoneme):
return phoneme in ["a", "i", "u", "e", "o", "N", "I", "U"]
def _is_pau(phoneme):
return phoneme == "pau"
def _is_br(phoneme):
return phoneme == "br"
def _is_special(phoneme):
# "cl", "br"
return phoneme == "cl" or _is_br(phoneme)
def _is_consonant(phoneme):
return not _is_vowel(phoneme) and not _is_pau(phoneme) and not _is_special(phoneme)
def _has_vowel_reduction(phonemes):
ret = True
for phoneme in phonemes:
if _is_pau(phoneme) or _is_vowel(phoneme):
ret = False
break
return ret
def _convert_mono_labels(mono_labels, note_indices, scale, tolerance=5):
logging.debug(f"scale: {scale}")
logging.debug(f"mono_labels.contexts: {mono_labels.contexts}")
logging.debug(f"phoneme of note_indices: {[mono_labels.contexts[i] for i in note_indices]}")
new_mono_labels = hts.HTSLabelFile()
for idx in range(len(note_indices)):
note_index = note_indices[idx]
logging.debug(f"idx: {idx}, note_index: {note_index}")
if note_index == len(mono_labels) -1:
logging.debug("Last pau")
note_duration = mono_labels.end_times[note_index] - mono_labels.start_times[note_index]
new_mono_labels.append([new_mono_labels.end_times[-1], \
new_mono_labels.end_times[-1] + int(note_duration * scale), \
mono_labels.contexts[note_index]], strict=False)
else:
phoneme_num_in_note = note_indices[idx+1]-note_index
logging.debug(f"range(note_index, phoneme_num_in_note): {range(note_index, note_index+phoneme_num_in_note)}")
phonemes_in_note = [mono_labels.contexts[i] for i in range(note_index, note_index+phoneme_num_in_note)]
logging.debug(f"phoneme in notes: {phonemes_in_note}")
if _has_vowel_reduction(phonemes_in_note):
raise RuntimeError(f"Vowel reduction detected: {phonemes_in_note}")
note_duration = mono_labels.end_times[note_index + phoneme_num_in_note -1] - mono_labels.start_times[note_index]
residue = int(note_duration * scale)
logging.debug(f"note_duration: {note_duration}")
for pos in range(phoneme_num_in_note):
logging.debug(f"note_index+pos: {note_index+pos}")
if _is_pau(mono_labels.contexts[note_index+pos]):
if note_index == 0:
logging.debug("First 'pau'")
start_time = 0
else:
start_time = new_mono_labels.end_times[-1]
if phoneme_num_in_note == 1:
logging.debug("Current note consits of only 'pau'")
new_mono_labels.append([start_time,\
start_time + residue, \
mono_labels.contexts[note_index]], strict=False)
elif phoneme_num_in_note == 2:
logging.debug("Current note consists of 'pau' 'br'")
br_duration = mono_labels.end_times[note_index+1] - mono_labels.start_times[note_index+1]
residue-= br_duration
new_mono_labels.append([start_time, \
start_time + residue, \
mono_labels.contexts[note_index]], strict=False)
else:
raise RuntimeError(f"The impossible phoneme_num_in_note: {phoneme_num_in_note}")
elif _is_consonant(mono_labels.contexts[note_index+pos]):
logging.debug("[:consonant:]")
phoneme_duration = mono_labels.end_times[note_index+pos] - mono_labels.start_times[note_index+pos]
new_mono_labels.append([new_mono_labels.end_times[-1], \
new_mono_labels.end_times[-1] + phoneme_duration, \
mono_labels.contexts[note_index+pos]], strict=False)
residue-= phoneme_duration
elif _is_vowel(mono_labels.contexts[note_index+pos]):
logging.debug("[:vowel:]")
if pos == phoneme_num_in_note - 1:
logging.debug("Current note ends with [:vowel:]")
new_mono_labels.append([new_mono_labels.end_times[-1], \
new_mono_labels.end_times[-1] + residue, \
mono_labels.contexts[note_index+pos]], strict=False)
else:
logging.debug("Current note ends with [:special:]")
special_duration = mono_labels.end_times[note_index+pos+1] - mono_labels.start_times[note_index+pos+1]
residue-=special_duration
new_mono_labels.append([new_mono_labels.end_times[-1], \
new_mono_labels.end_times[-1] + residue, \
mono_labels.contexts[note_index+pos]], strict=False)
elif _is_special(mono_labels.contexts[note_index+pos]):
logging.debug("[:special:]")
if pos != phoneme_num_in_note - 1:
raise RuntimeError(f"[:special:] does not located as the last phoneme.")
phoneme_duration = mono_labels.end_times[note_index+pos] - mono_labels.start_times[note_index+pos]
new_mono_labels.append([new_mono_labels.end_times[-1], \
new_mono_labels.end_times[-1] + phoneme_duration, \
mono_labels.contexts[note_index+pos]], strict=False)
else:
raise RuntimeError(f"Unknown phoneme: {mono_labels.contexts[note_index+pos]}")
logging.debug(f"int(mono_labels.end_times[-1]*scale): {int(mono_labels.end_times[-1]*scale)}")
logging.debug(f"new_mono_labels.end_times[-1]: {new_mono_labels.end_times[-1]}")
assert len(mono_labels) == len(new_mono_labels)
if abs(int(mono_labels.end_times[-1]*scale) - new_mono_labels.end_times[-1]) > tolerance:
raise RuntimeError(f"Cumulative error exceed the tolerance.")
for i in range(len(mono_labels)):
logging.debug(f"{mono_labels.start_times[i]} {mono_labels.end_times[i]} {mono_labels.contexts[i]} | {new_mono_labels.start_times[i]} {new_mono_labels.end_times[i]} {new_mono_labels.contexts[i]} | {mono_labels.end_times[i] - mono_labels.start_times[i]} {new_mono_labels.end_times[i] - new_mono_labels.start_times[i]}")
return new_mono_labels
def _convert_full_labels(full_labels, scale):
new_s = []
new_e = []
new_contexts = []
for s, e, context in full_labels:
new_s.append(int(s * scale))
new_e.append(int(e * scale))
# Tempo: d5, e5, f5
for id, pre, post in [("d5", "%", "\\|"), ("e5", "~", "!"), ("f5", "\\$", "\\$")]:
match = re.search(f"{pre}([0-9]+){post}", context)
# if not "xx"
if match is not None:
assert len(match.groups()) == 1
num = match.group(0)[1:-1]
if len(num) > 0:
pre = pre.replace("\\", "")
post = post.replace("\\", "")
new_num = int(round(float(num) / scale))
logging.debug(f"id: {id}, old_tempo: {num}, new_tempo: {new_num}")
context = context.replace(
match.group(0), f"{pre}{new_num}{post}", 1
)
# Length in sec or msec: d7, e7, f7
# e12/13, e20/21, e31/32, e37/e38, e43/44, e51/52
for id, pre, post in [
("d7", "&", ";"),
("e7", "@", "#"),
("f7", "\\+", "%"),
("e12", "\\|", "\\["),
("e13", "\\[", "&"),
("e20", "_", ";"),
("e21", ";", "\\$"),
("e31", "~", "="),
("e32", "=", "@"),
("e37", "#", "\\|"),
("e38", "\\|", "\\|"),
("e43", "\\+", "\\["),
("e44", "\\[", ";"),
("e51", "\\^", "@"),
("e52", "@", "\\["),
]:
match = re.search(f"{pre}([0-9]+){post}", context)
# if not "xx"
if match is not None:
assert len(match.groups()) == 1
num = match.group(0)[1:-1]
if len(num) > 0:
pre = pre.replace("\\", "")
post = post.replace("\\", "")
# NOTE: ensure > 0
new_num = max(int(float(num) * scale), 1)
logging.debug(f"id: {id}, old_length(by 0.01 sec): {num}, new_length(by 0.01 sec): {new_num}")
context = context.replace(
match.group(0), f"{pre}{new_num}{post}", 1
)
new_contexts.append(context)
new_full_labels = hts.HTSLabelFile()
new_full_labels.start_times = new_s
new_full_labels.end_times = new_e
new_full_labels.contexts = new_contexts
assert len(full_labels) == len(new_full_labels)
assert int(full_labels.end_times[-1]*scale) == new_full_labels.end_times[-1]
return new_full_labels
def _convert_wav(wav, sr, mono_labels, new_mono_labels, scale, tolerance=5):
logging.debug(f"wav.shape: {wav.shape}")
time_map = []
for idx in range(len(mono_labels)):
end_frame = int(mono_labels.end_times[idx] * 1e-7 * sr)
new_end_frame = int(new_mono_labels.end_times[idx] * 1e-7 * sr)
logging.debug(f"phoneme: {mono_labels.contexts[idx]}, end_frame: {end_frame}, new_end_frame: {new_end_frame}")
time_map.append([end_frame, new_end_frame])
if time_map[-1][0] < wav.shape[0]:
logging.debug(f"time_map[-1][0]: {time_map[-1][0]} is smaller than wav.shape[0]: {wav.shape[0]}")
wav = wav[0:time_map[-1][0]]
elif time_map[-1][0] > wav.shape[0]:
logging.debug(f"time_map[-1][0]: {time_map[-1][0]} is bigger than wav.shape[0]: {wav.shape[0]}")
time_map[-1][0] = wav.shape[0]
time_map[-1][1] = int(wav.shape[0] * scale)
pyrb.__RUBBERBAND_UTIL = 'rubberband-r3'
new_wav = pyrb.timemap_stretch(wav, sr, time_map)
logging.debug(f"new_wav.shape: {new_wav.shape}")
return new_wav
if __name__ == "__main__":
args = get_parser().parse_args(sys.argv[1:])
if args.debug:
logging.basicConfig(level=logging.DEBUG)
src_dir = args.src_dir
dest_dir = args.dest_dir
scale = args.scale
new_mono_lab_dir = join(dest_dir, "labels", "mono")
new_full_lab_dir = join(dest_dir, "labels", "full")
new_wav_dir = join(dest_dir, "wav")
os.makedirs(new_mono_lab_dir, exist_ok=True)
os.makedirs(new_full_lab_dir, exist_ok=True)
os.makedirs(new_wav_dir, exist_ok=True)
postfix = f"citsda_{str(scale).replace('.', '_')}"
logging.debug(f"postfix: {postfix}")
full_lab_files = sorted(glob(join(src_dir, "labels", "full", "*.lab")))
song_list = []
for full_lab_file in tqdm(full_lab_files):
logging.debug(f"full_lab_file: {full_lab_file}")
song_name = splitext(basename(full_lab_file))[0]
song_list.append(song_name)
full_labels = hts.load(full_lab_file)
assert _is_full_context(full_labels)
new_full_labels = _convert_full_labels(full_labels, scale)
note_indices = get_note_indices(full_labels)
mono_lab_file = join(src_dir, "labels", "mono", f"{song_name}.lab")
logging.debug(f"mono_lab_file: {mono_lab_file}")
mono_labels = hts.load(mono_lab_file)
assert not _is_full_context(mono_labels)
try:
new_mono_labels = _convert_mono_labels(mono_labels, note_indices, scale)
except RuntimeError as e:
print(f"ERROR: {song_name}: {e}")
continue
wav_file = join(src_dir, "wav", f"{song_name}.wav")
wav, sr = librosa.load(wav_file, sr=None)
logging.debug(f"wav.shape: {wav.shape}")
new_wav = _convert_wav(wav, sr, mono_labels, new_mono_labels, scale)
new_mono_lab_file = join(new_mono_lab_dir, f"{song_name}_{postfix}.lab")
with open(new_mono_lab_file, "w") as of:
of.write(str(new_mono_labels))
new_full_lab_file = join(new_full_lab_dir, f"{song_name}_{postfix}.lab")
with open(new_full_lab_file, "w") as of:
of.write(str(new_full_labels))
new_wav_file = join(new_wav_dir, f"{song_name}_{postfix}.wav")
logging.debug(f"new_wav.shape: {new_wav.shape}")
sf.write(new_wav_file, new_wav, sr, format="WAV")
print(song_list)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment