Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
import torch | |
from random import sample | |
model_path = 'souljoy/t5-chinese-lyric' | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
tokenizer = T5Tokenizer.from_pretrained(model_path) | |
model = T5ForConditionalGeneration.from_pretrained(model_path).to(device) | |
keywords_list = [] | |
with open('keywords', mode='r', encoding='utf-8') as f: | |
for line in f: | |
keywords_list.append(line.strip()) | |
def key_cng_change(v_list): | |
return ','.join(v_list) | |
def post_process(text): | |
return text.replace("\\n", "\n").replace("\\t", "\t") | |
def send(my_singer_txt, my_song_txt, my_key_txt): | |
if len(my_singer_txt) == 0: | |
lyric = '请输入歌手名' | |
elif len(my_song_txt) == 0: | |
lyric = '请输入歌名' | |
elif len(my_key_txt) == 0: | |
lyric = '歌词主题词' | |
elif len(my_key_txt.split(',')) > 10: | |
lyric = '歌词主太多!最多10个' | |
else: | |
text = '用户:写一首歌,歌手“{}”,歌曲“{}”,以“{}”为主题。\\n小L:'.format(my_singer_txt, my_song_txt, my_key_txt) | |
encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=64, return_tensors="pt").to(device) | |
out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, | |
do_sample=True, top_p=1, temperature=0.7, no_repeat_ngram_size=3) | |
out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True) | |
lyric = post_process(out_text[0]) | |
return lyric, gr.CheckboxGroup.update(choices=sample(keywords_list, 100), value=[]) | |
with gr.Blocks() as demo: | |
gr.Markdown("""#### 歌词创作机器人 📻 🎵 """) | |
with gr.Row(): | |
with gr.Column(): | |
singer_txt = gr.Textbox(label='歌手名', placeholder='请输入你的名字,或者其他歌手名字') | |
song_txt = gr.Textbox(label='歌名', placeholder='请输入要创作的歌词名称') | |
key_txt = gr.Textbox(label='歌词主题词', placeholder='请输入主题词,以中文逗号(,)分割,或点击下方“主题词样例”。最多10个') | |
send_button = gr.Button(value="提交") | |
key_cng = gr.CheckboxGroup(label='主题词样例', choices=sample(keywords_list, 100)) | |
with gr.Column(): | |
lyric_txt = gr.Textbox(label='生成歌词') | |
key_cng.change(key_cng_change, [key_cng], [key_txt]) | |
send_button.click(send, [singer_txt, song_txt, key_txt], [lyric_txt,key_cng]) | |
if __name__ == "__main__": | |
demo.launch() | |