souljoy's picture
Update app.py
370df52
raw
history blame
No virus
2.66 kB
import gradio as gr
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from random import sample
model_path = 'souljoy/t5-chinese-lyric'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device)
keywords_list = []
with open('keywords', mode='r', encoding='utf-8') as f:
for line in f:
keywords_list.append(line.strip())
def key_cng_change(v_list):
return ','.join(v_list)
def post_process(text):
return text.replace("\\n", "\n").replace("\\t", "\t")
def send(my_singer_txt, my_song_txt, my_key_txt):
if len(my_singer_txt) == 0:
lyric = '请输入歌手名'
elif len(my_song_txt) == 0:
lyric = '请输入歌名'
elif len(my_key_txt) == 0:
lyric = '歌词主题词'
elif len(my_key_txt.split(',')) > 10:
lyric = '歌词主太多!最多10个'
else:
text = '用户:写一首歌,歌手“{}”,歌曲“{}”,以“{}”为主题。\\n小L:'.format(my_singer_txt, my_song_txt, my_key_txt)
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=64, return_tensors="pt").to(device)
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512,
do_sample=True, top_p=1, temperature=0.7, no_repeat_ngram_size=3)
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
lyric = post_process(out_text[0])
return lyric, gr.CheckboxGroup.update(choices=sample(keywords_list, 100), value=[])
with gr.Blocks() as demo:
gr.Markdown("""#### 歌词创作机器人 📻 🎵 """)
with gr.Row():
with gr.Column():
singer_txt = gr.Textbox(label='歌手名', placeholder='请输入你的名字,或者其他歌手名字')
song_txt = gr.Textbox(label='歌名', placeholder='请输入要创作的歌词名称')
key_txt = gr.Textbox(label='歌词主题词', placeholder='请输入主题词,以中文逗号(,)分割,或点击下方“主题词样例”。最多10个')
send_button = gr.Button(value="提交")
key_cng = gr.CheckboxGroup(label='主题词样例', choices=sample(keywords_list, 100))
with gr.Column():
lyric_txt = gr.Textbox(label='生成歌词')
key_cng.change(key_cng_change, [key_cng], [key_txt])
send_button.click(send, [singer_txt, song_txt, key_txt], [lyric_txt,key_cng])
if __name__ == "__main__":
demo.launch()