彭见峡 commited on
Commit
276b6a9
1 Parent(s): 93543cf

Add application file

Browse files
Files changed (2) hide show
  1. app.py +209 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba Cloud.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """A simple web interactive chat demo based on gradio."""
7
+ import os
8
+ from argparse import ArgumentParser
9
+
10
+ import gradio as gr
11
+ import mdtex2html
12
+
13
+ import torch
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+ from transformers.generation import GenerationConfig
16
+
17
+
18
+ DEFAULT_CKPT_PATH = 'Qwen/Qwen-7B-Chat'
19
+
20
+
21
+ def _get_args():
22
+ parser = ArgumentParser()
23
+ parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
24
+ help="Checkpoint name or path, default to %(default)r")
25
+ parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
26
+
27
+ parser.add_argument("--share", action="store_true", default=False,
28
+ help="Create a publicly shareable link for the interface.")
29
+ parser.add_argument("--inbrowser", action="store_true", default=False,
30
+ help="Automatically launch the interface in a new tab on the default browser.")
31
+ parser.add_argument("--server-port", type=int, default=8000,
32
+ help="Demo server port.")
33
+ parser.add_argument("--server-name", type=str, default="127.0.0.1",
34
+ help="Demo server name.")
35
+
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def _load_model_tokenizer(args):
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
43
+ )
44
+
45
+ if args.cpu_only:
46
+ device_map = "cpu"
47
+ else:
48
+ device_map = "auto"
49
+
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ args.checkpoint_path,
52
+ device_map=device_map,
53
+ trust_remote_code=True,
54
+ resume_download=True,
55
+ ).eval()
56
+
57
+ config = GenerationConfig.from_pretrained(
58
+ args.checkpoint_path, trust_remote_code=True, resume_download=True,
59
+ )
60
+
61
+ return model, tokenizer, config
62
+
63
+
64
+ def postprocess(self, y):
65
+ if y is None:
66
+ return []
67
+ for i, (message, response) in enumerate(y):
68
+ y[i] = (
69
+ None if message is None else mdtex2html.convert(message),
70
+ None if response is None else mdtex2html.convert(response),
71
+ )
72
+ return y
73
+
74
+
75
+ gr.Chatbot.postprocess = postprocess
76
+
77
+
78
+ def _parse_text(text):
79
+ lines = text.split("\n")
80
+ lines = [line for line in lines if line != ""]
81
+ count = 0
82
+ for i, line in enumerate(lines):
83
+ if "```" in line:
84
+ count += 1
85
+ items = line.split("`")
86
+ if count % 2 == 1:
87
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
88
+ else:
89
+ lines[i] = f"<br></code></pre>"
90
+ else:
91
+ if i > 0:
92
+ if count % 2 == 1:
93
+ line = line.replace("`", r"\`")
94
+ line = line.replace("<", "&lt;")
95
+ line = line.replace(">", "&gt;")
96
+ line = line.replace(" ", "&nbsp;")
97
+ line = line.replace("*", "&ast;")
98
+ line = line.replace("_", "&lowbar;")
99
+ line = line.replace("-", "&#45;")
100
+ line = line.replace(".", "&#46;")
101
+ line = line.replace("!", "&#33;")
102
+ line = line.replace("(", "&#40;")
103
+ line = line.replace(")", "&#41;")
104
+ line = line.replace("$", "&#36;")
105
+ lines[i] = "<br>" + line
106
+ text = "".join(lines)
107
+ return text
108
+
109
+
110
+ def _gc():
111
+ import gc
112
+ gc.collect()
113
+ if torch.cuda.is_available():
114
+ torch.cuda.empty_cache()
115
+
116
+
117
+ def _launch_demo(args, model, tokenizer, config):
118
+
119
+ def predict(_query, _chatbot, _task_history):
120
+ print(f"User: {_parse_text(_query)}")
121
+ _chatbot.append((_parse_text(_query), ""))
122
+ full_response = ""
123
+
124
+ for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
125
+ _chatbot[-1] = (_parse_text(_query), _parse_text(response))
126
+
127
+ yield _chatbot
128
+ full_response = _parse_text(response)
129
+
130
+ print(f"History: {_task_history}")
131
+ _task_history.append((_query, full_response))
132
+ print(f"Qwen-Chat: {_parse_text(full_response)}")
133
+
134
+ def regenerate(_chatbot, _task_history):
135
+ if not _task_history:
136
+ yield _chatbot
137
+ return
138
+ item = _task_history.pop(-1)
139
+ _chatbot.pop(-1)
140
+ yield from predict(item[0], _chatbot, _task_history)
141
+
142
+ def reset_user_input():
143
+ return gr.update(value="")
144
+
145
+ def reset_state(_chatbot, _task_history):
146
+ _task_history.clear()
147
+ _chatbot.clear()
148
+ _gc()
149
+ return _chatbot
150
+
151
+ with gr.Blocks() as demo:
152
+ gr.Markdown("""\
153
+ <p align="center"><img src="https://qianwen-res.oss-cn-beijing.aliyuncs.com/logo_qwen.jpg" style="height: 80px"/><p>""")
154
+ gr.Markdown("""<center><font size=8>Qwen-Chat Bot</center>""")
155
+ gr.Markdown(
156
+ """\
157
+ <center><font size=3>This WebUI is based on Qwen-Chat, developed by Alibaba Cloud. \
158
+ (本WebUI基于Qwen-Chat打造,实现聊天机器人功能。)</center>""")
159
+ gr.Markdown("""\
160
+ <center><font size=4>
161
+ Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a> |
162
+ <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp |
163
+ Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 </a> |
164
+ <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp |
165
+ Qwen-14B <a href="https://modelscope.cn/models/qwen/Qwen-14B/summary">🤖 </a> |
166
+ <a href="https://huggingface.co/Qwen/Qwen-14B">🤗</a>&nbsp |
167
+ Qwen-14B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary">🤖 </a> |
168
+ <a href="https://huggingface.co/Qwen/Qwen-14B-Chat">🤗</a>&nbsp |
169
+ &nbsp<a href="https://github.com/QwenLM/Qwen">Github</a></center>""")
170
+
171
+ chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height")
172
+ query = gr.Textbox(lines=2, label='Input')
173
+ task_history = gr.State([])
174
+
175
+ with gr.Row():
176
+ empty_btn = gr.Button("🧹 Clear History (清除历史)")
177
+ submit_btn = gr.Button("🚀 Submit (发送)")
178
+ regen_btn = gr.Button("🤔️ Regenerate (重试)")
179
+
180
+ submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True)
181
+ submit_btn.click(reset_user_input, [], [query])
182
+ empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
183
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
184
+
185
+ gr.Markdown("""\
186
+ <font size=2>Note: This demo is governed by the original license of Qwen. \
187
+ We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
188
+ including hate speech, violence, pornography, deception, etc. \
189
+ (注:本演示受Qwen的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
190
+ 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
191
+
192
+ demo.queue().launch(
193
+ share=args.share,
194
+ inbrowser=args.inbrowser,
195
+ server_port=args.server_port,
196
+ server_name=args.server_name,
197
+ )
198
+
199
+
200
+ def main():
201
+ args = _get_args()
202
+
203
+ model, tokenizer, config = _load_model_tokenizer(args)
204
+
205
+ _launch_demo(args, model, tokenizer, config)
206
+
207
+
208
+ if __name__ == '__main__':
209
+ main()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gradio<3.42
2
+ mdtex2html