Mahiruoshi commited on
Commit
22980ff
1 Parent(s): 9c832f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -257
app.py CHANGED
@@ -1,275 +1,124 @@
1
- import logging
2
- logging.getLogger('numba').setLevel(logging.WARNING)
3
- logging.getLogger('matplotlib').setLevel(logging.WARNING)
4
- logging.getLogger('urllib3').setLevel(logging.WARNING)
5
- import json
6
  import re
7
- import numpy as np
8
- import IPython.display as ipd
9
  import torch
 
10
  import commons
11
  import utils
12
  from models import SynthesizerTrn
13
  from text import text_to_sequence
14
- import gradio as gr
15
- import time
16
- import datetime
17
- import os
18
- import pickle
19
- import openai
20
- from scipy.io.wavfile import write
21
- def is_japanese(string):
22
- for ch in string:
23
- if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
24
- return True
25
- return False
26
 
27
- def is_english(string):
28
- import re
29
- pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$')
30
- if pattern.fullmatch(string):
31
- return True
32
- else:
33
- return False
34
 
35
- def extrac(text):
36
- text = re.sub("<[^>]*>","",text)
37
- result_list = re.split(r'\n', text)
38
- final_list = []
39
- for i in result_list:
40
- if is_english(i):
41
- i = romajitable.to_kana(i).katakana
42
- i = i.replace('\n','').replace(' ','')
43
- #Current length of single sentence: 20
44
- '''
45
- if len(i)>1:
46
- if len(i) > 20:
47
- try:
48
- cur_list = re.split(r'。|!', i)
49
- for i in cur_list:
50
- if len(i)>1:
51
- final_list.append(i+'。')
52
- except:
53
- pass
54
- else:
55
- final_list.append(i)
56
- '''
57
- final_list.append(i)
58
- final_list = [x for x in final_list if x != '']
59
- print(final_list)
60
- return final_list
61
 
62
- def to_numpy(tensor: torch.Tensor):
63
- return tensor.detach().cpu().numpy() if tensor.requires_grad \
64
- else tensor.detach().numpy()
 
 
 
 
 
 
65
 
66
- def chatgpt(text):
67
- messages = []
68
- try:
69
- if text != 'exist':
70
- with open('log.pickle', 'rb') as f:
71
- messages = pickle.load(f)
72
- messages.append({"role": "user", "content": text},)
73
- chat = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
74
- reply = chat.choices[0].message.content
75
- messages.append({"role": "assistant", "content": reply})
76
- print(messages[-1])
77
- if len(messages) == 12:
78
- messages[6:10] = messages[8:]
79
- del messages[-2:]
80
- with open('log.pickle', 'wb') as f:
81
- pickle.dump(messages, f)
82
- return reply
83
- except:
84
- messages.append({"role": "user", "content": text},)
85
- chat = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=messages)
86
- reply = chat.choices[0].message.content
87
- messages.append({"role": "assistant", "content": reply})
88
- print(messages[-1])
89
- if len(messages) == 12:
90
- messages[6:10] = messages[8:]
91
- del messages[-2:]
92
- with open('log.pickle', 'wb') as f:
93
- pickle.dump(messages, f)
94
- return reply
95
 
96
- def get_symbols_from_json(path):
97
- assert os.path.isfile(path)
98
- with open(path, 'r') as f:
99
- data = json.load(f)
100
- return data['symbols']
101
 
102
- def sle(language,text):
103
- text = text.replace('\n', '').replace('\r', '').replace(" ", "")
104
- if language == "中文":
105
- tts_input1 = "[ZH]" + text + "[ZH]"
106
- return tts_input1
107
- elif language == "自动":
108
- tts_input1 = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]"
109
- return tts_input1
110
- elif language == "日文":
111
- tts_input1 = "[JA]" + text + "[JA]"
112
- return tts_input1
113
- elif language == "英文":
114
- tts_input1 = "[EN]" + text + "[EN]"
115
- return tts_input1
116
- elif language == "手动":
117
- return text
118
 
