New-Place / main.py
oflakne26's picture
Update main.py
e53fb7b verified
raw
history blame
No virus
3.38 kB
from fastapi import FastAPI, HTTPException
from typing import Any
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
from json_repair import repair_json
import nltk
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
FALLBACK_MODELS = [
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
]
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
system_prompt: str
user_input: str
json_prompt: str
history: str = ""
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Any:
client = InferenceClient(model=data.model, token=HF_TOKEN)
sentences = tokenizer.tokenize(data.user_input)
data_dict = {'###New response###': [], '###Sentence count###': 0}
for i, sentence in enumerate(sentences):
data_dict["###New response###"].append(sentence)
data_dict["###Sentence count###"] = i + 1
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
inputs = (
data.system_prompt_template.replace("{SystemPrompt}",
data.system_prompt) +
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
data.history)
seed = random.randint(0, 2**32 - 1)
models_to_try = [data.model] + FALLBACK_MODELS
for model in models_to_try:
try:
response = client.text_generation(inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed)
strict_response = str(response)
repaired_response = repair_json(strict_response,
return_objects=True)
if isinstance(repaired_response, str):
raise HTTPException(status_code=500, detail="Invalid response from model")
else:
cleaned_response = {}
for key, value in repaired_response.items():
cleaned_key = key.replace("###", "")
cleaned_response[cleaned_key] = value
for i, text in enumerate(cleaned_response["New response"]):
if i <= 2:
sentences = tokenizer.tokenize(text)
if sentences:
cleaned_response["New response"][i] = sentences[0]
else:
del cleaned_response["New response"][i]
if cleaned_response.get("Sentence count"):
if cleaned_response["Sentence count"] > 3:
cleaned_response["Sentence count"] = 3
else:
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
data.history += str(cleaned_response)
return cleaned_response
except Exception as e:
print(f"Model {model} failed with error: {e}")
raise HTTPException(status_code=500, detail="All models failed to generate response")