from fastapi import FastAPI, HTTPException from typing import Any, Dict from pydantic import BaseModel from os import getenv from huggingface_hub import InferenceClient import random from json_repair import repair_json import nltk import json import re from word_forms.word_forms import get_word_forms app = FastAPI() nltk.download('punkt') tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') HF_TOKEN = getenv("HF_TOKEN") class InputData(BaseModel): model: str system_prompt_template: str prompt_template: str end_token: str system_prompt: str user_input: str json_prompt: str history: str = "" class WordCheckData(BaseModel): string: str word: str @app.post("/generate-response/") async def generate_response(data: InputData) -> Dict[str, Any]: client = InferenceClient(model=data.model, token=HF_TOKEN) sentences = tokenizer.tokenize(data.user_input) data_dict = {'###New response###': [], '###Sentence count###': 0} for i, sentence in enumerate(sentences): data_dict["###New response###"].append(sentence) data_dict["###Sentence count###"] = i + 1 data.history += data.prompt_template.replace("{Prompt}", str(data_dict)) inputs = ( data.system_prompt_template.replace("{SystemPrompt}", data.system_prompt) + data.system_prompt_template.replace("{SystemPrompt}", data.json_prompt) + data.history ) seed = random.randint(0, 2**32 - 1) try: response = client.text_generation( inputs, temperature=1.0, max_new_tokens=1000, seed=seed ) strict_response = str(response) repaired_response = repair_json(strict_response, return_objects=True) if isinstance(repaired_response, str): raise HTTPException(status_code=500, detail="Invalid response from model") else: cleaned_response = {} for key, value in repaired_response.items(): cleaned_key = key.replace("###", "") cleaned_response[cleaned_key] = value strings = "" for i, text in enumerate(cleaned_response["New response"]): if i != len(cleaned_response["New response"]) - 1: strings += text + " " else: strings += text sentences = tokenizer.tokenize(strings) cleaned_response["New response"] = sentences if cleaned_response.get("Sentence count"): if cleaned_response["Sentence count"] > 3: cleaned_response["Sentence count"] = 3 else: cleaned_response["Sentence count"] = len(cleaned_response["New response"]) data.history += str(cleaned_response) return { "response": cleaned_response, "history": data.history + data.end_token } except Exception as e: print(f"Model {data.model} failed with error: {e}") raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response") @app.post("/check-word/") async def check_word(data: WordCheckData) -> Dict[str, Any]: input_string = data.string.lower() word = data.word.lower() forms = get_word_forms(word) all_forms = set() for words in forms.values(): all_forms.update(words) # Split the input string into words using regular expression to handle spaces and punctuation words_in_string = re.findall(r'\b\w+\b', input_string) found = False for word_in_string in words_in_string: if word_in_string in all_forms: found = True break result = { "found": found } return result