119
- def get_text(text,hps_ms):
120
- text_norm = text_to_sequence(text,hps_ms.symbols,hps_ms.data.text_cleaners)
121
- if hps_ms.data.add_blank:
122
- text_norm = commons.intersperse(text_norm, 0)
123
- text_norm = torch.LongTensor(text_norm)
124
- return text_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- def create_tts_fn(net_g,hps,speaker_id):
127
- speaker_id = int(speaker_id)
128
- def tts_fn(history,is_gpt,api_key,is_audio,audiopath,repeat_time,text, language, extract, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
129
- repeat_time = int(repeat_time)
130
- if is_gpt:
131
- openai.api_key = api_key
132
- text = chatgpt(text)
133
- history[-1][1] = text
134
- if not extract:
135
- print(text)
136
- t1 = time.time()
137
- stn_tst = get_text(sle(language,text),hps)
138
- with torch.no_grad():
139
- x_tst = stn_tst.unsqueeze(0).to(dev)
140
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
141
- sid = torch.LongTensor([speaker_id]).to(dev)
142
- audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=n_scale, noise_scale_w=n_scale_w, length_scale=l_scale)[0][0,0].data.cpu().float().numpy()
143
- t2 = time.time()
144
- spending_time = "推理时间为:"+str(t2-t1)+"s"
145
- print(spending_time)
146
- file_path = "subtitles.srt"
147
- try:
148
- write(audiopath + '.wav',22050,audio)
149
- if is_audio:
150
- for i in range(repeat_time):
151
- cmd = 'ffmpeg -y -i ' + audiopath + '.wav' + ' -ar 44100 '+ audiopath.replace('temp','temp'+str(i))
152
- os.system(cmd)
153
- except:
154
- pass
155
- return history,file_path,(hps.data.sampling_rate,audio)
156
- else:
157
- a = ['【','[','(','(']
158
- b = ['】',']',')',')']
159
- for i in a:
160
- text = text.replace(i,'<')
161
- for i in b:
162
- text = text.replace(i,'>')
163
- final_list = extrac(text.replace('“','').replace('”',''))
164
- audio_fin = []
165
- c = 0
166
- t = datetime.timedelta(seconds=0)
167
- f1 = open("subtitles.srt",'w',encoding='utf-8')
168
- for sentence in final_list:
169
- c +=1
170
- stn_tst = get_text(sle(language,sentence),hps)
171
- with torch.no_grad():
172
- x_tst = stn_tst.unsqueeze(0).to(dev)
173
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
174
- sid = torch.LongTensor([speaker_id]).to(dev)
175
- t1 = time.time()
176
- audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=n_scale, noise_scale_w=n_scale_w, length_scale=l_scale)[0][0,0].data.cpu().float().numpy()
177
- t2 = time.time()
178
- spending_time = "第"+str(c)+"句的推理时间为:"+str(t2-t1)+"s"
179
- print(spending_time)
180
- time_start = str(t).split(".")[0] + "," + str(t.microseconds)[:3]
181
- last_time = datetime.timedelta(seconds=len(audio)/float(22050))
182
- t+=last_time
183
- time_end = str(t).split(".")[0] + "," + str(t.microseconds)[:3]
184
- print(time_end)
185
- f1.write(str(c-1)+'\n'+time_start+' --> '+time_end+'\n'+sentence+'\n\n')
186
- audio_fin.append(audio)
187
- try:
188
- write(audiopath + '.wav',22050,np.concatenate(audio_fin))
189
- if is_audio:
190
- for i in range(repeat_time):
191
- cmd = 'ffmpeg -y -i ' + audiopath + '.wav' + ' -ar 44100 '+ audiopath.replace('temp','temp'+str(i))
192
- os.system(cmd)
193
-
194
- except:
195
- pass
196
-
197
- file_path = "subtitles.srt"
198
- return history,file_path,(hps.data.sampling_rate, np.concatenate(audio_fin))
199
- return tts_fn
200
 
201
- def bot(history,user_message):
202
- return history + [[user_message, None]]
 
203
 
204
- if __name__ == '__main__':
205
- hps = utils.get_hparams_from_file('checkpoints/tmp/config.json')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
207
- models = []
208
- schools = ["Seisho-Nijigasaki","Seisho-betterchinese","Nijigasaki","Nijigasaki-biaobei"]
209
- lan = ["中文","日文","自动","手动"]
210
- with open("checkpoints/info.json", "r", encoding="utf-8") as f:
211
- models_info = json.load(f)
212
- for i in models_info:
213
- checkpoint = models_info[i]["checkpoint"]
214
- phone_dict = {
215
- symbol: i for i, symbol in enumerate(hps.symbols)
216
- }
217
- n_symbols = len(hps.symbols) if 'symbols' in hps.keys() else 0
218
- net_g = SynthesizerTrn(
219
- n_symbols,
220
- hps.data.filter_length // 2 + 1,
221
- hps.train.segment_size // hps.data.hop_length,
222
- n_speakers=hps.data.n_speakers,
223
- **hps.model).to(dev)
224
- _ = net_g.eval()
225
- _ = utils.load_checkpoint(checkpoint, net_g)
226
- school = models_info[i]
227
- speakers = school["speakers"]
228
- content = []
229
- for j in speakers:
230
- sid = int(speakers[j]['sid'])
231
- title = school
232
- example = speakers[j]['speech']
233
- name = speakers[j]["name"]
234
- content.append((sid, name, title, example, create_tts_fn(net_g,hps,sid)))
235
- models.append(content)
236
-
237
- with gr.Blocks() as app:
238
- with gr.Tabs():
239
- for i in schools:
240
- with gr.TabItem(i):
241
- for (sid, name, title, example, tts_fn) in models[schools.index(i)]:
242
- with gr.TabItem(name):
243
- with gr.Column():
244
- with gr.Row():
245
- with gr.Row():
246
- gr.Markdown(
247
- '<div align="center">'
248
- f'<img style="width:auto;height:400px;" src="file/image/{name}.png">'
249
- '</div>'
250
- )
251
- chatbot = gr.Chatbot()
252
- with gr.Row():
253
- with gr.Column(scale=0.85):
254
- input1 = gr.TextArea(label="Text", value=example,lines = 1)
255
- with gr.Column(scale=0.15, min_width=0):
256
- btnVC = gr.Button("Send")
257
- output1 = gr.Audio(label="采样率22050")
258
- with gr.Accordion(label="Setting", open=False):
259
- input2 = gr.Dropdown(label="Language", choices=lan, value="自动", interactive=True)
260
- input3 = gr.Checkbox(value=False, label="长句切割(小说合成)")
261
- input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.6)
262
- input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.668)
263
- input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1)
264
- with gr.Accordion(label="Advanced Setting", open=False):
265
- audio_input3 = gr.Dropdown(label="重复次数", choices=list(range(101)), value='0', interactive=True)
266
- api_input1 = gr.Checkbox(value=False, label="接入chatgpt")
267
- api_input2 = gr.TextArea(label="api-key",lines=1,value = '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')
268
- output2 = gr.outputs.File(label="字幕文件:subtitles.srt")
269
- audio_input1 = gr.Checkbox(value=False, label="修改音频路径(live2d)")
270
- audio_input2 = gr.TextArea(label="音频路径",lines=1,value = 'D:/app_develop/live2d_whole/2010002/sounds/temp.wav')
271
- btnVC.click(bot, inputs = [chatbot,input1], outputs = [chatbot]).then(
272
- tts_fn, inputs=[chatbot,api_input1,api_input2,audio_input1,audio_input2,audio_input3,input1,input2,input3,input4,input5,input6], outputs=[chatbot,output2,output1]
273
- )
274
-
275
  app.launch()
 
 
 
 
 
 
1
  import re
2
+ import gradio as gr
 
3
  import torch
4
+ import unicodedata
5
  import commons
6
  import utils
7
  from models import SynthesizerTrn
