New-Place / main.py
oflakne26's picture
Update main.py
d83f45c verified
raw
history blame
No virus
3.83 kB
from fastapi import FastAPI, HTTPException
from typing import Any, Dict
from pydantic import BaseModel
from os import getenv
from huggingface_hub import InferenceClient
import random
from json_repair import repair_json
import nltk
import json
import re
from word_forms.word_forms import get_word_forms
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
end_token: str
system_prompt: str
user_input: str
json_prompt: str
history: str = ""
class WordCheckData(BaseModel):
string: str
word: str
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
client = InferenceClient(model=data.model, token=HF_TOKEN)
sentences = tokenizer.tokenize(data.user_input)
data_dict = {'###New response###': [], '###Sentence count###': 0}
for i, sentence in enumerate(sentences):
data_dict["###New response###"].append(sentence)
data_dict["###Sentence count###"] = i + 1
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
inputs = (
data.system_prompt_template.replace("{SystemPrompt}", data.system_prompt) +
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
data.history
)
seed = random.randint(0, 2**32 - 1)
try:
response = client.text_generation(
inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed
)
strict_response = str(response)
repaired_response = repair_json(strict_response, return_objects=True)
if isinstance(repaired_response, str):
raise HTTPException(status_code=500, detail="Invalid response from model")
else:
cleaned_response = {}
for key, value in repaired_response.items():
cleaned_key = key.replace("###", "")
cleaned_response[cleaned_key] = value
strings = ""
for i, text in enumerate(cleaned_response["New response"]):
if i != len(cleaned_response["New response"]) - 1:
strings += text + " "
else:
strings += text
sentences = tokenizer.tokenize(strings)
cleaned_response["New response"] = sentences
if cleaned_response.get("Sentence count"):
if cleaned_response["Sentence count"] > 3:
cleaned_response["Sentence count"] = 3
else:
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
data.history += str(cleaned_response)
return {
"response": cleaned_response,
"history": data.history + data.end_token
}
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")
@app.post("/check-word/")
async def check_word(data: WordCheckData) -> Dict[str, Any]:
input_string = data.string.lower()
word = data.word.lower()
forms = get_word_forms(word)
all_forms = set()
for words in forms.values():
all_forms.update(words)
# Split the input string into words using regular expression to handle spaces and punctuation
words_in_string = re.findall(r'\b\w+\b', input_string)
found = False
for word_in_string in words_in_string:
if word_in_string in all_forms:
found = True
break
result = {
"found": found
}
return result