txya900619's picture
feat: disable use_default_emb_or_custom_radio label
77dc123 verified
raw
history blame
No virus
7.71 kB
import os
import gradio as gr
import TTS
from TTS.utils.synthesizer import Synthesizer
import numpy as np
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
import spaces
import torch
from ipa.ipa import get_ipa, parse_ipa
from replace.tts import ChangedVitsConfig
TTS.tts.configs.vits_config.VitsConfig = ChangedVitsConfig
def load_model(model_id):
model_dir = snapshot_download(model_id)
config_file_path = os.path.join(model_dir, "config.json")
model_ckpt_path = os.path.join(model_dir, "model.pth")
speaker_file_path = os.path.join(model_dir, "speakers.pth")
language_file_path = os.path.join(model_dir, "language_ids.json")
speaker_embedding_file_path = os.path.join(model_dir, "speaker_embs.pth")
temp_config_path = "temp_config.json"
with open(config_file_path, "r") as f:
content = f.read()
content = content.replace("speakers.pth", speaker_file_path)
content = content.replace("language_ids.json", language_file_path)
content = content.replace("speaker_embs.pth", speaker_embedding_file_path)
f.close()
with open(temp_config_path, "w") as f:
f.write(content)
f.close()
return Synthesizer(
tts_checkpoint=model_ckpt_path,
tts_config_path=temp_config_path,
use_cuda=torch.cuda.is_available(),
)
OmegaConf.register_new_resolver("load_model", load_model)
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
@spaces.GPU
def _do_tts(model_id, ipa, language_name,speaker_name=None, speaker_wav=None):
model = models_config[model_id]["model"]
if speaker_wav is not None:
return model.tts(
ipa,
speaker_wav=speaker_wav,
language_name=language_name,
split_sentences=False,
)
return model.tts(
ipa,
speaker_name=speaker_name,
language_name=language_name,
split_sentences=False,
)
def text_to_speech(
model_id: str,
use_default_emb_or_custom: str,
speaker_wav,
speaker: str,
dialect,
text: str,
):
if len(text) == 0:
raise gr.Error("請勿輸入空字串。")
words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect)
if len(missing_words) > 0:
raise gr.Error(
f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。"
)
parsed_ipa = parse_ipa(ipa)
if use_default_emb_or_custom == "預設語者":
wav = _do_tts(
model_id,
parsed_ipa,
speaker_name=speaker
if len(models_config[model_id]["speaker_mapping"]) > 1
else None,
language_name=dialect,
)
else:
wav = _do_tts(
model_id,
parsed_ipa,
speaker_wav=speaker_wav,
language_name=dialect,
)
return (
words,
pinyin,
(models_config[model_id]["model"].tts_model.config.audio.sample_rate, np.array(wav)),
)
def when_model_selected(model_id):
model_config = models_config[model_id]
speaker_drop_down_choices = [
(k, v) for k, v in model_config["speaker_mapping"].items()
]
dialect_drop_down_choices = [
(k, v) for k, v in model_config["dialect_mapping"].items()
]
use_default_emb_or_ref_radio_visible = False
if model_config["model"].tts_model.config.model_args.speaker_encoder_model_path:
use_default_emb_or_ref_radio_visible = True
return (
gr.update(
choices=speaker_drop_down_choices,
value=speaker_drop_down_choices[0][1]
if len(speaker_drop_down_choices) > 0
else None,
interactive=len(speaker_drop_down_choices) > 1,
),
gr.update(
choices=dialect_drop_down_choices,
value=dialect_drop_down_choices[0][1],
interactive=len(dialect_drop_down_choices) > 1,
),
gr.update(visible=use_default_emb_or_ref_radio_visible, value="default"),
)
def use_default_emb_or_custom_radio_input(use_default_emb_or_custom):
if use_default_emb_or_custom == "客製化語者":
return gr.update(visible=True), gr.update(visible=False)
return gr.update(visible=False), gr.update(visible=True)
demo = gr.Blocks(
title="臺灣客語語音生成系統",
css="@import url(https://tauhu.tw/tauhu-oo.css);",
theme=gr.themes.Default(
font=(
"tauhu-oo",
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
)
),
)
with demo:
default_model_id = list(models_config.keys())[0]
model_drop_down = gr.Dropdown(
models_config.keys(),
value=default_model_id,
label="模型",
)
use_default_emb_or_custom_radio = gr.Radio(
label=None,
choices=["預設語者", "客製化語者"],
value="預設語者",
visible=True,
show_label=False,
)
speaker_wav = gr.Audio(
label="客製化語音",
visible=False,
editable=False,
type="filepath",
waveform_options=gr.WaveformOptions(
show_controls=False,
sample_rate=16000,
),
)
speaker_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["speaker_mapping"].items()
],
value=list(models_config[default_model_id]["speaker_mapping"].values())[0],
label="語者",
interactive=len(models_config[default_model_id]["speaker_mapping"]) > 1,
)
use_default_emb_or_custom_radio.input(
use_default_emb_or_custom_radio_input,
inputs=[use_default_emb_or_custom_radio],
outputs=[speaker_wav, speaker_drop_down],
)
dialect_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["dialect_mapping"].items()
],
value=list(models_config[default_model_id]["dialect_mapping"].values())[0],
label="腔調",
interactive=len(models_config[default_model_id]["dialect_mapping"]) > 1,
)
model_drop_down.input(
when_model_selected,
inputs=[model_drop_down],
outputs=[speaker_drop_down, dialect_drop_down, use_default_emb_or_custom_radio],
)
gr.Markdown(
"""
# 臺灣客語語音合成系統
### Taiwanese Hakka Text-to-Speech System
### 研發團隊
- **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)([聯和科創](https://www.104.com.tw/company/1a2x6bmu75))**
- **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)([聯和科創](https://www.104.com.tw/company/1a2x6bmu75))**
### 合作單位
- **[國立聯合大學智慧客家實驗室](https://www.gohakka.org)**
"""
)
gr.Interface(
text_to_speech,
inputs=[
model_drop_down,
use_default_emb_or_custom_radio,
speaker_wav,
speaker_drop_down,
dialect_drop_down,
gr.Textbox(
label="輸入文字",
value="客家族群个六堆運動會會一直延續下去,為臺灣个體育史寫下特別个一頁。",
),
],
outputs=[
gr.Textbox(interactive=False, label="斷詞"),
gr.Textbox(interactive=False, label="客語拼音"),
gr.Audio(interactive=False, label="合成語音", show_download_button=True),
],
allow_flagging="auto",
)
demo.launch()