Skip to content

Instantly share code, notes, and snippets.

@shivammehta25
Created November 13, 2022 10:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shivammehta25/c2bc3a5a875c268e538edb774733f7e8 to your computer and use it in GitHub Desktop.
Save shivammehta25/c2bc3a5a875c268e538edb774733f7e8 to your computer and use it in GitHub Desktop.
Hosting the models on the server using FastAPI
import json
import sys
sys.path.append('src/model')
sys.path.insert(0, './hifigan')
import logging
import os
from pathlib import Path
from uuid import uuid4
import numpy as np
import soundfile as sf
import torch
import uvicorn
from fastapi import BackgroundTasks, FastAPI
from fastapi.responses import FileResponse
from nltk import word_tokenize
from hifigan.env import AttrDict
from hifigan.models import Generator
from hifigandenoiser import Denoiser
from src.hparams import create_hparams
from src.training_module import TrainingModule
from src.utilities.text import phonetise_text, text_to_sequence
device = "cuda" if torch.cuda.is_available() else "cpu"
app = FastAPI()
logging.basicConfig(filename="log.txt", level=logging.DEBUG,
format="%(asctime)s - %(levelname)s: %(message)s'", filemode="a")
print("[+] Loading Models..")
hparams = create_hparams()
def load_model(checkpoint_path, speaker):
model = TrainingModule.load_from_checkpoint(checkpoint_path)
_ = model.to(device).eval().half()
print(f"[+] Model Loaded: {speaker}")
return model
checkpoint_mandarin= "checkpoints/MandrinRun_Male/checkpoint_105000.ckpt"
checkpoint_arabic = "checkpoints/ArabRun_Male/checkpoint_110000.ckpt"
checkpoint_british = "checkpoints/BritishRun_male/checkpoint_105500.ckpt"
checkpoint_african = "checkpoints/Nigerian2Run_male/checkpoint_111500.ckpt"
model_mandarin = load_model(checkpoint_mandarin, speaker="mandarin")
model_arabic = load_model(checkpoint_arabic, speaker="arabic")
model_british = load_model(checkpoint_british, speaker="british")
model_african = load_model(checkpoint_african, speaker="african")
print("[+] Models Loaded..")
print("[+] Loading HiFi-GAN..")
# load the hifi-gan model
hifigan_loc = 'hifigan/'
config_file = hifigan_loc + 'config_v1.json'
hifi_checkpoint_file = 'g_02500000'
with open(config_file) as f:
data = f.read()
json_config = json.loads(data)
def load_checkpoint(filepath, device):
print(filepath)
assert os.path.isfile(filepath)
print("Loading '{}'".format(filepath))
checkpoint_dict = torch.load(filepath, map_location=device)
print("Complete.")
return checkpoint_dict
h = AttrDict(json_config)
torch.manual_seed(h.seed)
generator = Generator(h).to(device)
state_dict_g = load_checkpoint(hifi_checkpoint_file, device)
generator.load_state_dict(state_dict_g['generator'])
generator.eval().half()
generator.remove_weight_norm()
denoiser = Denoiser(generator, mode='zeros')
print("[+] HiFi-GAN Loaded..")
def text_to_seq(text):
text = phonetise_text(hparams.cmu_phonetiser, text, word_tokenize)
sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :]
sequence = torch.from_numpy(sequence).to(device).long()
return sequence
savepath = Path('temp')
savepath.mkdir(exist_ok=True, parents=True)
def del_file(file):
os.remove(file)
def log(text, speaker):
logging.info(f"\tSent successfully {speaker}: {text}")
@app.get("/speak/")
async def get_audio_from_text(text: str, speaker: str, bg_tasks: BackgroundTasks, speed: float = 0.55):
if speaker == "mandarin":
model = model_mandarin
elif speaker == "arabic":
model = model_arabic
elif speaker == "british":
model = model_british
elif speaker == "african":
model = model_african
model.model.hmm.hparams.max_sampling_time = 10000
model.model.hmm.hparams.duration_quantile_threshold=speed
model.model.hmm.hparams.deterministic_transition=True
model.model.hmm.hparams.predict_means=False
model.model.hmm.hparams.prenet_dropout_while_eval=True
model.model.hmm.prenet.prenet_dropout=0.5
text += "."
sequence = text_to_seq(text)
with torch.no_grad() and torch.inference_mode():
mel_output, hidden_state_travelled, _, _ = model.sample(sequence.squeeze(0), sampling_temp=0.334)
mel_output = mel_output.transpose(1, 2)
audio = generator(mel_output)
audio = denoiser(audio[:, 0], strength=0.004)[:, 0]
filename = savepath / f"{uuid4()}.wav"
sf.write(filename, audio.data.squeeze().cpu().numpy(),
22500, 'PCM_24')
bg_tasks.add_task(del_file, filename)
bg_tasks.add_task(log, text, speaker)
model.model.hmm.hparams.duration_quantile_threshold=0.55
return FileResponse(filename, media_type="audio/wav")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8020)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment