Skip to content

Instantly share code, notes, and snippets.

@polvanrijn
Created August 11, 2020 16:21
Show Gist options
  • Save polvanrijn/059601cb6dcb16a014121e1b01c018b0 to your computer and use it in GitHub Desktop.
Save polvanrijn/059601cb6dcb16a014121e1b01c018b0 to your computer and use it in GitHub Desktop.
Flattened version of inference_style_transfer.ipynb
# Imports
import json
import sys
import torch
from torch.distributions import Normal
from flowtron import Flowtron
from data import Data
from train import update_params
sys.path.insert(0, "tacotron2")
sys.path.insert(0, "tacotron2/waveglow")
from denoiser import Denoiser
# LOAD FLOWTRON
config_path = "config.json"
# These params are new
params = ["model_config.dummy_speaker_embedding=0",
"data_config.p_arpabet=1.0"]
with open(config_path) as f:
data = f.read()
config = json.loads(data)
update_params(config, params)
data_config = config["data_config"]
model_config = config["model_config"]
## LOAD MODEL
model_path = "models/flowtron_ljs.pt"
state_dict = torch.load(model_path, map_location='cpu')['state_dict']
model = Flowtron(**model_config)
model.load_state_dict(state_dict)
_ = model.eval().cuda()
### LOAD WAVEGLOW
# Uses a newer waveglow model
waveglow_path = 'models/waveglow_256channels_universal_v5.pt'
waveglow = torch.load(waveglow_path)['model']
_ = waveglow.eval().cuda()
denoiser = Denoiser(waveglow).cuda().eval()
### PREPARE DATALOADER
dataset_path = 'data/surprised_samples/surprised_audiofilelist_text.txt'
dataset = Data(
dataset_path,
**dict((k, v) for k, v in data_config.items() if k not in ['training_files', 'validation_files']))
for iteration in range(10):
### COLLECT Z VALUES
z_values = []
force_speaker_id = 0
for i in range(len(dataset)):
mel, sid, text = dataset[i]
mel, sid, text = mel[None].cuda(), sid.cuda(), text[None].cuda()
if force_speaker_id > -1:
sid = sid * 0 + force_speaker_id
in_lens = torch.LongTensor([text.shape[1]]).cuda()
with torch.no_grad():
z = model(mel, sid, text, in_lens, None)[0]
z_values.append(z.permute(1, 2, 0))
### COMPUTE POSTERIOR
lambd = 0.0001
sigma = 1.
n_frames = 300
aggregation_type = 'batch'
if aggregation_type == 'time_and_batch':
z_mean = torch.cat([z.mean(dim=2) for z in z_values])
z_mean = torch.mean(z_mean, dim=0)[:, None]
ratio = len(z_values) / lambd
mu_posterior = (ratio * z_mean / (ratio + 1))
elif aggregation_type == 'batch':
for k in range(len(z_values)):
expand = z_values[k]
while expand.size(2) < n_frames:
expand = torch.cat((expand, z_values[k]), 2)
z_values[k] = expand[:, :, :n_frames]
z_mean = torch.mean(torch.cat(z_values, dim=0), dim=0)[None]
z_mean_size = z_mean.size()
z_mean = z_mean.flatten()
ratio = len(z_values) / float(lambd)
mu_posterior = (ratio * z_mean / (ratio + 1)).flatten()
mu_posterior = mu_posterior.view(80, -1)
print(ratio)
dist = Normal(mu_posterior.cpu(), sigma)
### Z BASELINE
z_baseline = torch.FloatTensor(1, 80, n_frames).cuda().normal_() * sigma
if aggregation_type == 'time_and_batch':
z_posterior = dist.sample([n_frames]).permute(2,1,0).cuda()
elif aggregation_type == 'batch':
z_posterior = dist.sample().view(1, 80, -1)[..., :n_frames].cuda()
text = "Humans are walking on the streets?"
text_encoded = dataset.get_text(text).cuda()[None]
#### Perform inference sampling the posterior and a standard gaussian baseline
speaker = 0
speaker_id = torch.LongTensor([speaker]).cuda()
with torch.no_grad():
mel_posterior = model.infer(z_posterior, speaker_id, text_encoded)[0]
mel_baseline = model.infer(z_baseline, speaker_id, text_encoded)[0]
#### Posterior sample
import librosa
with torch.no_grad():
audio = denoiser(waveglow.infer(mel_posterior, sigma=0.75), 0.01)
librosa.output.write_wav("results/posterior_%d.wav" % (iteration + 1), audio[0].data.cpu().numpy().T, data_config['sampling_rate'])
#### Baseline sample
with torch.no_grad():
audio = denoiser(waveglow.infer(mel_baseline, sigma=0.75), 0.01)
librosa.output.write_wav("results/baseline_%d.wav" % (iteration + 1), audio[0].data.cpu().numpy().T, data_config['sampling_rate'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment