Skip to content

Instantly share code, notes, and snippets.

@polvanrijn
Last active July 14, 2020 13:33
Show Gist options
  • Save polvanrijn/a4b825821d6a04f48f6e85b180cf4140 to your computer and use it in GitHub Desktop.
Save polvanrijn/a4b825821d6a04f48f6e85b180cf4140 to your computer and use it in GitHub Desktop.
Code to perform examples
import sys
import numpy as np
import torch
import json
from flowtron import Flowtron
from data import Data
from torch.distributions import Normal
from data import load_wav_to_torch
sys.path.insert(0, "tacotron2")
sys.path.insert(0, "tacotron2/waveglow")
from glow import WaveGlow
from scipy.io.wavfile import write
def get_models_and_train(seed, waveglow_path, flowtron_path, model_config, data_config):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# load waveglow
waveglow = torch.load(waveglow_path)['model'].cuda().eval()
waveglow.cuda().half()
for k in waveglow.convinv:
k.float()
waveglow.eval()
# load flowtron
model = Flowtron(**model_config).cuda()
state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict']
model.load_state_dict(state_dict)
model.eval()
print("Loaded checkpoint '{}')".format(flowtron_path))
ignore_keys = ['training_files', 'validation_files']
trainset = Data(
data_config['training_files'],
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys))
return waveglow, model, trainset
def mels_to_wav(waveglow, mels, filepath, sampling_rate):
# Taken from inference.py
# TODO why sigma 0.8 (also in inference.py)? In the paper it says: "During inference we used sigma = 0.7" (p. 4)
audio = waveglow.infer(mels.half(), sigma=0.8).float()
audio = audio.cpu().numpy()[0]
# normalize audio for now
audio = audio / np.abs(audio).max()
write(filepath, sampling_rate, audio)
def tile(a, dim, n_tile):
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(a, dim, order_index)
def prepare_text(text):
# Taken from inference.py
text = trainset.get_text(text).cuda()
return text[None]
def synthesize_text(text, sigma, filename):
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
# Taken from inference.py
with torch.no_grad():
residual = torch.cuda.FloatTensor(1, 80, N_FRAMES).normal_() * sigma
mels, attentions = model.infer(residual, speaker_vecs, text)
mels_to_wav(waveglow, mels, filename, data_config['sampling_rate'])
##################
# Default settings
##################
FLOWTRON_PATH = 'models/flowtron_ljs.pt'
WAVEGLOW_PATH = 'models/waveglow_256channels_v4.pt'
SEED = 1234
SPEAKER_ID = 0
N_FRAMES = 400
# `GATE_THRESHOLD` is not needed here, since 0.5 is the default value in `model.infer()`
# GATE_THRESHOLD = 0.5
# Read the configuration
with open('config.json') as f:
data = f.read()
config = json.loads(data)
data_config = config["data_config"]
model_config = config["model_config"]
# avoids nonsense of cudnn
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False # do NOT use cudnn auto-tuner to find the best algorithm to use for your hardware
# Get the models
waveglow, model, trainset = get_models_and_train(SEED, WAVEGLOW_PATH, FLOWTRON_PATH, model_config, data_config)
# Prepare speaker, text, utterance text
speaker_vecs = trainset.get_speaker_id(SPEAKER_ID).cuda()
speaker_vecs = speaker_vecs[None]
##################
# EXPERIMENTS
##################
SIGMA = 0.5
##########################
# Experiment 1: variation
text = prepare_text('How much variation is there?')
for sigma in [0.0, 0.5, 1.0]:
filename = 'results/variation_%.1f.wav' % sigma
synthesize_text(text, sigma, filename)
#################################################
# Experiment 2: effect of (missing) interpunction
sentences = {
'well_known': 'It is well known that deep generative models have a rich latent space.',
'dogs': 'Dogs are sitting by the door.',
}
for effect, sentence in sentences.items():
text = prepare_text(sentence)
synthesize_text(text, SIGMA, 'results/%s_with_dot.wav' % effect)
text_without_dot = prepare_text(sentence.replace('.', ''))
synthesize_text(text_without_dot, SIGMA, 'results/%s_no_dot.wav' % effect)
##########################################
# Experiment 3: Emotional prosody transfer
text = prepare_text('Humans are walking on the street.')
synthesize_text(text, SIGMA, 'results/baseline_humans.wav') # Synthesize baseline
# Taken from karkirowle's gist:
with torch.no_grad():
# Start with empty residual
residual_accumulator = torch.zeros((1, 80, N_FRAMES)).to("cuda")
# loading mel spectra, in_lens, out_lens?
audio, _ = load_wav_to_torch('data/ravdess_surprised_prior.wav')
mel = trainset.get_mel(audio).to(device="cuda")
# You need to pad this because of the permute
mel = mel[None]
# Out lens describes the length of the output, i.e. the audio
out_lens = torch.LongTensor(1).to(device="cuda")
out_lens[0] = mel.size(2)
utterance_text = prepare_text('Kids are talking by the door.')
# In lens describes the length of the text input
in_lens = torch.LongTensor([utterance_text.shape[1]]).to(device="cuda")
# Compute residual in Mel spectrogram between the speaker recording and estimated using utterance text
residual, _, _, _, _, _, _ = model.forward(mel, speaker_vecs, utterance_text, in_lens, out_lens)
# Reorder dimensions
residual = residual.permute(1, 2, 0)
# Make sure that the 3rd dimension of the residual has up to `N_FRAMES` elements
residual = residual[:, :, :N_FRAMES]
# Fill the vector with copies of the residual if it contains less than `N_FRAMES` elements
if residual.shape[2] < N_FRAMES:
num_tile = int(np.ceil(N_FRAMES / residual.shape[2]))
# I used tiling instead of replication
residual = tile(residual.cpu(), 2, num_tile).to("cuda")
# Compute the residual sum
residual_accumulator = residual_accumulator + residual[:, :, :N_FRAMES]
#####################################
# Sample without averaging over time
#####################################
dist = Normal(residual_accumulator, SIGMA)
z_style = dist.sample()
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
mels, attentions = model.infer(z_style, speaker_vecs, text)
mels_to_wav(waveglow, mels, 'results/surprised_humans_transfer_without_time_avg.wav', data_config['sampling_rate'])
#################################
# Sample WITH averaging over time
#################################
residual_accumulator = residual_accumulator.mean(dim=2)
dist = Normal(residual_accumulator, SIGMA)
z_style = dist.sample((N_FRAMES,)).permute(1, 2, 0)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
mels, attentions = model.infer(z_style, speaker_vecs, text)
mels_to_wav(waveglow, mels, 'results/surprised_humans_transfer_time_avg.wav', data_config['sampling_rate'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment