txya900619's picture
feat: update default model and let 1p1l model show dialect and speaker dropdown
a0c110f
raw
history blame
No virus
6.91 kB
import json
import os
import tempfile
import gradio as gr
import TTS
from TTS.utils.synthesizer import Synthesizer
import numpy as np
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from ipa.ipa import get_ipa, parse_ipa
from replace.tts import ChangedVitsConfig
TTS.tts.configs.vits_config.VitsConfig = ChangedVitsConfig
def load_model(model_id):
model_dir = snapshot_download(model_id)
config_file_path = os.path.join(model_dir, "config.json")
model_ckpt_path = os.path.join(model_dir, "model.pth")
speaker_file_path = os.path.join(model_dir, "speakers.pth")
language_file_path = os.path.join(model_dir, "language_ids.json")
speaker_embedding_file_path = os.path.join(model_dir, "speaker_embs.pth")
temp_config_path = "temp_config.json"
with open(config_file_path, "r") as f:
content = f.read()
content = content.replace("speakers.pth", speaker_file_path)
content = content.replace("language_ids.json", language_file_path)
content = content.replace("speaker_embs.pth", speaker_embedding_file_path)
f.close()
with open(temp_config_path, "w") as f:
f.write(content)
f.close()
return Synthesizer(tts_checkpoint=model_ckpt_path, tts_config_path=temp_config_path)
OmegaConf.register_new_resolver("load_model", load_model)
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
def text_to_speech(
model_id: str,
use_default_emb_or_custom: str,
speaker_wav,
speaker: str,
dialect,
text: str,
):
model = models_config[model_id]["model"]
if len(text) == 0:
raise gr.Error("請勿輸入空字串。")
words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect)
if len(missing_words) > 0:
raise gr.Error(
f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。"
)
if use_default_emb_or_custom == "default":
wav = model.tts(
parse_ipa(ipa),
speaker_name=speaker if len(models_config[model_id]["speaker_mapping"]) > 1 else None,
language_name=dialect,
split_sentences=False,
)
else:
wav = model.tts(
parse_ipa(ipa),
speaker_wav=speaker_wav,
language_name=dialect,
split_sentences=False,
)
return (
words,
pinyin,
(model.tts_model.config.audio.sample_rate, np.array(wav)),
)
def when_model_selected(model_id):
model_config = models_config[model_id]
speaker_drop_down_choices = [
(k, v) for k, v in model_config["speaker_mapping"].items()
]
dialect_drop_down_choices = [
(k, v) for k, v in model_config["dialect_mapping"].items()
]
use_default_emb_or_ref_radio_visible = False
if model_config["model"].tts_model.config.model_args.speaker_encoder_model_path:
use_default_emb_or_ref_radio_visible = True
return (
gr.update(
choices=speaker_drop_down_choices,
value=speaker_drop_down_choices[0][1] if len(speaker_drop_down_choices) > 0 else None,
interactive=len(speaker_drop_down_choices) > 1,
),
gr.update(
choices=dialect_drop_down_choices,
value=dialect_drop_down_choices[0][1],
interactive=len(dialect_drop_down_choices) > 1,
),
gr.update(visible=use_default_emb_or_ref_radio_visible, value="default"),
)
def use_default_emb_or_custom_radio_input(use_default_emb_or_custom):
if use_default_emb_or_custom == "custom":
return gr.update(visible=True), gr.update(visible=False)
return gr.update(visible=False), gr.update(visible=True)
demo = gr.Blocks(
title="臺灣客語語音生成系統",
css="@import url(https://tauhu.tw/tauhu-oo.css);",
theme=gr.themes.Default(
font=(
"tauhu-oo",
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
)
),
)
with demo:
default_model_id = list(models_config.keys())[0]
model_drop_down = gr.Dropdown(
models_config.keys(),
value=default_model_id,
label="模型",
)
use_default_emb_or_custom_radio = gr.Radio(
label="use default speaker embedding or custom speaker embedding",
choices=["default", "custom"],
value="default",
visible=False,
)
speaker_wav = gr.Microphone(
label="speaker wav",
visible=False,
editable=False,
type="filepath",
waveform_options=gr.WaveformOptions(
show_controls=False,
sample_rate=16000,
),
)
speaker_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["speaker_mapping"].items()
],
value=list(models_config[default_model_id]["speaker_mapping"].values())[0],
label="語者",
interactive=len(models_config[default_model_id]["speaker_mapping"]) > 1,
)
use_default_emb_or_custom_radio.input(
use_default_emb_or_custom_radio_input,
inputs=[use_default_emb_or_custom_radio],
outputs=[speaker_wav, speaker_drop_down],
)
dialect_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["dialect_mapping"].items()
],
value=list(models_config[default_model_id]["dialect_mapping"].values())[0],
label="腔調",
interactive=len(models_config[default_model_id]["dialect_mapping"]) > 1,
)
model_drop_down.input(
when_model_selected,
inputs=[model_drop_down],
outputs=[speaker_drop_down, dialect_drop_down, use_default_emb_or_custom_radio],
)
gr.Markdown(
"""
# 臺灣客語語音合成系統
### Taiwanese Hakka Text-to-Speech System
### 模型
- **sixian-1p-240417**(四縣腔,單一語者)
### 研發
- **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)(諾思資訊 North Co., Ltd.)**
- **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)(諾思資訊 North Co., Ltd.)**
"""
)
gr.Interface(
text_to_speech,
inputs=[
model_drop_down,
use_default_emb_or_custom_radio,
speaker_wav,
speaker_drop_down,
dialect_drop_down,
gr.Textbox(label="輸入文字"),
],
outputs=[
gr.Textbox(interactive=False, label="斷詞"),
gr.Textbox(interactive=False, label="客語拼音"),
gr.Audio(interactive=False, label="合成語音", show_download_button=True),
],
allow_flagging="auto",
)
demo.launch()