oflakne26 commited on
Commit
e53fb7b
1 Parent(s): afb8fad

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +84 -51
main.py CHANGED
@@ -1,61 +1,94 @@
1
- import os
2
- from fastapi import FastAPI, HTTPException, Depends
3
  from pydantic import BaseModel
4
- from ctransformers import AutoModelForCausalLM
 
 
 
 
5
 
6
- # Pydantic object for request validation
7
- class Validation(BaseModel):
8
- inputs: str
9
- temperature: float = 0.0
10
- max_new_tokens: int = 1048
11
- top_p: float = 0.15
12
- repetition_penalty: float = 1.0
13
-
14
- # Initialize FastAPI app
15
  app = FastAPI()
16
 
17
- # Function to load models and create endpoints
18
- def setup_endpoints(app):
19
- model_base_path = './models'
20
- if not os.path.exists(model_base_path) or not os.path.isdir(model_base_path):
21
- raise RuntimeError("Models directory does not exist or is not a directory")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- model_dirs = [d for d in os.listdir(model_base_path) if os.path.isdir(os.path.join(model_base_path, d))]
 
 
 
 
24
 
25
- if not model_dirs:
26
- raise RuntimeError("No models found in the models directory")
27
 
28
- models = {}
29
 
30
- # Load each model
31
- for model_name in model_dirs:
32
- model_path = os.path.join(model_base_path, model_name)
33
  try:
34
- model = AutoModelForCausalLM.from_pretrained(model_path, threads=2)
35
- models[model_name] = model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  except Exception as e:
37
- print(f"Failed to load model {model_name}: {e}")
38
- continue
39
-
40
- # Function to get model dependency
41
- def get_model(model_name: str):
42
- if model_name not in models:
43
- raise HTTPException(status_code=404, detail="Model not found")
44
- return models[model_name]
45
-
46
- # Create an endpoint for each model
47
- for model_name in model_dirs:
48
- @app.post(f"/{model_name}")
49
- async def generate_response(item: Validation, model=Depends(lambda: get_model(model_name))):
50
- try:
51
- response = model(item.inputs,
52
- temperature=item.temperature,
53
- max_new_tokens=item.max_new_tokens,
54
- top_p=item.top_p,
55
- repetition_penalty=item.repetition_penalty)
56
- return response
57
- except Exception as e:
58
- raise HTTPException(status_code=500, detail=str(e))
59
-
60
- # Setup endpoints
61
- setup_endpoints(app)
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from typing import Any
3
  from pydantic import BaseModel
4
+ from os import getenv
5
+ from huggingface_hub import InferenceClient
6
+ import random
7
+ from json_repair import repair_json
8
+ import nltk
9
 
 
 
 
 
 
 
 
 
 
10
  app = FastAPI()
11
 
12
+ nltk.download('punkt')
13
+
14
+ tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
15
+
16
+ HF_TOKEN = getenv("HF_TOKEN")
17
+ MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
18
+ FALLBACK_MODELS = [
19
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
20
+ "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
21
+ ]
22
+
23
+ class InputData(BaseModel):
24
+ model: str
25
+ system_prompt_template: str
26
+ prompt_template: str
27
+ system_prompt: str
28
+ user_input: str
29
+ json_prompt: str
30
+ history: str = ""
31
+
32
+ @app.post("/generate-response/")
33
+ async def generate_response(data: InputData) -> Any:
34
+ client = InferenceClient(model=data.model, token=HF_TOKEN)
35
+
36
+ sentences = tokenizer.tokenize(data.user_input)
37
+ data_dict = {'###New response###': [], '###Sentence count###': 0}
38
+ for i, sentence in enumerate(sentences):
39
+ data_dict["###New response###"].append(sentence)
40
+ data_dict["###Sentence count###"] = i + 1
41
+
42
+ data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
43
 
44
+ inputs = (
45
+ data.system_prompt_template.replace("{SystemPrompt}",
46
+ data.system_prompt) +
47
+ data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
48
+ data.history)
49
 
50
+ seed = random.randint(0, 2**32 - 1)
 
51
 
52
+ models_to_try = [data.model] + FALLBACK_MODELS
53
 
54
+ for model in models_to_try:
 
 
55
  try:
56
+ response = client.text_generation(inputs,
57
+ temperature=1.0,
58
+ max_new_tokens=1000,
59
+ seed=seed)
60
+
61
+ strict_response = str(response)
62
+
63
+ repaired_response = repair_json(strict_response,
64
+ return_objects=True)
65
+
66
+ if isinstance(repaired_response, str):
67
+ raise HTTPException(status_code=500, detail="Invalid response from model")
68
+ else:
69
+ cleaned_response = {}
70
+ for key, value in repaired_response.items():
71
+ cleaned_key = key.replace("###", "")
72
+ cleaned_response[cleaned_key] = value
73
+
74
+ for i, text in enumerate(cleaned_response["New response"]):
75
+ if i <= 2:
76
+ sentences = tokenizer.tokenize(text)
77
+ if sentences:
78
+ cleaned_response["New response"][i] = sentences[0]
79
+ else:
80
+ del cleaned_response["New response"][i]
81
+ if cleaned_response.get("Sentence count"):
82
+ if cleaned_response["Sentence count"] > 3:
83
+ cleaned_response["Sentence count"] = 3
84
+ else:
85
+ cleaned_response["Sentence count"] = len(cleaned_response["New response"])
86
+
87
+ data.history += str(cleaned_response)
88
+
89
+ return cleaned_response
90
+
91
  except Exception as e:
92
+ print(f"Model {model} failed with error: {e}")
93
+
94
+ raise HTTPException(status_code=500, detail="All models failed to generate response")