8
  from text import text_to_sequence
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ config_json = "checkpoints/paimeng/config.json"
11
+ pth_path = "checkpoints/paimeng/model.pth"
 
 
 
 
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def get_text(text, hps, cleaned=False):
15
+ if cleaned:
16
+ text_norm = text_to_sequence(text, hps.symbols, [])
17
+ else:
18
+ text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
19
+ if hps.data.add_blank:
20
+ text_norm = commons.intersperse(text_norm, 0)
21
+ text_norm = torch.LongTensor(text_norm)
22
+ return text_norm
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def get_label(text, label):
26
+ if f'[{label}]' in text:
27
+ return True, text.replace(f'[{label}]', '')
28
+ else:
29
+ return False, text
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ def clean_text(text):
33
+ print(text)
34
+ jap = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]') # 匹配日文
35
+ text = unicodedata.normalize('NFKC', text)
36
+ text = f"[JA]{text}[JA]" if jap.search(text) else f"[ZH]{text}[ZH]"
37
+ return text
38
+
39
+
40
+ def load_model(config_json, pth_path):
41
+ dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
+ hps_ms = utils.get_hparams_from_file(f"{config_json}")
43
+ n_speakers = hps_ms.data.n_speakers if 'n_speakers' in hps_ms.data.keys() else 0
44
+ n_symbols = len(hps_ms.symbols) if 'symbols' in hps_ms.keys() else 0
45
+ net_g_ms = SynthesizerTrn(
46
+ n_symbols,
47
+ hps_ms.data.filter_length // 2 + 1,
48
+ hps_ms.train.segment_size // hps_ms.data.hop_length,
49
+ n_speakers=n_speakers,
50
+ **hps_ms.model).to(dev)
51
+ _ = net_g_ms.eval()
52
+ _ = utils.load_checkpoint(pth_path, net_g_ms)
53
+ return net_g_ms
54
+
55
+ net_g_ms = load_model(config_json, pth_path)
56
 
57
+ def selection(speaker):
58
+ if speaker == "南小鸟":
59
+ spk = 0
60
+ return spk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ elif speaker == "园田海未":
63
+ spk = 1
64
+ return spk
65
 
66
+ elif speaker == "小泉花阳":
67
+ spk = 2
68
+ return spk
69
+
70
+ elif speaker == "星空凛":
71
+ spk = 3
72
+ return spk
73
+
74
+ elif speaker == "东条希":
75
+ spk = 4
76
+ return spk
77
+
78
+ elif speaker == "矢泽妮可":
79
+ spk = 5
80
+ return spk
81
+
82
+ elif speaker == "绚濑���里":
83
+ spk = 6
84
+ return spk
85
+
86
+ elif speaker == "西木野真姬":
87
+ spk = 7
88
+ return spk
89
+
90
+ elif speaker == "高坂穗乃果":
91
+ spk = 8
92
+ return spk
93
+
94
+ def infer(text,speaker_id, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
95
+ text = clean_text(text)
96
+ speaker_id = int(selection(speaker_id))
97
  dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
98
+ hps_ms = utils.get_hparams_from_file(f"{config_json}")
99
+ with torch.no_grad():
100
+ stn_tst = get_text(text, hps_ms, cleaned=False)
101
+ x_tst = stn_tst.unsqueeze(0).to(dev)
102
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
103
+ sid = torch.LongTensor([speaker_id]).to(dev)
104
+ audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=n_scale, noise_scale_w=n_scale_w, length_scale=l_scale)[0][
105
+ 0, 0].data.cpu().float().numpy()
106
+ return (hps_ms.data.sampling_rate, audio)
107
+
108
+ idols = ["南小鸟","园田海未","小泉花阳","星空凛","东条希","矢泽妮可","绚濑绘里","西木野真姬","高坂穗乃果"]
109
+ app = gr.Blocks()
110
+ with app:
111
+ with gr.Tabs():
112
+
113
+ with gr.TabItem("Basic"):
114
+
115
+ tts_input1 = gr.TextArea(label="请输入纯中文或纯日文", value="大家好")
116
+ para_input1 = gr.Slider(minimum= 0.01,maximum=1.0,label="更改噪声比例", value=0.667)
117
+ para_input2 = gr.Slider(minimum= 0.01,maximum=1.0,label="更改噪声偏差", value=0.8)
118
+ para_input3 = gr.Slider(minimum= 0.1,maximum=10,label="更改时间比例", value=1)
119
+ tts_submit = gr.Button("Generate", variant="primary")
120
+ speaker1 = gr.Dropdown(label="选择说话人",choices=idols, value="高坂穗乃果", interactive=True)
121
+ tts_output2 = gr.Audio(label="Output")
122
+
123
+ tts_submit.click(infer, [tts_input1,speaker1,para_input1,para_input2,para_input3], [tts_output2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  app.launch()