File size: 2,662 Bytes
40d108c
370df52
 
 
40d108c
370df52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from random import sample

model_path = 'souljoy/t5-chinese-lyric'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)
keywords_list = []
with open('keywords', mode='r', encoding='utf-8') as f:
    for line in f:
        keywords_list.append(line.strip())


def key_cng_change(v_list):
    return ','.join(v_list)


def post_process(text):
    return text.replace("\\n", "\n").replace("\\t", "\t")


def send(my_singer_txt, my_song_txt, my_key_txt):
    if len(my_singer_txt) == 0:
        lyric = '请输入歌手名'
    elif len(my_song_txt) == 0:
        lyric = '请输入歌名'
    elif len(my_key_txt) == 0:
        lyric = '歌词主题词'
    elif len(my_key_txt.split(',')) > 10:
        lyric = '歌词主太多!最多10个'
    else:
        text = '用户:写一首歌,歌手“{}”,歌曲“{}”,以“{}”为主题。\\n小L:'.format(my_singer_txt, my_song_txt, my_key_txt)
        encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=64, return_tensors="pt").to(device)
        out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512,
                             do_sample=True, top_p=1, temperature=0.7, no_repeat_ngram_size=3)
        out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
        lyric = post_process(out_text[0])

    return lyric, gr.CheckboxGroup.update(choices=sample(keywords_list, 100), value=[])


with gr.Blocks() as demo:
    gr.Markdown("""#### 歌词创作机器人 📻 🎵 """)
    with gr.Row():
        with gr.Column():
            singer_txt = gr.Textbox(label='歌手名', placeholder='请输入你的名字,或者其他歌手名字')
            song_txt = gr.Textbox(label='歌名', placeholder='请输入要创作的歌词名称')
            key_txt = gr.Textbox(label='歌词主题词', placeholder='请输入主题词,以中文逗号(,)分割,或点击下方“主题词样例”。最多10个')
            send_button = gr.Button(value="提交")
            key_cng = gr.CheckboxGroup(label='主题词样例', choices=sample(keywords_list, 100))
            

        with gr.Column():
            lyric_txt = gr.Textbox(label='生成歌词')
    key_cng.change(key_cng_change, [key_cng], [key_txt])
    send_button.click(send, [singer_txt, song_txt, key_txt], [lyric_txt,key_cng])

if __name__ == "__main__":
    demo.launch()