New-Place / main.py
oflakne26's picture
Update main.py
8a7a2b3 verified
raw
history blame
No virus
2.18 kB
import os
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from ctransformers import AutoModelForCausalLM
# Pydantic object for request validation
class Validation(BaseModel):
inputs: str
temperature: float = 0.0
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
# Initialize FastAPI app
app = FastAPI()
# Function to load models and create endpoints
def setup_endpoints(app):
model_base_path = './models'
if not os.path.exists(model_base_path) or not os.path.isdir(model_base_path):
raise RuntimeError("Models directory does not exist or is not a directory")
model_dirs = [d for d in os.listdir(model_base_path) if os.path.isdir(os.path.join(model_base_path, d))]
if not model_dirs:
raise RuntimeError("No models found in the models directory")
models = {}
# Load each model
for model_name in model_dirs:
model_path = os.path.join(model_base_path, model_name)
try:
model = AutoModelForCausalLM.from_pretrained(model_path, threads=2)
models[model_name] = model
except Exception as e:
print(f"Failed to load model {model_name}: {e}")
continue
# Function to get model dependency
def get_model(model_name: str):
if model_name not in models:
raise HTTPException(status_code=404, detail="Model not found")
return models[model_name]
# Create an endpoint for each model
for model_name in model_dirs:
@app.post(f"/{model_name}")
async def generate_response(item: Validation, model=Depends(lambda: get_model(model_name))):
try:
response = model(item.inputs,
temperature=item.temperature,
max_new_tokens=item.max_new_tokens,
top_p=item.top_p,
repetition_penalty=item.repetition_penalty)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Setup endpoints
setup_endpoints(app)