File size: 3,193 Bytes
e53fb7b
 
5dab16f
e53fb7b
 
 
 
 
83be1bc
5dab16f
 
4e06fbd
e53fb7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407b03e
e53fb7b
 
 
 
 
 
 
 
 
 
04dbc8e
e53fb7b
 
 
 
 
04dbc8e
e53fb7b
04dbc8e
9db93ec
 
 
 
 
0abcce5
9db93ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c81e2b1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from fastapi import FastAPI, HTTPException
from typing import Any
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
from json_repair import repair_json
import nltk
import sys

app = FastAPI()

nltk.download('punkt')

tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

HF_TOKEN = getenv("HF_TOKEN")
MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
FALLBACK_MODELS = [
    "mistralai/Mixtral-8x7B-Instruct-v0.1",
    "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mistral-7B-Instruct-v0.1"
]

class InputData(BaseModel):
    model: str
    system_prompt_template: str
    prompt_template: str
    system_prompt: str
    user_input: str
    json_prompt: str
    history: str = ""

@app.post("/generate-response/")
async def generate_response(data: InputData) -> Any:
    client = InferenceClient(model=data.model, token=HF_TOKEN)

    sentences = tokenizer.tokenize(data.user_input)
    data_dict = {'###New response###': [], '###Sentence count###': 0}
    for i, sentence in enumerate(sentences):
        data_dict["###New response###"].append(sentence)
        data_dict["###Sentence count###"] = i + 1

    data.history += data.prompt_template.replace("{Prompt}", str(data_dict))

    inputs = (
        data.system_prompt_template.replace("{SystemPrompt}",
                                       data.system_prompt) +
        data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
        data.history)

    seed = random.randint(0, 2**32 - 1)

    try:
        response = client.text_generation(inputs,
                                          temperature=1.0,
                                          max_new_tokens=1000,
                                          seed=seed)

        strict_response = str(response)
        
        repaired_response = repair_json(strict_response,
                                        return_objects=True)

        if isinstance(repaired_response, str):
            raise HTTPException(status_code=500, detail="Invalid response from model")
        else:
            cleaned_response = {}
            for key, value in repaired_response.items():
                cleaned_key = key.replace("###", "")
                cleaned_response[cleaned_key] = value

            for i, text in enumerate(cleaned_response["New response"]):
                if i <= 2:
                    sentences = tokenizer.tokenize(text)
                    if sentences:
                        cleaned_response["New response"][i] = sentences[0]
                else:
                    del cleaned_response["New response"][i]
            if cleaned_response.get("Sentence count"):
                if cleaned_response["Sentence count"] > 3:
                    cleaned_response["Sentence count"] = 3
            else:
                cleaned_response["Sentence count"] = len(cleaned_response["New response"])

            data.history += str(cleaned_response)

            return cleaned_response

    except Exception as e:
        print(f"Model {data.model} failed with error: {e}")
        raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")