import gradio as gr from transformers import T5Tokenizer, T5ForConditionalGeneration import torch from random import sample model_path = 'souljoy/t5-chinese-lyric' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tokenizer = T5Tokenizer.from_pretrained(model_path) model = T5ForConditionalGeneration.from_pretrained(model_path).to(device) keywords_list = [] with open('keywords', mode='r', encoding='utf-8') as f: for line in f: keywords_list.append(line.strip()) def key_cng_change(v_list): return ','.join(v_list) def post_process(text): return text.replace("\\n", "\n").replace("\\t", "\t") def send(my_singer_txt, my_song_txt, my_key_txt): if len(my_singer_txt) == 0: lyric = '请输入歌手名' elif len(my_song_txt) == 0: lyric = '请输入歌名' elif len(my_key_txt) == 0: lyric = '歌词主题词' elif len(my_key_txt.split(',')) > 10: lyric = '歌词主太多!最多10个' else: text = '用户:写一首歌,歌手“{}”,歌曲“{}”,以“{}”为主题。\\n小L:'.format(my_singer_txt, my_song_txt, my_key_txt) encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=64, return_tensors="pt").to(device) out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, do_sample=True, top_p=1, temperature=0.7, no_repeat_ngram_size=3) out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True) lyric = post_process(out_text[0]) return lyric, gr.CheckboxGroup.update(choices=sample(keywords_list, 100), value=[]) with gr.Blocks() as demo: gr.Markdown("""#### 歌词创作机器人 📻 🎵 """) with gr.Row(): with gr.Column(): singer_txt = gr.Textbox(label='歌手名', placeholder='请输入你的名字,或者其他歌手名字') song_txt = gr.Textbox(label='歌名', placeholder='请输入要创作的歌词名称') key_txt = gr.Textbox(label='歌词主题词', placeholder='请输入主题词,以中文逗号(,)分割,或点击下方“主题词样例”。最多10个') send_button = gr.Button(value="提交") key_cng = gr.CheckboxGroup(label='主题词样例', choices=sample(keywords_list, 100)) with gr.Column(): lyric_txt = gr.Textbox(label='生成歌词') key_cng.change(key_cng_change, [key_cng], [key_txt]) send_button.click(send, [singer_txt, song_txt, key_txt], [lyric_txt,key_cng]) if __name__ == "__main__": demo.launch()