Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,25 +1,61 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
return
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
3 |
+
import torch
|
4 |
+
from random import sample
|
5 |
|
6 |
+
model_path = 'souljoy/t5-chinese-lyric'
|
7 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
8 |
+
tokenizer = T5Tokenizer.from_pretrained(model_path)
|
9 |
+
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)
|
10 |
+
keywords_list = []
|
11 |
+
with open('keywords', mode='r', encoding='utf-8') as f:
|
12 |
+
for line in f:
|
13 |
+
keywords_list.append(line.strip())
|
14 |
+
|
15 |
+
|
16 |
+
def key_cng_change(v_list):
|
17 |
+
return ','.join(v_list)
|
18 |
+
|
19 |
+
|
20 |
+
def post_process(text):
|
21 |
+
return text.replace("\\n", "\n").replace("\\t", "\t")
|
22 |
+
|
23 |
+
|
24 |
+
def send(my_singer_txt, my_song_txt, my_key_txt):
|
25 |
+
if len(my_singer_txt) == 0:
|
26 |
+
lyric = '请输入歌手名'
|
27 |
+
elif len(my_song_txt) == 0:
|
28 |
+
lyric = '请输入歌名'
|
29 |
+
elif len(my_key_txt) == 0:
|
30 |
+
lyric = '歌词主题词'
|
31 |
+
elif len(my_key_txt.split(',')) > 10:
|
32 |
+
lyric = '歌词主太多!最多10个'
|
33 |
+
else:
|
34 |
+
text = '用户:写一首歌,歌手“{}”,歌曲“{}”,以“{}”为主题。\\n小L:'.format(my_singer_txt, my_song_txt, my_key_txt)
|
35 |
+
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=64, return_tensors="pt").to(device)
|
36 |
+
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512,
|
37 |
+
do_sample=True, top_p=1, temperature=0.7, no_repeat_ngram_size=3)
|
38 |
+
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
|
39 |
+
lyric = post_process(out_text[0])
|
40 |
+
|
41 |
+
return lyric, gr.CheckboxGroup.update(choices=sample(keywords_list, 100), value=[])
|
42 |
+
|
43 |
+
|
44 |
+
with gr.Blocks() as demo:
|
45 |
+
gr.Markdown("""#### 歌词创作机器人 📻 🎵 """)
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column():
|
48 |
+
singer_txt = gr.Textbox(label='歌手名', placeholder='请输入你的名字,或者其他歌手名字')
|
49 |
+
song_txt = gr.Textbox(label='歌名', placeholder='请输入要创作的歌词名称')
|
50 |
+
key_txt = gr.Textbox(label='歌词主题词', placeholder='请输入主题词,以中文逗号(,)分割,或点击下方“主题词样例”。最多10个')
|
51 |
+
send_button = gr.Button(value="提交")
|
52 |
+
key_cng = gr.CheckboxGroup(label='主题词样例', choices=sample(keywords_list, 100))
|
53 |
+
|
54 |
+
|
55 |
+
with gr.Column():
|
56 |
+
lyric_txt = gr.Textbox(label='生成歌词')
|
57 |
+
key_cng.change(key_cng_change, [key_cng], [key_txt])
|
58 |
+
send_button.click(send, [singer_txt, song_txt, key_txt], [lyric_txt,key_cng])
|
59 |
+
|
60 |
+
if __name__ == "__main__":
|
61 |
+
demo.launch()
|