oflakne26 commited on
Commit
f319ec8
1 Parent(s): d17ac0e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +20 -15
main.py CHANGED
@@ -1,5 +1,5 @@
1
  from fastapi import FastAPI, HTTPException
2
- from typing import Any, Dict, Optional
3
  from pydantic import BaseModel
4
  from os import getenv
5
  from huggingface_hub import InferenceClient
@@ -18,11 +18,11 @@ HF_TOKEN = getenv("HF_TOKEN")
18
 
19
  class InputData(BaseModel):
20
  model: str
21
- system_prompt_template: str
22
- prompt_template: str
23
  end_token: str
24
- system_prompt: str
25
- user_input: str
26
  history: str = ""
27
  segment: bool = False
28
  max_sentences: Optional[int] = None
@@ -36,24 +36,29 @@ async def generate_response(data: InputData) -> Dict[str, Any]:
36
  if data.max_sentences is not None and data.max_sentences != 0:
37
  data.segment = True
38
  elif data.max_sentences == 0:
39
- data.history += data.prompt_template.replace("{Prompt}", data.user_input)
 
 
40
  return {
41
  "response": "",
42
  "history": data.history + data.end_token
43
  }
44
 
 
45
  if data.segment:
46
- user_sentences = tokenizer.tokenize(data.user_input)
47
- user_input_str = "\n".join(user_sentences)
 
48
  else:
49
- user_input_str = data.user_input
50
 
51
- data.history += data.prompt_template.replace("{Prompt}", user_input_str)
 
52
 
53
- inputs = (
54
- data.system_prompt_template.replace("{SystemPrompt}", data.system_prompt) +
55
- data.history
56
- )
57
 
58
  seed = random.randint(0, 2**32 - 1)
59
 
@@ -111,4 +116,4 @@ async def check_word(data: WordCheckData) -> Dict[str, Any]:
111
  "found": found
112
  }
113
 
114
- return result
 
1
  from fastapi import FastAPI, HTTPException
2
+ from typing import Any, Dict, List, Optional
3
  from pydantic import BaseModel
4
  from os import getenv
5
  from huggingface_hub import InferenceClient
 
18
 
19
  class InputData(BaseModel):
20
  model: str
21
+ system_prompt_template: List[str]
22
+ prompt_template: List[str]
23
  end_token: str
24
+ system_prompt: List[str]
25
+ user_input: List[str]
26
  history: str = ""
27
  segment: bool = False
28
  max_sentences: Optional[int] = None
 
36
  if data.max_sentences is not None and data.max_sentences != 0:
37
  data.segment = True
38
  elif data.max_sentences == 0:
39
+ for prompt in data.prompt_template:
40
+ for user_input in data.user_input:
41
+ data.history += prompt.replace("{Prompt}", user_input) + "\n"
42
  return {
43
  "response": "",
44
  "history": data.history + data.end_token
45
  }
46
 
47
+ user_input_str = ""
48
  if data.segment:
49
+ for user_input in data.user_input:
50
+ user_sentences = tokenizer.tokenize(user_input)
51
+ user_input_str += "\n".join(user_sentences) + "\n"
52
  else:
53
+ user_input_str = "\n".join(data.user_input)
54
 
55
+ for prompt in data.prompt_template:
56
+ data.history += prompt.replace("{Prompt}", user_input_str) + "\n"
57
 
58
+ inputs = ""
59
+ for system_prompt in data.system_prompt_template:
60
+ inputs += system_prompt.replace("{SystemPrompt}", "\n".join(data.system_prompt)) + "\n"
61
+ inputs += data.history
62
 
63
  seed = random.randint(0, 2**32 - 1)
64
 
 
116
  "found": found
117
  }
118
 
119
+ return result