souljoy commited on
Commit
370df52
1 Parent(s): 157a9c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -23
app.py CHANGED
@@ -1,25 +1,61 @@
1
  import gradio as gr
 
 
 
2
 
3
- def request_download(IMAGE_URL):
4
- import requests
5
- r = requests.get(IMAGE_URL)
6
- with open('img.jpg', 'wb') as f:
7
- f.write(r.content)
8
-
9
-
10
- demo = gr.Blocks()
11
-
12
- def change_img():
13
- request_download('https://y.qq.com/music/photo_new/T002R300x300M0000007ujpB2dTrlW_2.jpg')
14
- return gr.Image.update(value='img.jpg')
15
-
16
- with demo:
17
- album_image = gr.Image(value='fws.jpg', visible=True)
18
- fetch_songs = gr.Button(value="更新").style(full_width=True)
19
- fetch_songs.click(
20
- fn=change_img,
21
- inputs=None,
22
- outputs=album_image
23
- )
24
-
25
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
3
+ import torch
4
+ from random import sample
5
 
6
+ model_path = 'souljoy/t5-chinese-lyric'
7
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
8
+ tokenizer = T5Tokenizer.from_pretrained(model_path)
9
+ model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)
10
+ keywords_list = []
11
+ with open('keywords', mode='r', encoding='utf-8') as f:
12
+ for line in f:
13
+ keywords_list.append(line.strip())
14
+
15
+
16
+ def key_cng_change(v_list):
17
+ return ','.join(v_list)
18
+
19
+
20
+ def post_process(text):
21
+ return text.replace("\\n", "\n").replace("\\t", "\t")
22
+
23
+
24
+ def send(my_singer_txt, my_song_txt, my_key_txt):
25
+ if len(my_singer_txt) == 0:
26
+ lyric = '请输入歌手名'
27
+ elif len(my_song_txt) == 0:
28
+ lyric = '请输入歌名'
29
+ elif len(my_key_txt) == 0:
30
+ lyric = '歌词主题词'
31
+ elif len(my_key_txt.split(',')) > 10:
32
+ lyric = '歌词主太多!最多10个'
33
+ else:
34
+ text = '用户:写一首歌,歌手“{}”,歌曲“{}”,以“{}”为主题。\\n小L:'.format(my_singer_txt, my_song_txt, my_key_txt)
35
+ encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=64, return_tensors="pt").to(device)
36
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512,
37
+ do_sample=True, top_p=1, temperature=0.7, no_repeat_ngram_size=3)
38
+ out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
39
+ lyric = post_process(out_text[0])
40
+
41
+ return lyric, gr.CheckboxGroup.update(choices=sample(keywords_list, 100), value=[])
42
+
43
+
44
+ with gr.Blocks() as demo:
45
+ gr.Markdown("""#### 歌词创作机器人 📻 🎵 """)
46
+ with gr.Row():
47
+ with gr.Column():
48
+ singer_txt = gr.Textbox(label='歌手名', placeholder='请输入你的名字,或者其他歌手名字')
49
+ song_txt = gr.Textbox(label='歌名', placeholder='请输入要创作的歌词名称')
50
+ key_txt = gr.Textbox(label='歌词主题词', placeholder='请输入主题词,以中文逗号(,)分割,或点击下方“主题词样例”。最多10个')
51
+ send_button = gr.Button(value="提交")
52
+ key_cng = gr.CheckboxGroup(label='主题词样例', choices=sample(keywords_list, 100))
53
+
54
+
55
+ with gr.Column():
56
+ lyric_txt = gr.Textbox(label='生成歌词')
57
+ key_cng.change(key_cng_change, [key_cng], [key_txt])
58
+ send_button.click(send, [singer_txt, song_txt, key_txt], [lyric_txt,key_cng])
59
+
60
+ if __name__ == "__main__":
61
+ demo.launch()