Skip to content

Instantly share code, notes, and snippets.

@melMass
Created May 29, 2023 23:14
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save melMass/86f164d0d5911c2f7a0a18f4dc66f49a to your computer and use it in GitHub Desktop.
Save melMass/86f164d0d5911c2f7a0a18f4dc66f49a to your computer and use it in GitHub Desktop.
# mostly extracted from
# https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
# and
# https://github.com/gitmylo/audio-webui
from bark.generation import load_codec_model, generate_text_semantic
from encodec.utils import convert_audio
import torchaudio
import torch
import numpy as np
import os
# Load HuBERT for semantic tokens
from hubert.pre_kmeans_hubert import CustomHubert
from hubert.customtokenizer import CustomTokenizer
import rich
from pathlib import Path
import urllib.request
import huggingface_hub
HERE = Path(__file__).parent
SAMPLE_TXT = "Hello, my name is Bark. And, uh — and I like pizza. [laughs]"
class HuBERTManager:
@staticmethod
def ensure_directory():
install_dir = HERE / "data" / "models" / "hubert"
if not install_dir.exists():
install_dir.mkdir(parents=True, exist_ok=True)
return install_dir
@staticmethod
def make_sure_hubert_installed(
download_url: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt", file_name: str = "hubert.pt"
):
install_dir = HuBERTManager.ensure_directory()
install_file = install_dir / file_name
if not install_file.is_file():
rich.print("[bold]Downloading HuBERT base model[/bold]")
urllib.request.urlretrieve(download_url, install_file.as_posix())
print("[green]Downloaded HuBERT[/green]")
return install_file
@staticmethod
def make_sure_tokenizer_installed(
model: str = "quantifier_hubert_base_ls960_14.pth",
repo: str = "GitMylo/bark-voice-cloning",
local_file: str = "tokenizer.pth",
):
install_dir = HuBERTManager.ensure_directory()
install_file = install_dir / local_file
if not install_file.is_file():
print("[bold]Downloading HuBERT custom tokenizer[/bold]")
huggingface_hub.hf_hub_download(repo, model, local_dir=install_dir, local_dir_use_symlinks=False)
(install_dir / model).rename(install_file)
print("[green]Downloaded tokenizer[/green]")
return install_file
def do_inference(
text, speaker, output_path, long=False, text_temp=0.7, waveform_temp=0.7, output_full=True, allow_early_stop=True
):
from scipy.io.wavfile import write as write_wav
from bark.api import generate_audio
from transformers import BertTokenizer
from bark.generation import (
SAMPLE_RATE,
preload_models,
codec_decode,
generate_coarse,
generate_fine,
generate_text_semantic,
)
# download and load all models
preload_models(
text_use_gpu=True,
text_use_small=False,
coarse_use_gpu=True,
coarse_use_small=False,
fine_use_gpu=True,
fine_use_small=False,
codec_use_gpu=True,
force_reload=False,
)
audio_array = None
if long:
import nltk
sentences = nltk.sent_tokenize(text)
silence = np.zeros(int(0.15 * SAMPLE_RATE)) # quarter second of silence
pieces = []
for i, sentence in enumerate(sentences):
rich.print(f"[bold blue]Generating sentence {i+1}/{len(sentences)}[/bold blue]")
# generation with more control
x_semantic = generate_text_semantic(
sentence,
history_prompt=speaker,
temp=text_temp,
top_k=50,
top_p=0.95,
use_kv_caching=True,
allow_early_stop=allow_early_stop,
)
x_coarse_gen = generate_coarse(
x_semantic,
history_prompt=speaker,
temp=waveform_temp,
top_k=50,
top_p=0.95,
use_kv_caching=True,
)
x_fine_gen = generate_fine(x_coarse_gen, history_prompt=speaker, temp=0.5)
pieces.append(codec_decode(x_fine_gen))
audio_array = np.concatenate(pieces)
else:
# generation with more control
x_semantic = generate_text_semantic(
text,
history_prompt=speaker,
temp=text_temp,
top_k=50,
top_p=0.95,
use_kv_caching=True,
allow_early_stop=allow_early_stop,
)
x_coarse_gen = generate_coarse(
x_semantic,
history_prompt=speaker,
temp=waveform_temp,
top_k=50,
top_p=0.95,
use_kv_caching=True,
)
x_fine_gen = generate_fine(x_coarse_gen, history_prompt=speaker, temp=0.5)
audio_array = codec_decode(x_fine_gen)
write_wav(output_path, SAMPLE_RATE, audio_array)
rich.print(f"Saved audio to [bold blue]{output_path}[/bold blue]")
def main(input_sound, dest_path, voice_name, count=1, text=None):
# rich print with all args semantically highlighted each with a different color
rich.print(
f"Cloning voice [bold cyan]{voice_name}[/bold cyan] from [bold yellow]{input_sound}[/bold yellow] to [bold blue]{dest_path}[/bold blue] for [bold magenta]{count}[/bold magenta] runs"
)
# make dirs for dest_path
if not os.path.exists(dest_path):
os.makedirs(dest_path)
rich.print(text)
device = "cuda" # or 'cpu'
model = load_codec_model(use_gpu=True)
hubert_manager = HuBERTManager()
hubert_manager.make_sure_hubert_installed()
hubert_manager.make_sure_tokenizer_installed()
rich.print("[bold yellow]Loading HuBERT model[/bold yellow]")
# Load the HuBERT model
hubert_model = CustomHubert(checkpoint_path=(HERE / "data" / "models" / "hubert" / "hubert.pt").as_posix()).to(
device
)
# Load the CustomTokenizer model
tokenizer = CustomTokenizer.load_from_checkpoint(
(HERE / "data" / "models" / "hubert" / "tokenizer.pth").as_posix()
).to(
device
) # Automatically uses the right layers
rich.print("[bold yellow]Loading audio waveform[/bold yellow]")
if input_sound:
# Load and pre-process the audio waveform
audio_filepath = input_sound # the audio you want to clone (under 13 seconds)
wav, sr = torchaudio.load(audio_filepath)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device)
rich.print(f"[bold yellow]Generating semantic tokens ({count} runs)[/bold yellow]")
for i in range(count):
iteration = i + 1
output_path = None
if input_sound:
semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)
semantic_tokens = tokenizer.get_token(semantic_vectors)
rich.print("[bold yellow]Generating discrete codes[/bold yellow]")
# Extract discrete codes from EnCodec
with torch.no_grad():
encoded_frames = model.encode(wav.unsqueeze(0))
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
rich.print("[bold yellow]Moving to CPU (codes and semantics)[/bold yellow]")
# move codes to cpu
codes = codes.cpu().numpy()
# move semantic tokens to cpu
semantic_tokens = semantic_tokens.cpu().numpy()
output_path = (Path(dest_path) / f"{voice_name}_{iteration:02d}.npz").as_posix()
rich.print(
f"[bold yellow]Saving voice to disk ({iteration}/{count})[/bold yellow]: [bold blue]{output_path}[/bold blue]"
)
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
else:
output_path = Path(voice_name).resolve().as_posix()
if text:
rich.print(
f"[bold yellow]Generating audio from text[/bold yellow]: [bold blue]{text}[/bold blue] with voice [bold blue]{voice_name}[/bold blue]"
)
wav_dest = (Path(dest_path) / f"{Path(voice_name).stem}_{iteration:02d}.wav").as_posix()
rich.print(f"[bold yellow]Saving audio to disk[/bold yellow]: [bold blue]{wav_dest}[/bold blue]")
do_inference(text, output_path, wav_dest, True)
if __name__ == "__main__":
import argparse
from rich_argparse import ArgumentDefaultsRichHelpFormatter
parser = argparse.ArgumentParser(description="Voice cloner", formatter_class=ArgumentDefaultsRichHelpFormatter)
parser.add_argument("audio", type=str, help="Audio file to clone", nargs="?")
parser.add_argument("--output", type=str, help="Output voice file", default=(HERE / "data" / "prompts").as_posix())
parser.add_argument(
"--voice", type=str, help="Voice name, just the name to save it and a full nzp path to load", required=True
)
parser.add_argument(
"--text", type=str, help="If provided will also generate a wav from the clones voice", default=None
)
parser.add_argument("--text-file", type=str, help="Path to a text file to use as input", default=None)
parser.add_argument("--count", type=int, help="Number of times to clone", default=1)
args = parser.parse_args()
voice = HERE / "data" / "prompts" / f"{args.voice}.npz"
txt = args.text
# validate args
if args.audio is None and (args.text is None and args.text_file is None):
rich.print(
"[bold red]Error:[/bold red] You must provide either an audio file, a string or a text file to generate from"
)
rich.print(args)
exit(1)
if args.text_file:
with open(args.text_file, "r", encoding="utf-8") as f:
txt = f.read()
if args.text and args.text_file:
rich.print("[yellow]Warning:[/yellow] You provided both a text and a text file, the text file will be used")
if Path(args.voice).exists():
if args.audio:
rich.print(
"[bold red]Error:[/bold red] Please only use the voice name if you want to generate audio from text. If you want to clone an audio file, do not provide a path to a voice file, but only a name."
)
rich.print(args)
exit(1)
elif args.text or args.text_file:
if not voice.exists() and not args.audio:
rich.print(
"[bold red]Error:[/bold red] The provided voice name does not exist. Please provide an audio file to clone instead."
)
exit(1)
main(args.audio, args.output, voice if voice.exists() else args.voice, args.count, txt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment