Created
July 30, 2020 09:21
-
-
Save polvanrijn/62462965e1823002892b6329a02a31f9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import numpy as np | |
import torch | |
import json | |
from glob import glob | |
from flowtron import Flowtron | |
from data import Data | |
from torch.distributions import Normal | |
from data import load_wav_to_torch | |
sys.path.insert(0, "tacotron2") | |
sys.path.insert(0, "tacotron2/waveglow") | |
from glow import WaveGlow | |
from scipy.io.wavfile import write | |
def get_models_and_train(seed, waveglow_path, flowtron_path, model_config, data_config): | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
# load waveglow | |
waveglow = torch.load(waveglow_path)['model'].cuda().eval() | |
waveglow.cuda().half() | |
for k in waveglow.convinv: | |
k.float() | |
waveglow.eval() | |
# load flowtron | |
model = Flowtron(**model_config).cuda() | |
state_dict = torch.load(flowtron_path, map_location='cpu')['state_dict'] | |
model.load_state_dict(state_dict) | |
model.eval() | |
print("Loaded checkpoint '{}')".format(flowtron_path)) | |
ignore_keys = ['training_files', 'validation_files'] | |
trainset = Data( | |
data_config['training_files'], | |
**dict((k, v) for k, v in data_config.items() if k not in ignore_keys)) | |
return waveglow, model, trainset | |
def mels_to_wav(waveglow, mels, filepath, sampling_rate): | |
# Taken from inference.py | |
# TODO why sigma 0.8 (also in inference.py)? In the paper it says: "During inference we used sigma = 0.7" (p. 4) | |
audio = waveglow.infer(mels.half(), sigma=0.8).float() | |
audio = audio.cpu().numpy()[0] | |
# normalize audio for now | |
audio = audio / np.abs(audio).max() | |
write(filepath, sampling_rate, audio) | |
def tile(a, dim, n_tile): | |
init_dim = a.size(dim) | |
repeat_idx = [1] * a.dim() | |
repeat_idx[dim] = n_tile | |
a = a.repeat(*(repeat_idx)) | |
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) | |
return torch.index_select(a, dim, order_index) | |
def prepare_text(text): | |
# Taken from inference.py | |
text = trainset.get_text(text).cuda() | |
return text[None] | |
def synthesize_text(text, sigma, filename): | |
torch.manual_seed(SEED) | |
torch.cuda.manual_seed(SEED) | |
# Taken from inference.py | |
with torch.no_grad(): | |
residual = torch.cuda.FloatTensor(1, 80, N_FRAMES).normal_() * sigma | |
mels, attentions = model.infer(residual, speaker_vecs, text) | |
mels_to_wav(waveglow, mels, filename, data_config['sampling_rate']) | |
################## | |
# Default settings | |
################## | |
FLOWTRON_PATH = 'models/flowtron_ljs.pt' | |
WAVEGLOW_PATH = 'models/waveglow_256channels_v4.pt' | |
SEED = 1234 | |
SPEAKER_ID = 0 | |
N_FRAMES = 400 | |
# `GATE_THRESHOLD` is not needed here, since 0.5 is the default value in `model.infer()` | |
# GATE_THRESHOLD = 0.5 | |
# Read the configuration | |
with open('config.json') as f: | |
data = f.read() | |
config = json.loads(data) | |
data_config = config["data_config"] | |
model_config = config["model_config"] | |
# avoids nonsense of cudnn | |
torch.backends.cudnn.enabled = True | |
torch.backends.cudnn.benchmark = False # do NOT use cudnn auto-tuner to find the best algorithm to use for your hardware | |
# Get the models | |
waveglow, model, trainset = get_models_and_train(SEED, WAVEGLOW_PATH, FLOWTRON_PATH, model_config, data_config) | |
# Prepare speaker, text, utterance text | |
speaker_vecs = trainset.get_speaker_id(SPEAKER_ID).cuda() | |
speaker_vecs = speaker_vecs[None] | |
########################################## | |
# Replication of surprise | |
########################################## | |
SIGMA = 0.5 | |
text = prepare_text('Humans are walking on the street.') | |
synthesize_text(text, SIGMA, 'results/baseline_humans.wav') # Synthesize baseline | |
# Taken from karkirowle's gist: | |
with torch.no_grad(): | |
# Start with empty residual | |
residual_accumulator = torch.zeros((1, 80, N_FRAMES)).to("cuda") | |
sentence1 = glob("data/RAVDESS/*01-0[1-2]-24.wav") | |
utterance_text1 = prepare_text('Kids are talking by the door.') | |
sentence2 = glob("data/RAVDESS/*02-0[1-2]-24.wav") | |
utterance_text2 = prepare_text('Dogs are sitting by the door.') | |
# Get all filenames | |
files = sentence1 | |
files.extend(sentence2) | |
# Duplicate utterance texts | |
utterances = [utterance_text1]*4 | |
utterances.extend([utterance_text2]*4) | |
for idx, file in enumerate(files): | |
utterance_text = utterances[idx] | |
# loading mel spectra, in_lens, out_lens? | |
audio, _ = load_wav_to_torch(file) | |
mel = trainset.get_mel(audio).to(device="cuda") | |
# You need to pad this because of the permute | |
mel = mel[None] | |
# Out lens describes the length of the output, i.e. the audio | |
out_lens = torch.LongTensor(1).to(device="cuda") | |
out_lens[0] = mel.size(2) | |
# In lens describes the length of the text input | |
in_lens = torch.LongTensor([utterance_text.shape[1]]).to(device="cuda") | |
# Compute residual in Mel spectrogram between the speaker recording and estimated using utterance text | |
residual, _, _, _, _, _, _ = model.forward(mel, speaker_vecs, utterance_text, in_lens, out_lens) | |
# Reorder dimensions | |
residual = residual.permute(1, 2, 0) | |
# Make sure that the 3rd dimension of the residual has up to `N_FRAMES` elements | |
residual = residual[:, :, :N_FRAMES] | |
# Fill the vector with copies of the residual if it contains less than `N_FRAMES` elements | |
if residual.shape[2] < N_FRAMES: | |
num_tile = int(np.ceil(N_FRAMES / residual.shape[2])) | |
# I used tiling instead of replication | |
residual = tile(residual.cpu(), 2, num_tile).to("cuda") | |
# Compute the residual sum | |
residual_accumulator = residual_accumulator + residual[:, :, :N_FRAMES] | |
######### | |
# Compute the average residual | |
residual_accumulator = residual_accumulator / len(files) | |
##################################### | |
# Sample without averaging over time | |
##################################### | |
dist = Normal(residual_accumulator, SIGMA) | |
z_style = dist.sample() | |
torch.manual_seed(SEED) | |
torch.cuda.manual_seed(SEED) | |
mels, attentions = model.infer(z_style, speaker_vecs, text) | |
mels_to_wav(waveglow, mels, 'results/surprised_speaker24_humans_transfer_without_time_avg.wav', data_config['sampling_rate']) | |
################################# | |
# Sample WITH averaging over time | |
################################# | |
residual_accumulator = residual_accumulator.mean(dim=2) | |
dist = Normal(residual_accumulator, SIGMA) | |
z_style = dist.sample((N_FRAMES,)).permute(1, 2, 0) | |
torch.manual_seed(SEED) | |
torch.cuda.manual_seed(SEED) | |
mels, attentions = model.infer(z_style, speaker_vecs, text) | |
mels_to_wav(waveglow, mels, 'results/surprised_speaker24_humans_transfer_time_avg.wav', data_config['sampling_rate']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment