Created
May 29, 2023 23:14
-
-
Save melMass/86f164d0d5911c2f7a0a18f4dc66f49a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# mostly extracted from | |
# https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer | |
# and | |
# https://github.com/gitmylo/audio-webui | |
from bark.generation import load_codec_model, generate_text_semantic | |
from encodec.utils import convert_audio | |
import torchaudio | |
import torch | |
import numpy as np | |
import os | |
# Load HuBERT for semantic tokens | |
from hubert.pre_kmeans_hubert import CustomHubert | |
from hubert.customtokenizer import CustomTokenizer | |
import rich | |
from pathlib import Path | |
import urllib.request | |
import huggingface_hub | |
HERE = Path(__file__).parent | |
SAMPLE_TXT = "Hello, my name is Bark. And, uh — and I like pizza. [laughs]" | |
class HuBERTManager: | |
@staticmethod | |
def ensure_directory(): | |
install_dir = HERE / "data" / "models" / "hubert" | |
if not install_dir.exists(): | |
install_dir.mkdir(parents=True, exist_ok=True) | |
return install_dir | |
@staticmethod | |
def make_sure_hubert_installed( | |
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", file_name: str = "hubert.pt" | |
): | |
install_dir = HuBERTManager.ensure_directory() | |
install_file = install_dir / file_name | |
if not install_file.is_file(): | |
rich.print("[bold]Downloading HuBERT base model[/bold]") | |
urllib.request.urlretrieve(download_url, install_file.as_posix()) | |
print("[green]Downloaded HuBERT[/green]") | |
return install_file | |
@staticmethod | |
def make_sure_tokenizer_installed( | |
model: str = "quantifier_hubert_base_ls960_14.pth", | |
repo: str = "GitMylo/bark-voice-cloning", | |
local_file: str = "tokenizer.pth", | |
): | |
install_dir = HuBERTManager.ensure_directory() | |
install_file = install_dir / local_file | |
if not install_file.is_file(): | |
print("[bold]Downloading HuBERT custom tokenizer[/bold]") | |
huggingface_hub.hf_hub_download(repo, model, local_dir=install_dir, local_dir_use_symlinks=False) | |
(install_dir / model).rename(install_file) | |
print("[green]Downloaded tokenizer[/green]") | |
return install_file | |
def do_inference( | |
text, speaker, output_path, long=False, text_temp=0.7, waveform_temp=0.7, output_full=True, allow_early_stop=True | |
): | |
from scipy.io.wavfile import write as write_wav | |
from bark.api import generate_audio | |
from transformers import BertTokenizer | |
from bark.generation import ( | |
SAMPLE_RATE, | |
preload_models, | |
codec_decode, | |
generate_coarse, | |
generate_fine, | |
generate_text_semantic, | |
) | |
# download and load all models | |
preload_models( | |
text_use_gpu=True, | |
text_use_small=False, | |
coarse_use_gpu=True, | |
coarse_use_small=False, | |
fine_use_gpu=True, | |
fine_use_small=False, | |
codec_use_gpu=True, | |
force_reload=False, | |
) | |
audio_array = None | |
if long: | |
import nltk | |
sentences = nltk.sent_tokenize(text) | |
silence = np.zeros(int(0.15 * SAMPLE_RATE)) # quarter second of silence | |
pieces = [] | |
for i, sentence in enumerate(sentences): | |
rich.print(f"[bold blue]Generating sentence {i+1}/{len(sentences)}[/bold blue]") | |
# generation with more control | |
x_semantic = generate_text_semantic( | |
sentence, | |
history_prompt=speaker, | |
temp=text_temp, | |
top_k=50, | |
top_p=0.95, | |
use_kv_caching=True, | |
allow_early_stop=allow_early_stop, | |
) | |
x_coarse_gen = generate_coarse( | |
x_semantic, | |
history_prompt=speaker, | |
temp=waveform_temp, | |
top_k=50, | |
top_p=0.95, | |
use_kv_caching=True, | |
) | |
x_fine_gen = generate_fine(x_coarse_gen, history_prompt=speaker, temp=0.5) | |
pieces.append(codec_decode(x_fine_gen)) | |
audio_array = np.concatenate(pieces) | |
else: | |
# generation with more control | |
x_semantic = generate_text_semantic( | |
text, | |
history_prompt=speaker, | |
temp=text_temp, | |
top_k=50, | |
top_p=0.95, | |
use_kv_caching=True, | |
allow_early_stop=allow_early_stop, | |
) | |
x_coarse_gen = generate_coarse( | |
x_semantic, | |
history_prompt=speaker, | |
temp=waveform_temp, | |
top_k=50, | |
top_p=0.95, | |
use_kv_caching=True, | |
) | |
x_fine_gen = generate_fine(x_coarse_gen, history_prompt=speaker, temp=0.5) | |
audio_array = codec_decode(x_fine_gen) | |
write_wav(output_path, SAMPLE_RATE, audio_array) | |
rich.print(f"Saved audio to [bold blue]{output_path}[/bold blue]") | |
def main(input_sound, dest_path, voice_name, count=1, text=None): | |
# rich print with all args semantically highlighted each with a different color | |
rich.print( | |
f"Cloning voice [bold cyan]{voice_name}[/bold cyan] from [bold yellow]{input_sound}[/bold yellow] to [bold blue]{dest_path}[/bold blue] for [bold magenta]{count}[/bold magenta] runs" | |
) | |
# make dirs for dest_path | |
if not os.path.exists(dest_path): | |
os.makedirs(dest_path) | |
rich.print(text) | |
device = "cuda" # or 'cpu' | |
model = load_codec_model(use_gpu=True) | |
hubert_manager = HuBERTManager() | |
hubert_manager.make_sure_hubert_installed() | |
hubert_manager.make_sure_tokenizer_installed() | |
rich.print("[bold yellow]Loading HuBERT model[/bold yellow]") | |
# Load the HuBERT model | |
hubert_model = CustomHubert(checkpoint_path=(HERE / "data" / "models" / "hubert" / "hubert.pt").as_posix()).to( | |
device | |
) | |
# Load the CustomTokenizer model | |
tokenizer = CustomTokenizer.load_from_checkpoint( | |
(HERE / "data" / "models" / "hubert" / "tokenizer.pth").as_posix() | |
).to( | |
device | |
) # Automatically uses the right layers | |
rich.print("[bold yellow]Loading audio waveform[/bold yellow]") | |
if input_sound: | |
# Load and pre-process the audio waveform | |
audio_filepath = input_sound # the audio you want to clone (under 13 seconds) | |
wav, sr = torchaudio.load(audio_filepath) | |
wav = convert_audio(wav, sr, model.sample_rate, model.channels) | |
wav = wav.to(device) | |
rich.print(f"[bold yellow]Generating semantic tokens ({count} runs)[/bold yellow]") | |
for i in range(count): | |
iteration = i + 1 | |
output_path = None | |
if input_sound: | |
semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate) | |
semantic_tokens = tokenizer.get_token(semantic_vectors) | |
rich.print("[bold yellow]Generating discrete codes[/bold yellow]") | |
# Extract discrete codes from EnCodec | |
with torch.no_grad(): | |
encoded_frames = model.encode(wav.unsqueeze(0)) | |
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T] | |
rich.print("[bold yellow]Moving to CPU (codes and semantics)[/bold yellow]") | |
# move codes to cpu | |
codes = codes.cpu().numpy() | |
# move semantic tokens to cpu | |
semantic_tokens = semantic_tokens.cpu().numpy() | |
output_path = (Path(dest_path) / f"{voice_name}_{iteration:02d}.npz").as_posix() | |
rich.print( | |
f"[bold yellow]Saving voice to disk ({iteration}/{count})[/bold yellow]: [bold blue]{output_path}[/bold blue]" | |
) | |
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens) | |
else: | |
output_path = Path(voice_name).resolve().as_posix() | |
if text: | |
rich.print( | |
f"[bold yellow]Generating audio from text[/bold yellow]: [bold blue]{text}[/bold blue] with voice [bold blue]{voice_name}[/bold blue]" | |
) | |
wav_dest = (Path(dest_path) / f"{Path(voice_name).stem}_{iteration:02d}.wav").as_posix() | |
rich.print(f"[bold yellow]Saving audio to disk[/bold yellow]: [bold blue]{wav_dest}[/bold blue]") | |
do_inference(text, output_path, wav_dest, True) | |
if __name__ == "__main__": | |
import argparse | |
from rich_argparse import ArgumentDefaultsRichHelpFormatter | |
parser = argparse.ArgumentParser(description="Voice cloner", formatter_class=ArgumentDefaultsRichHelpFormatter) | |
parser.add_argument("audio", type=str, help="Audio file to clone", nargs="?") | |
parser.add_argument("--output", type=str, help="Output voice file", default=(HERE / "data" / "prompts").as_posix()) | |
parser.add_argument( | |
"--voice", type=str, help="Voice name, just the name to save it and a full nzp path to load", required=True | |
) | |
parser.add_argument( | |
"--text", type=str, help="If provided will also generate a wav from the clones voice", default=None | |
) | |
parser.add_argument("--text-file", type=str, help="Path to a text file to use as input", default=None) | |
parser.add_argument("--count", type=int, help="Number of times to clone", default=1) | |
args = parser.parse_args() | |
voice = HERE / "data" / "prompts" / f"{args.voice}.npz" | |
txt = args.text | |
# validate args | |
if args.audio is None and (args.text is None and args.text_file is None): | |
rich.print( | |
"[bold red]Error:[/bold red] You must provide either an audio file, a string or a text file to generate from" | |
) | |
rich.print(args) | |
exit(1) | |
if args.text_file: | |
with open(args.text_file, "r", encoding="utf-8") as f: | |
txt = f.read() | |
if args.text and args.text_file: | |
rich.print("[yellow]Warning:[/yellow] You provided both a text and a text file, the text file will be used") | |
if Path(args.voice).exists(): | |
if args.audio: | |
rich.print( | |
"[bold red]Error:[/bold red] Please only use the voice name if you want to generate audio from text. If you want to clone an audio file, do not provide a path to a voice file, but only a name." | |
) | |
rich.print(args) | |
exit(1) | |
elif args.text or args.text_file: | |
if not voice.exists() and not args.audio: | |
rich.print( | |
"[bold red]Error:[/bold red] The provided voice name does not exist. Please provide an audio file to clone instead." | |
) | |
exit(1) | |
main(args.audio, args.output, voice if voice.exists() else args.voice, args.count, txt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment