txya900619's picture
feat: add type hint to func and delete un use import
5c3ce0e
raw
history blame
No virus
6.99 kB
import os
import gradio as gr
import TTS
from TTS.utils.synthesizer import Synthesizer
import numpy as np
from huggingface_hub import snapshot_download
from omegaconf import OmegaConf
from ipa.ipa import get_ipa, parse_ipa
from replace.tts import ChangedVitsConfig
TTS.tts.configs.vits_config.VitsConfig = ChangedVitsConfig
def load_model(model_id):
model_dir = snapshot_download(model_id)
config_file_path = os.path.join(model_dir, "config.json")
model_ckpt_path = os.path.join(model_dir, "model.pth")
speaker_file_path = os.path.join(model_dir, "speakers.pth")
language_file_path = os.path.join(model_dir, "language_ids.json")
speaker_embedding_file_path = os.path.join(model_dir, "speaker_embs.pth")
temp_config_path = "temp_config.json"
with open(config_file_path, "r") as f:
content = f.read()
content = content.replace("speakers.pth", speaker_file_path)
content = content.replace("language_ids.json", language_file_path)
content = content.replace("speaker_embs.pth", speaker_embedding_file_path)
f.close()
with open(temp_config_path, "w") as f:
f.write(content)
f.close()
return Synthesizer(tts_checkpoint=model_ckpt_path, tts_config_path=temp_config_path)
OmegaConf.register_new_resolver("load_model", load_model)
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
def text_to_speech(
model_id: str,
use_default_emb_or_custom: str,
speaker_wav,
speaker: str,
dialect,
text: str,
):
model = models_config[model_id]["model"]
if len(text) == 0:
raise gr.Error("請勿輸入空字串。")
words, ipa, pinyin, missing_words = get_ipa(text, dialect=dialect)
if len(missing_words) > 0:
raise gr.Error(
f"句子中的[{','.join(missing_words)}]目前無法轉成 ipa。請嘗試其他句子。"
)
if use_default_emb_or_custom == "default":
wav = model.tts(
parse_ipa(ipa),
speaker_name=speaker if len(models_config[model_id]["speaker_mapping"]) > 1 else None,
language_name=dialect,
split_sentences=False,
)
else:
wav = model.tts(
parse_ipa(ipa),
speaker_wav=speaker_wav,
language_name=dialect,
split_sentences=False,
)
return (
words,
pinyin,
(model.tts_model.config.audio.sample_rate, np.array(wav)),
)
def when_model_selected(model_id):
model_config = models_config[model_id]
speaker_drop_down_choices = [
(k, v) for k, v in model_config["speaker_mapping"].items()
]
dialect_drop_down_choices = [
(k, v) for k, v in model_config["dialect_mapping"].items()
]
use_default_emb_or_ref_radio_visible = False
if model_config["model"].tts_model.config.model_args.speaker_encoder_model_path:
use_default_emb_or_ref_radio_visible = True
return (
gr.update(
choices=speaker_drop_down_choices,
value=speaker_drop_down_choices[0][1] if len(speaker_drop_down_choices) > 0 else None,
interactive=len(speaker_drop_down_choices) > 1,
),
gr.update(
choices=dialect_drop_down_choices,
value=dialect_drop_down_choices[0][1],
interactive=len(dialect_drop_down_choices) > 1,
),
gr.update(visible=use_default_emb_or_ref_radio_visible, value="default"),
)
def use_default_emb_or_custom_radio_input(use_default_emb_or_custom):
if use_default_emb_or_custom == "custom":
return gr.update(visible=True), gr.update(visible=False)
return gr.update(visible=False), gr.update(visible=True)
demo = gr.Blocks(
title="臺灣客語語音生成系統",
css="@import url(https://tauhu.tw/tauhu-oo.css);",
theme=gr.themes.Default(
font=(
"tauhu-oo",
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
)
),
)
with demo:
default_model_id = list(models_config.keys())[0]
model_drop_down = gr.Dropdown(
models_config.keys(),
value=default_model_id,
label="模型",
)
use_default_emb_or_custom_radio = gr.Radio(
label="use default speaker embedding or custom speaker embedding",
choices=["default", "custom"],
value="default",
visible=False,
)
speaker_wav = gr.Microphone(
label="speaker wav",
visible=False,
editable=False,
type="filepath",
waveform_options=gr.WaveformOptions(
show_controls=False,
sample_rate=16000,
),
)
speaker_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["speaker_mapping"].items()
],
value=list(models_config[default_model_id]["speaker_mapping"].values())[0],
label="語者",
interactive=len(models_config[default_model_id]["speaker_mapping"]) > 1,
)
use_default_emb_or_custom_radio.input(
use_default_emb_or_custom_radio_input,
inputs=[use_default_emb_or_custom_radio],
outputs=[speaker_wav, speaker_drop_down],
)
dialect_drop_down = gr.Dropdown(
choices=[
(k, v)
for k, v in models_config[default_model_id]["dialect_mapping"].items()
],
value=list(models_config[default_model_id]["dialect_mapping"].values())[0],
label="腔調",
interactive=len(models_config[default_model_id]["dialect_mapping"]) > 1,
)
model_drop_down.input(
when_model_selected,
inputs=[model_drop_down],
outputs=[speaker_drop_down, dialect_drop_down, use_default_emb_or_custom_radio],
)
gr.Markdown(
"""
# 臺灣客語語音合成系統
### Taiwanese Hakka Text-to-Speech System
### 模型
- **sixian-1f-240417**(四縣腔,單一語者)
### 研發
- **[李鴻欣 Hung-Shin Lee](mailto:hungshinlee@gmail.com)(諾思資訊 North Co., Ltd.)**
- **[陳力瑋 Li-Wei Chen](mailto:wayne900619@gmail.com)(諾思資訊 North Co., Ltd.)**
"""
)
gr.Interface(
text_to_speech,
inputs=[
model_drop_down,
use_default_emb_or_custom_radio,
speaker_wav,
speaker_drop_down,
dialect_drop_down,
gr.Textbox(label="輸入文字", value="客家族群个六堆運動會會一直延續下去,為臺灣个體育史寫下特別个一頁。"),
],
outputs=[
gr.Textbox(interactive=False, label="斷詞"),
gr.Textbox(interactive=False, label="客語拼音"),
gr.Audio(interactive=False, label="合成語音", show_download_button=True),
],
allow_flagging="auto",
)
demo.launch()