New-Place / main.py
oflakne26's picture
Update main.py
c81e2b1 verified
raw
history blame
No virus
3.19 kB
from fastapi import FastAPI, HTTPException
from typing import Any
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
from json_repair import repair_json
import nltk
import sys
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
FALLBACK_MODELS = [
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
]
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
system_prompt: str
user_input: str
json_prompt: str
history: str = ""
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Any:
client = InferenceClient(model=data.model, token=HF_TOKEN)
sentences = tokenizer.tokenize(data.user_input)
data_dict = {'###New response###': [], '###Sentence count###': 0}
for i, sentence in enumerate(sentences):
data_dict["###New response###"].append(sentence)
data_dict["###Sentence count###"] = i + 1
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
inputs = (
data.system_prompt_template.replace("{SystemPrompt}",
data.system_prompt) +
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
data.history)
seed = random.randint(0, 2**32 - 1)
try:
response = client.text_generation(inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed)
strict_response = str(response)
repaired_response = repair_json(strict_response,
return_objects=True)
if isinstance(repaired_response, str):
raise HTTPException(status_code=500, detail="Invalid response from model")
else:
cleaned_response = {}
for key, value in repaired_response.items():
cleaned_key = key.replace("###", "")
cleaned_response[cleaned_key] = value
for i, text in enumerate(cleaned_response["New response"]):
if i <= 2:
sentences = tokenizer.tokenize(text)
if sentences:
cleaned_response["New response"][i] = sentences[0]
else:
del cleaned_response["New response"][i]
if cleaned_response.get("Sentence count"):
if cleaned_response["Sentence count"] > 3:
cleaned_response["Sentence count"] = 3
else:
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
data.history += str(cleaned_response)
return cleaned_response
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")