Update main.py
Browse files
main.py
CHANGED
@@ -18,11 +18,11 @@ HF_TOKEN = getenv("HF_TOKEN")
|
|
18 |
|
19 |
class InputData(BaseModel):
|
20 |
model: str
|
21 |
-
system_prompt_template:
|
22 |
-
prompt_template:
|
23 |
end_token: str
|
24 |
-
|
25 |
-
|
26 |
history: str = ""
|
27 |
segment: bool = False
|
28 |
max_sentences: Optional[int] = None
|
@@ -36,28 +36,25 @@ async def generate_response(data: InputData) -> Dict[str, Any]:
|
|
36 |
if data.max_sentences is not None and data.max_sentences != 0:
|
37 |
data.segment = True
|
38 |
elif data.max_sentences == 0:
|
39 |
-
for
|
40 |
-
|
41 |
-
data.history += prompt.replace("{Prompt}", user_input) + "\n"
|
42 |
return {
|
43 |
"response": "",
|
44 |
"history": data.history + data.end_token
|
45 |
}
|
46 |
|
47 |
-
user_input_str = ""
|
48 |
if data.segment:
|
49 |
-
for user_input in data.
|
50 |
user_sentences = tokenizer.tokenize(user_input)
|
51 |
-
user_input_str
|
|
|
52 |
else:
|
53 |
-
|
54 |
-
|
55 |
-
for prompt in data.prompt_template:
|
56 |
-
data.history += prompt.replace("{Prompt}", user_input_str) + "\n"
|
57 |
|
58 |
inputs = ""
|
59 |
-
for system_prompt in data.
|
60 |
-
inputs +=
|
61 |
inputs += data.history
|
62 |
|
63 |
seed = random.randint(0, 2**32 - 1)
|
@@ -116,4 +113,4 @@ async def check_word(data: WordCheckData) -> Dict[str, Any]:
|
|
116 |
"found": found
|
117 |
}
|
118 |
|
119 |
-
return result
|
|
|
18 |
|
19 |
class InputData(BaseModel):
|
20 |
model: str
|
21 |
+
system_prompt_template: str
|
22 |
+
prompt_template: str
|
23 |
end_token: str
|
24 |
+
system_prompts: List[str]
|
25 |
+
user_inputs: List[str]
|
26 |
history: str = ""
|
27 |
segment: bool = False
|
28 |
max_sentences: Optional[int] = None
|
|
|
36 |
if data.max_sentences is not None and data.max_sentences != 0:
|
37 |
data.segment = True
|
38 |
elif data.max_sentences == 0:
|
39 |
+
for user_input in data.user_inputs:
|
40 |
+
data.history += data.prompt_template.replace("{Prompt}", user_input)
|
|
|
41 |
return {
|
42 |
"response": "",
|
43 |
"history": data.history + data.end_token
|
44 |
}
|
45 |
|
|
|
46 |
if data.segment:
|
47 |
+
for user_input in data.user_inputs:
|
48 |
user_sentences = tokenizer.tokenize(user_input)
|
49 |
+
user_input_str = "\n".join(user_sentences)
|
50 |
+
data.history += data.prompt_template.replace("{Prompt}", user_input_str) + "\n"
|
51 |
else:
|
52 |
+
for user_input in data.user_inputs:
|
53 |
+
data.history += data.prompt_template.replace("{Prompt}", user_input) + "\n"
|
|
|
|
|
54 |
|
55 |
inputs = ""
|
56 |
+
for system_prompt in data.system_prompts:
|
57 |
+
inputs += data.system_prompt_template.replace("{SystemPrompt}", system_prompt) + "\n"
|
58 |
inputs += data.history
|
59 |
|
60 |
seed = random.randint(0, 2**32 - 1)
|
|
|
113 |
"found": found
|
114 |
}
|
115 |
|
116 |
+
return result
|