New-Place / main.py
oflakne26's picture
Update main.py
cdf5dea verified
raw
history blame
No virus
3.74 kB
from fastapi import FastAPI, HTTPException
from typing import Any, Dict
from pydantic import BaseModel
import os
from os import getenv
from huggingface_hub import InferenceClient
import random
from json_repair import repair_json
import nltk
app = FastAPI()
nltk.download('punkt')
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
HF_TOKEN = getenv("HF_TOKEN")
class InputData(BaseModel):
model: str
system_prompt_template: str
prompt_template: str
end_token: str
system_prompt: str
user_input: str
json_prompt: str
history: str = ""
@app.post("/generate-response/")
async def generate_response(data: InputData) -> Dict[str, Any]:
client = InferenceClient(model=data.model, token=HF_TOKEN)
sentences = tokenizer.tokenize(data.user_input)
data_dict = {'###New response###': [], '###Sentence count###': 0}
for i, sentence in enumerate(sentences):
data_dict["###New response###"].append(sentence)
data_dict["###Sentence count###"] = i + 1
data.history += data.prompt_template.replace("{Prompt}", str(data_dict))
inputs = (
data.system_prompt_template.replace("{SystemPrompt}", data.system_prompt) +
data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) +
data.history
)
seed = random.randint(0, 2**32 - 1)
try:
response = client.text_generation(
inputs,
temperature=1.0,
max_new_tokens=1000,
seed=seed
)
strict_response = str(response)
repaired_response = repair_json(strict_response, return_objects=True)
if isinstance(repaired_response, str):
raise HTTPException(status_code=500, detail="Invalid response from model")
else:
cleaned_response = {}
for key, value in repaired_response.items():
cleaned_key = key.replace("###", "")
cleaned_response[cleaned_key] = value
for i, text in enumerate(cleaned_response["New response"]):
if i <= 2:
sentences = tokenizer.tokenize(text)
if sentences:
cleaned_response["New response"][i] = sentences[0]
else:
del cleaned_response["New response"][i]
if cleaned_response.get("Sentence count"):
if cleaned_response["Sentence count"] > 3:
cleaned_response["Sentence count"] = 3
else:
cleaned_response["Sentence count"] = len(cleaned_response["New response"])
data.history += str(cleaned_response)
return {
"response": cleaned_response,
"history": data.history + data.end_token
}
except Exception as e:
print(f"Model {data.model} failed with error: {e}")
raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response")
@app.post("/get-medieval-name/")
async def get_medieval_name() -> Dict[str, str]:
try:
file_path = "medieval_names.txt"
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found")
with open(file_path, "r") as file:
names = file.read().splitlines()
if not names:
raise HTTPException(status_code=404, detail="No names found in the file")
random_name = random.choice(names)
return {"name": random_name}
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail="An error occurred while processing the request")