from fastapi import FastAPI, HTTPException from typing import Any, Dict, List, Optional from pydantic import BaseModel from os import getenv from huggingface_hub import InferenceClient import random import nltk import re from word_forms.word_forms import get_word_forms app = FastAPI() nltk.download('punkt') tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') HF_TOKEN = getenv("HF_TOKEN") class InputData(BaseModel): model: str system_prompt_template: str prompt_template: str end_token: str system_prompts: List[str] user_inputs: List[str] history: str = "" segment: bool = False max_sentences: Optional[int] = None class WordCheckData(BaseModel): string: str word: str @app.post("/generate-response/") async def generate_response(data: InputData) -> Dict[str, Any]: if data.max_sentences is not None and data.max_sentences != 0: data.segment = True elif data.max_sentences == 0: for user_input in data.user_inputs: data.history += data.prompt_template.replace("{Prompt}", user_input) return { "response": [], "sentence_count": None, "history": data.history + data.end_token } responses = [] if data.segment: for user_input in data.user_inputs: user_sentences = tokenizer.tokenize(user_input) user_input_str = "\n".join(user_sentences) data.history += data.prompt_template.replace("{Prompt}", user_input_str) + "\n" else: for user_input in data.user_inputs: data.history += data.prompt_template.replace("{Prompt}", user_input) + "\n" inputs = "" for system_prompt in data.system_prompts: inputs += data.system_prompt_template.replace("{SystemPrompt}", system_prompt) + "\n" inputs += data.history seed = random.randint(0, 2**32 - 1) try: client = InferenceClient(model=data.model, token=HF_TOKEN) response = client.text_generation( inputs, temperature=1.0, max_new_tokens=1000, seed=seed ) response_str = str(response) if data.segment: ai_sentences = tokenizer.tokenize(response_str) if data.max_sentences is not None: ai_sentences = ai_sentences[:data.max_sentences] responses = ai_sentences sentence_count = len(ai_sentences) else: responses = [response_str] sentence_count = None data.history += response_str + "\n" return { "response": responses, "sentence_count": sentence_count, "history": data.history + data.end_token } except Exception as e: print(f"Model {data.model} failed with error: {e}") raise HTTPException(status_code=500, detail=f"Model {data.model} failed to generate response") @app.post("/check-word/") async def check_word(data: WordCheckData) -> Dict[str, Any]: input_string = data.string.lower() word = data.word.lower() forms = get_word_forms(word) all_forms = set() for words in forms.values(): all_forms.update(words) # Initialize found flag found = False # Split the input string into words input_words = input_string.split() # Loop through each word in the input string for input_word in input_words: # Strip the word to contain only alphabetic characters input_word = ''.join(filter(str.isalpha, input_word)) # Check if the stripped word is equal to any of the forms if input_word in all_forms: found = True break # Exit loop if word is found result = { "found": found } return result