PLatonG commited on
Commit
fc0baa4
1 Parent(s): 2054247

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -0
app.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Prompter import Prompter
2
+ from Callback import Stream, Iteratorize
3
+ import os
4
+ import sys
5
+
6
+ import gradio as gr
7
+ import torch
8
+ import transformers
9
+ from peft import PeftModel
10
+ from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
11
+ import pandas as pd
12
+ import numpy as np
13
+
14
+ if torch.cuda.is_available():
15
+ device = "cuda"
16
+ else:
17
+ device = "cpu"
18
+
19
+ try:
20
+ if torch.backends.mps.is_available():
21
+ device = "mps"
22
+ except: # noqa: E722
23
+ pass
24
+
25
+ base_model = "openthaigpt/openthaigpt-1.0.0-beta-7b-chat-ckpt-hf"
26
+ load_8bit = True
27
+ # lora_weights = "PLatonG/openthaigpt-1.0.0-beta-7b-expert-recommendations"
28
+ lora_weights = "PLatonG/openthaigpt-1.0.0-beta-7b-expert-recommendations"
29
+ prompter = Prompter("alpaca")
30
+ tokenizer = LlamaTokenizer.from_pretrained(base_model)
31
+
32
+ model = LlamaForCausalLM.from_pretrained(
33
+ base_model,
34
+ load_in_8bit=load_8bit,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto",
37
+ offload_folder = "./offload"
38
+ )
39
+ model = PeftModel.from_pretrained(
40
+ model,
41
+ lora_weights,
42
+ torch_dtype=torch.float16,
43
+ offload_folder = "./offload"
44
+ )
45
+
46
+ # unwind broken decapoda-research config
47
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
48
+ model.config.bos_token_id = 1
49
+ model.config.eos_token_id = 2
50
+
51
+ if not load_8bit:
52
+ model.half() # seems to fix bugs for some users.
53
+
54
+ model.eval()
55
+ if torch.__version__ >= "2" and sys.platform != "win32":
56
+ model = torch.compile(model)
57
+
58
+ def evaluate(
59
+ instruction,
60
+ input=None,
61
+ stream_output=False,
62
+ **kwargs,
63
+ ):
64
+ temperature=0.5
65
+ top_p=0.75
66
+ top_k=40
67
+ num_beams=4
68
+ max_new_tokens=380
69
+
70
+ prompt = prompter.generate_prompt(instruction, input)
71
+ inputs = tokenizer(prompt, return_tensors="pt")
72
+ input_ids = inputs["input_ids"].to(device)
73
+ generation_config = GenerationConfig(
74
+ temperature=temperature,
75
+ top_p=top_p,
76
+ top_k=top_k,
77
+ num_beams=num_beams,
78
+ **kwargs,
79
+ )
80
+
81
+ generate_params = {
82
+ "input_ids": input_ids,
83
+ "generation_config": generation_config,
84
+ "return_dict_in_generate": True,
85
+ "output_scores": True,
86
+ "max_new_tokens": max_new_tokens,
87
+ }
88
+
89
+ if stream_output:
90
+ # Stream the reply 1 token at a time.
91
+ # This is based on the trick of using 'stopping_criteria' to create an iterator,
92
+ # from https://github.com/oobabooga/text-generation-webui/blob/ad37f396fc8bcbab90e11ecf17c56c97bfbd4a9c/modules/text_generation.py#L216-L243.
93
+
94
+ def generate_with_callback(callback=None, **kwargs):
95
+ kwargs.setdefault(
96
+ "stopping_criteria", transformers.StoppingCriteriaList()
97
+ )
98
+ kwargs["stopping_criteria"].append(
99
+ Stream(callback_func=callback)
100
+ )
101
+ with torch.no_grad():
102
+ model.generate(**kwargs)
103
+
104
+ def generate_with_streaming(**kwargs):
105
+ return Iteratorize(
106
+ generate_with_callback, kwargs, callback=None
107
+ )
108
+
109
+ with generate_with_streaming(**generate_params) as generator:
110
+ for output in generator:
111
+ # new_tokens = len(output) - len(input_ids[0])
112
+ decoded_output = tokenizer.decode(output)
113
+
114
+ if output[-1] in [tokenizer.eos_token_id]:
115
+ break
116
+
117
+ yield prompter.get_response(decoded_output)
118
+ return # early return for stream_output
119
+
120
+ # Without streaming
121
+ with torch.no_grad():
122
+ generation_output = model.generate(
123
+ input_ids=input_ids,
124
+ generation_config=generation_config,
125
+ return_dict_in_generate=True,
126
+ output_scores=True,
127
+ max_new_tokens=max_new_tokens,
128
+ )
129
+ s = generation_output.sequences[0]
130
+ output = tokenizer.decode(s)
131
+ yield prompter.get_response(output)
132
+
133
+
134
+ # From SMOTE with 4 neightbor
135
+ fourNSMOTE = pd.read_csv("FILTER_GREATERTHANTHREE_FROM_SHEETS_SMOTE_train.csv")
136
+
137
+ with gr.Blocks() as demo:
138
+ birth_year = gr.components.Number(minimum = 2536, maximum = 2567, value= 2545,
139
+ label="ปีเกิด",
140
+ info="ต่ำสุด : 2536 สูงสุด : 2567")
141
+ nationality_name = gr.components.Dropdown(choices=fourNSMOTE.NATIONALITY_NAME.unique().tolist(),
142
+ label="สัญชาติ",
143
+ value = fourNSMOTE.NATIONALITY_NAME.unique().tolist()[0])
144
+ religion_name = gr.components.Dropdown(choices=fourNSMOTE.RELIGION_NAME.unique().tolist(),
145
+ label="ศาสนา",
146
+ value = fourNSMOTE.RELIGION_NAME.unique().tolist()[0])
147
+ sex = gr.components.Dropdown(choices=fourNSMOTE.JVN_SEX.unique().tolist(),
148
+ label="เพศ",
149
+ value = fourNSMOTE.JVN_SEX.unique().tolist()[0])
150
+ inform_status = gr.components.Dropdown(choices=fourNSMOTE.INFORM_STATUS_TXT.unique().tolist(),
151
+ label="เหตุที่นำมาสู่การดำเนินคดี",
152
+ value = fourNSMOTE.INFORM_STATUS_TXT.unique().tolist()[0])
153
+ age = gr.components.Number(minimum = 10, maximum = 19, value= 17,
154
+ label="อายุตอนกระทำผิด",
155
+ info="ต่ำสุด : 10 ปี สูงสุด : 19")
156
+
157
+ offense_name = gr.components.Dropdown(choices=fourNSMOTE.OFFENSE_NAME.unique().tolist(),
158
+ label="คดีที่กระทำผิด",
159
+ value = fourNSMOTE.OFFENSE_NAME.unique().tolist()[0])
160
+
161
+ ref_value = fourNSMOTE.OFFENSE_NAME.unique().tolist()[0]
162
+
163
+ allegation_name = gr.components.Dropdown(choices=fourNSMOTE.ALLEGATION_NAME.unique().tolist(), label="ชื่อของข้อกล่าวหา",
164
+ value = fourNSMOTE.query("OFFENSE_NAME == @ref_value")["ALLEGATION_NAME"].unique().tolist()[0])
165
+
166
+ allegation_desc = gr.components.Dropdown(choices=fourNSMOTE.ALLEGATION_DESC.unique().tolist(), label="รายละเอียดของข้อกล่าวหา",
167
+ value = fourNSMOTE.query("OFFENSE_NAME == @ref_value")["ALLEGATION_DESC"].unique().tolist()[0])
168
+
169
+ def update_dropDown(value):
170
+ query_state = fourNSMOTE.query("OFFENSE_NAME == @value")
171
+ allegation_name = gr.components.Dropdown(choices=query_state["ALLEGATION_NAME"].unique().tolist())
172
+ allegation_desc = gr.components.Dropdown(choices=query_state["ALLEGATION_DESC"].unique().tolist())
173
+ return allegation_name, allegation_desc
174
+
175
+ offense_name.change(fn=update_dropDown, inputs=offense_name, outputs=[allegation_name, allegation_desc])
176
+
177
+ rn1 = gr.components.Radio(choices=["ถูก", "ผิด"],
178
+ label="ปรากฎลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่",
179
+ value="ถูก")
180
+ rn2 = gr.components.Radio(choices=["ถูก", "ผิด"],
181
+ label="ปรากฎประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย",
182
+ value = "ถูก")
183
+ rn3 = gr.components.Radio(choices=["ถูก", "ผิด"],
184
+ label="ปรากฎประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว",
185
+ value = "ถูก")
186
+
187
+ education = gr.components.Dropdown(choices=fourNSMOTE.RN3_14_HIS_EDU_FLAG.unique().tolist(),
188
+ label="สถาณะการศึกษา",
189
+ value = fourNSMOTE.RN3_14_HIS_EDU_FLAG.unique().tolist()[0])
190
+ occupation = gr.components.Dropdown(choices=fourNSMOTE.RN3_19_OCCUPATION_STATUS.unique().tolist(),
191
+ label="สถาณะการประกอบอาชีพ",
192
+ value = fourNSMOTE.RN3_19_OCCUPATION_STATUS.unique().tolist()[0])
193
+ province = gr.components.Dropdown(choices=fourNSMOTE.PROVINCE_NAME.unique().tolist(),
194
+ label="จังหวัดที่กระทำผิด",
195
+ value = fourNSMOTE.PROVINCE_NAME.unique().tolist()[0])
196
+
197
+
198
+ def generate_input(birth_year, nationality_name, religion_name, sex,
199
+ inform_status, age, offense_name, allegation_name,
200
+ allegation_desc, rn1, rn2, rn3, education, occupation, province):
201
+
202
+ birth_year = f"เกิดเมื่อปี พ.ศ. {int(birth_year)}"
203
+
204
+ if int(age) >= 10 or int(age) <=15:
205
+ age = f"มีอายุอยู่ในช่วง 10 ถึง 15 ปี"
206
+ elif int(age) >=16 or int(age) <= 20:
207
+ age = f"มีอายุอยู่ในช่วง 16 ถึง 20 ปี"
208
+ elif int(age) >=21 or int(age) <= 25:
209
+ age = f"มีอายุอยู่ในช่วง 21 ถึง 25 ปี"
210
+ elif int(age) >=26:
211
+ age = f"มีอายุอยู่ในช่วง 26 ปีขึ้นไป"
212
+
213
+ if rn1 == "ถูก":
214
+ rn1 = "มีลักษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่"
215
+ else:
216
+ rn1 = "ไม่มีลั��ษณะนิสัย/พฤติกรรมที่ไม่เหมาะสมของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่"
217
+
218
+ if rn2 == "ถูก":
219
+ rn2 = "มีประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย"
220
+ else:
221
+ rn2 = "ไม่มีประวัติการกระทำผิดของบุคคลในครอบครัวและบุคคลที่เด็ก/เยาวชนอาศัยอยู่ด้วย"
222
+
223
+ if rn3 == "ถูก":
224
+ rn3 = "มีประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว"
225
+ else:
226
+ rn3 = "ไม่มีประวัติการเกี่ยวข้องกับยาเสพติดของบุคคลในครอบครัว"
227
+
228
+ instruciton = "จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้"
229
+ input = f"{birth_year} {nationality_name} {religion_name} {sex} {inform_status} {age} {offense_name} {allegation_name} {allegation_desc} {rn1} {rn2} {rn3} {education} {occupation} {province}"
230
+
231
+
232
+ return input
233
+
234
+ def generate_output(instruction, input):
235
+
236
+ return input
237
+
238
+
239
+ def generate_input2(*values):
240
+ return "คำสั่ง : จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้ " + " ".join(str(value) for value in values)
241
+
242
+
243
+
244
+ instruction = gr.Textbox(label = "คำสั่ง", value="จงสร้างคำแนะนำของผู้เชี่ยวชาญจากปัจจัยดังต่อไปนี้", visible=True, interactive=False)
245
+ input_compo = gr.Textbox(label = "ข้อมูลเข้า (input)")
246
+ outputModel = gr.Textbox(label= "ผลลัพธ์ (output)")
247
+ stream_output = gr.components.Checkbox(label="Stream output")
248
+
249
+
250
+ btn1 = gr.Button("GENERATE INPUT")
251
+
252
+ # show input text format for user
253
+ btn1.click(fn=generate_input, inputs=[birth_year, nationality_name, religion_name, sex,
254
+ inform_status, age, offense_name, allegation_name,
255
+ allegation_desc, rn1, rn2, rn3, education, occupation, province],
256
+ outputs=input_compo)
257
+
258
+ btn2 = gr.Button("GENERATE OUTPUT")
259
+ btn2.click(fn=evaluate, inputs=[instruction, input_compo, stream_output], outputs=outputModel)
260
+
261
+ # outputChatInterface = gr.ChatInterface(fn=evaluate)
262
+
263
+
264
+
265
+ # input text format for model
266
+ # btn.click(fn=generate_text_test2, inputs = [birth_year, nationality_name, religion_name, sex,
267
+ # inform_status, age, offense_name, allegation_name,
268
+ # allegation_desc, rn1, rn2, rn3, education, occupation, province],
269
+ # outputs = input_compo)
270
+
271
+
272
+
273
+ demo.launch(debug=True, share=True)