Skip to content

Instantly share code, notes, and snippets.

@ichabodcole
Created January 6, 2024 02:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ichabodcole/1c0d19ef4c33b7b5705b0860c7c27f7b to your computer and use it in GitHub Desktop.
Save ichabodcole/1c0d19ef4c33b7b5705b0860c7c27f7b to your computer and use it in GitHub Desktop.
Combining XTTS speaker embeddings
import torch
from torch import Tensor
from TTS.api import TTS
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import utils
from utils import CombineMethod
from pydub import AudioSegment
from typing import List
config = XttsConfig()
config.load_json('./tts/tts_models--multilingual--multi-dataset--xtts_v2/config.json')
model = Xtts.init_from_config(config)
checkpoint_dir = './tts/tts_models--multilingual--multi-dataset--xtts_v2'
model.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=True)
model.cuda() if torch.cuda.is_available() else model.cpu()
speaker_data = torch.load('./tts/tts_models--multilingual--multi-dataset--xtts_v2/speakers_xtts.pth')
daisy = list(speaker_data['Daisy Studious'].values())
henriette = list(speaker_data['Henriette Usha'].values())
baldur = list(speaker_data['Baldur Sanjin'].values())
speakers = [daisy, henriette, baldur]
def generate_speech(speakers, combine_method: CombineMethod, speaker_weights: List | None = None):
avg_gpt_cond_latents, avg_speaker_embedding = utils.average_latents_and_embeddings(speakers, combine_method, speaker_weights)
text = "This ascent represents your connection to the universe, your consciousness expanding to embrace the infinite."
out = model.inference(
text=text,
language="en",
gpt_cond_latent=avg_gpt_cond_latents,
speaker_embedding=avg_speaker_embedding,
temperature=0.7,
speed=1.0
)
twav = Tensor(out['wav'])
numpy_wav = utils.tensor_to_numpy_array(twav)
audio_segment = AudioSegment(
data=numpy_wav.tobytes(), # Convert the array to bytes
sample_width=2, # 2 bytes (16 bits) per sample
frame_rate=24000, # Sample rate
channels=1 # Mono audio
)
weights = "-".join(map(str, speaker_weights))
audio_segment.export(f'./output/speaker-mix_method_{combine_method.value}_weight_{weights}.mp3', format="mp3")
generate_speech(speakers, CombineMethod.SUM, speaker_weights=[1, 2, 2])
import torch
from torch import Tensor
import numpy as np
from enum import Enum
from typing import List
class CombineMethod(Enum):
MEAN = 'mean'
SUM = 'sum'
MEDIAN = 'median'
MAX = 'max'
MIN = 'min'
NORMALIZED_SUM = 'normalized_sum'
def tensor_to_numpy_array(tensor: Tensor) -> np.ndarray:
tensor = tensor.cpu().detach()
numpy_array = tensor.numpy()
return (numpy_array * np.iinfo(np.int16).max).astype(np.int16)
def average_latents_and_embeddings(latent_embedding_pairs, combine_method: CombineMethod = CombineMethod.MEAN, speaker_weights: List | None = None):
"""
Averages a list of (gpt_cond_latents, speaker_embedding) pairs.
Args:
latent_embedding_pairs (list of tuples): A list where each element is a tuple containing gpt_cond_latents and speaker_embedding.
Returns:
tuple: A tuple containing the averaged gpt_cond_latents and speaker_embedding.
"""
# Separate gpt_cond_latents and speaker_embeddings
gpt_cond_latents_list = [pair[0] for pair in latent_embedding_pairs]
speaker_embeddings_list = [pair[1] for pair in latent_embedding_pairs]
# Average gpt_cond_latents
avg_gpt_cond_latents = combine_embeddings(gpt_cond_latents_list, combine_method, speaker_weights)
# Average speaker_embeddings
avg_speaker_embedding = combine_embeddings(speaker_embeddings_list, combine_method, speaker_weights)
return avg_gpt_cond_latents, avg_speaker_embedding
def combine_embeddings(embeddings, method, weights: List | None = None):
if weights == None:
weights = [1 for _ in embeddings]
if len(weights) != len(embeddings):
raise ValueError("Weights match the number of embeddings for weighted average.")
weighted_embeddings = [embedding * weight for embedding, weight in zip(embeddings, weights)]
if method == CombineMethod.MEAN:
return torch.mean(torch.stack(weighted_embeddings), dim=0)
elif method == CombineMethod.SUM:
return torch.sum(torch.stack(weighted_embeddings), dim=0)
elif method == CombineMethod.MEDIAN:
return torch.median(torch.stack(weighted_embeddings), dim=0).values
elif method == CombineMethod.MAX:
return torch.max(torch.stack(weighted_embeddings), dim=0).values
elif method == CombineMethod.MIN:
return torch.min(torch.stack(weighted_embeddings), dim=0).values
elif method == CombineMethod.NORMALIZED_SUM:
normalized = [embedding / torch.norm(embedding) for embedding in weighted_embeddings]
return torch.sum(torch.stack(normalized), dim=0)
else:
raise ValueError("Invalid combine method specified.")
def normalize_weights(weights):
total = sum(weights)
if total == 0:
raise ValueError("Sum of weights cannot be zero.")
return [w / total for w in weights]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment