Skip to content

Instantly share code, notes, and snippets.

@litagin02
Last active February 25, 2024 16:13
Show Gist options
  • Star 8 You must be signed in to star a gist
  • Fork 9 You must be signed in to fork a gist
  • Save litagin02/f07a5d7217c9efa4918de0812fd99cd3 to your computer and use it in GitHub Desktop.
Save litagin02/f07a5d7217c9efa4918de0812fd99cd3 to your computer and use it in GitHub Desktop.
Bert-VITS2のモデルマージするやつ(声音・感情表現それぞれを取っ替えたり混ぜたり)
import os
import gradio as gr
import torch
from infer import get_net_g, infer
import utils
voice_keys = ["dec", "flow"]
speech_style_keys = ["enc_p"]
tempo_keys = ["sdp", "dp"]
models_dir = "merge"
model_list = [
os.path.join(models_dir, f) for f in os.listdir(models_dir) if f.endswith(".pth")
]
config_path = os.path.join(models_dir, "config.json")
def tts(model_path, text):
hps = utils.get_hparams_from_file(config_path)
speaker_name = next(iter(hps.data.spk2id.keys()))
device = "cuda" if torch.cuda.is_available() else "cpu"
net_g = get_net_g(
model_path=model_path,
version=hps.version,
device=device,
hps=hps,
)
if hps.version == "2.1":
emotion = 0
elif hps.version == "2.2":
emotion = ""
with torch.no_grad():
audio = infer(
text=text,
sdp_ratio=0.2,
noise_scale=0.6,
noise_scale_w=0.8,
length_scale=1,
sid=speaker_name,
language="JP",
hps=hps,
net_g=net_g,
device=device,
emotion=emotion,
)
return (hps.data.sampling_rate, audio)
def merge_models(
model_path_a, model_path_b, voice_weight, speech_style_weight, tempo_weight
):
"""model Aを起点に、model Bの各要素を重み付けしてマージする"""
# モデルを読み込む
model_a = torch.load(model_path_a, map_location="cpu")
model_b = torch.load(model_path_b, map_location="cpu")
merged_model = model_a.copy()
for key in model_a["model"].keys():
if any([key.startswith(prefix) for prefix in voice_keys]):
weight = voice_weight
elif any([key.startswith(prefix) for prefix in speech_style_keys]):
weight = speech_style_weight
elif any([key.startswith(prefix) for prefix in tempo_keys]):
weight = tempo_weight
else:
continue
merged_model["model"][key] = (
model_a["model"][key] * (1 - weight) + model_b["model"][key] * weight
)
merged_model_path = os.path.join(models_dir, "merged_model.pth")
torch.save(merged_model, merged_model_path)
return merged_model_path
def refresh_models():
model_list = [
os.path.join(models_dir, f)
for f in os.listdir(models_dir)
if f.endswith(".pth")
]
return gr.Dropdown(choices=model_list), gr.Dropdown(choices=model_list)
initial_md = """
# Bert-VITS2 モデルマージツール
2つのBert-VITS2モデルから、声質・話し方・話す速さを取り替えたり混ぜたりするやつです。
確認したバージョンは2.1と2.2です。同じバージョン同士(2.1同士、2.2同士)でしか動きません。
挙動としてはモデルAを起点にするので、configファイル等はモデルAのものを使ってください。もしかしたら若干モデルAに結果が寄りがちかもしれないけど多分そんなに変わらないです。
## 使い方
`merge`フォルダを作って、直下に混ぜたい`*.pth`ファイルを置いてください。またモデルAのconfig.jsonファイルも同じところに置いてください。
"""
# Gradioインターフェースの作成
with gr.Blocks() as demo:
gr.Markdown(initial_md)
with gr.Row():
with gr.Column():
with gr.Row():
model_a = gr.Dropdown(
label="モデルA (pthファイル)", choices=model_list, scale=2
)
model_b = gr.Dropdown(
label="モデルB (pthファイル)", choices=model_list, scale=2
)
refresh_button = gr.Button("モデルリストの再読み込み", scale=1)
refresh_button.click(fn=refresh_models, outputs=[model_a, model_b])
voice_slider = gr.Slider(
label="声質",
value=0,
minimum=0,
maximum=1,
step=0.1,
)
speech_style_slider = gr.Slider(
label="話し方(抑揚・感情表現等)",
value=0,
minimum=0,
maximum=1,
step=0.1,
)
tempo_slider = gr.Slider(
label="話す速さ・リズム・テンポ",
value=0,
minimum=0,
maximum=1,
step=0.1,
)
merge_button = gr.Button("マージ")
merged_model_output = gr.Textbox(label="マージされたモデルのパス")
# マージボタンの動作を定義
merge_button.click(
fn=merge_models,
inputs=[
model_a,
model_b,
voice_slider,
speech_style_slider,
tempo_slider,
],
outputs=merged_model_output,
)
with gr.Column():
gr.Markdown("### マージされたモデルでのTTSテスト")
input_text = gr.Textbox(label="テキスト")
play_button = gr.Button("再生")
tts_output = gr.Audio(label="オーディオ出力")
play_button.click(
fn=tts,
inputs=[merged_model_output, input_text],
outputs=tts_output,
)
demo.launch(inbrowser=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment