Last active
July 14, 2020 13:33
-
-
Save polvanrijn/a4b825821d6a04f48f6e85b180cf4140 to your computer and use it in GitHub Desktop.
Code to perform examples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
import torch | |
import json | |
from flowtron import Flowtron | |
from data import Data | |
from torch.distributions import Normal | |
from data import load_wav_to_torch | |
sys.path.insert(0, "tacotron2") | |
sys.path.insert(0, "tacotron2/waveglow") | |
from glow import WaveGlow | |
from scipy.io.wavfile import write | |
def get_models_and_train(seed, waveglow_path, flowtron_path, model_config, data_config): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
# load waveglow | |
waveglow = torch.load(waveglow_path)['model'].cuda().eval() | |
waveglow.cuda().half() | |
for k in waveglow.convinv: | |
k.float() | |
waveglow.eval() | |
# load flowtron | |
model = Flowtron(**model_config).cuda() | |
state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict'] | |
model.load_state_dict(state_dict) | |
model.eval() | |
print("Loaded checkpoint '{}')".format(flowtron_path)) | |
ignore_keys = ['training_files', 'validation_files'] | |
trainset = Data( | |
data_config['training_files'], | |
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys)) | |
return waveglow, model, trainset | |
def mels_to_wav(waveglow, mels, filepath, sampling_rate): | |
# Taken from inference.py | |
# TODO why sigma 0.8 (also in inference.py)? In the paper it says: "During inference we used sigma = 0.7" (p. 4) | |
audio = waveglow.infer(mels.half(), sigma=0.8).float() | |
audio = audio.cpu().numpy()[0] | |
# normalize audio for now | |
audio = audio / np.abs(audio).max() | |
write(filepath, sampling_rate, audio) | |
def tile(a, dim, n_tile): | |
init_dim = a.size(dim) | |
repeat_idx = [1] * a.dim() | |
repeat_idx[dim] = n_tile | |
a = a.repeat(*(repeat_idx)) | |
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) | |
return torch.index_select(a, dim, order_index) | |
def prepare_text(text): | |
# Taken from inference.py | |
text = trainset.get_text(text).cuda() | |
return text[None] | |
def synthesize_text(text, sigma, filename): | |
torch.manual_seed(SEED) | |
torch.cuda.manual_seed(SEED) | |
# Taken from inference.py | |
with torch.no_grad(): | |
residual = torch.cuda.FloatTensor(1, 80, N_FRAMES).normal_() * sigma | |
mels, attentions = model.infer(residual, speaker_vecs, text) | |
mels_to_wav(waveglow, mels, filename, data_config['sampling_rate']) | |
################## | |
# Default settings | |
################## | |
FLOWTRON_PATH = 'models/flowtron_ljs.pt' | |
WAVEGLOW_PATH = 'models/waveglow_256channels_v4.pt' | |
SEED = 1234 | |
SPEAKER_ID = 0 | |
N_FRAMES = 400 | |
# `GATE_THRESHOLD` is not needed here, since 0.5 is the default value in `model.infer()` | |
# GATE_THRESHOLD = 0.5 | |
# Read the configuration | |
with open('config.json') as f: | |
data = f.read() | |
config = json.loads(data) | |
data_config = config["data_config"] | |
model_config = config["model_config"] | |
# avoids nonsense of cudnn | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = False # do NOT use cudnn auto-tuner to find the best algorithm to use for your hardware | |
# Get the models | |
waveglow, model, trainset = get_models_and_train(SEED, WAVEGLOW_PATH, FLOWTRON_PATH, model_config, data_config) | |
# Prepare speaker, text, utterance text | |
speaker_vecs = trainset.get_speaker_id(SPEAKER_ID).cuda() | |
speaker_vecs = speaker_vecs[None] | |
################## | |
# EXPERIMENTS | |
################## | |
SIGMA = 0.5 | |
########################## | |
# Experiment 1: variation | |
text = prepare_text('How much variation is there?') | |
for sigma in [0.0, 0.5, 1.0]: | |
filename = 'results/variation_%.1f.wav' % sigma | |
synthesize_text(text, sigma, filename) | |
################################################# | |
# Experiment 2: effect of (missing) interpunction | |
sentences = { | |
'well_known': 'It is well known that deep generative models have a rich latent space.', | |
'dogs': 'Dogs are sitting by the door.', | |
} | |
for effect, sentence in sentences.items(): | |
text = prepare_text(sentence) | |
synthesize_text(text, SIGMA, 'results/%s_with_dot.wav' % effect) | |
text_without_dot = prepare_text(sentence.replace('.', '')) | |
synthesize_text(text_without_dot, SIGMA, 'results/%s_no_dot.wav' % effect) | |
########################################## | |
# Experiment 3: Emotional prosody transfer | |
text = prepare_text('Humans are walking on the street.') | |
synthesize_text(text, SIGMA, 'results/baseline_humans.wav') # Synthesize baseline | |
# Taken from karkirowle's gist: | |
with torch.no_grad(): | |
# Start with empty residual | |
residual_accumulator = torch.zeros((1, 80, N_FRAMES)).to("cuda") | |
# loading mel spectra, in_lens, out_lens? | |
audio, _ = load_wav_to_torch('data/ravdess_surprised_prior.wav') | |
mel = trainset.get_mel(audio).to(device="cuda") | |
# You need to pad this because of the permute | |
mel = mel[None] | |
# Out lens describes the length of the output, i.e. the audio | |
out_lens = torch.LongTensor(1).to(device="cuda") | |
out_lens[0] = mel.size(2) | |
utterance_text = prepare_text('Kids are talking by the door.') | |
# In lens describes the length of the text input | |
in_lens = torch.LongTensor([utterance_text.shape[1]]).to(device="cuda") | |
# Compute residual in Mel spectrogram between the speaker recording and estimated using utterance text | |
residual, _, _, _, _, _, _ = model.forward(mel, speaker_vecs, utterance_text, in_lens, out_lens) | |
# Reorder dimensions | |
residual = residual.permute(1, 2, 0) | |
# Make sure that the 3rd dimension of the residual has up to `N_FRAMES` elements | |
residual = residual[:, :, :N_FRAMES] | |
# Fill the vector with copies of the residual if it contains less than `N_FRAMES` elements | |
if residual.shape[2] < N_FRAMES: | |
num_tile = int(np.ceil(N_FRAMES / residual.shape[2])) | |
# I used tiling instead of replication | |
residual = tile(residual.cpu(), 2, num_tile).to("cuda") | |
# Compute the residual sum | |
residual_accumulator = residual_accumulator + residual[:, :, :N_FRAMES] | |
##################################### | |
# Sample without averaging over time | |
##################################### | |
dist = Normal(residual_accumulator, SIGMA) | |
z_style = dist.sample() | |
torch.manual_seed(SEED) | |
torch.cuda.manual_seed(SEED) | |
mels, attentions = model.infer(z_style, speaker_vecs, text) | |
mels_to_wav(waveglow, mels, 'results/surprised_humans_transfer_without_time_avg.wav', data_config['sampling_rate']) | |
################################# | |
# Sample WITH averaging over time | |
################################# | |
residual_accumulator = residual_accumulator.mean(dim=2) | |
dist = Normal(residual_accumulator, SIGMA) | |
z_style = dist.sample((N_FRAMES,)).permute(1, 2, 0) | |
torch.manual_seed(SEED) | |
torch.cuda.manual_seed(SEED) | |
mels, attentions = model.infer(z_style, speaker_vecs, text) | |
mels_to_wav(waveglow, mels, 'results/surprised_humans_transfer_time_avg.wav', data_config['sampling_rate']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment