File size: 2,180 Bytes
04dbc8e 8a7a2b3 5dab16f 04dbc8e 5dab16f 04dbc8e e3d6348 e4cef2a 04dbc8e 5dab16f 4e06fbd 04dbc8e 8a7a2b3 04dbc8e 8a7a2b3 04dbc8e 8a7a2b3 04dbc8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from ctransformers import AutoModelForCausalLM
# Pydantic object for request validation
class Validation(BaseModel):
inputs: str
temperature: float = 0.0
max_new_tokens: int = 1048
top_p: float = 0.15
repetition_penalty: float = 1.0
# Initialize FastAPI app
app = FastAPI()
# Function to load models and create endpoints
def setup_endpoints(app):
model_base_path = './models'
if not os.path.exists(model_base_path) or not os.path.isdir(model_base_path):
raise RuntimeError("Models directory does not exist or is not a directory")
model_dirs = [d for d in os.listdir(model_base_path) if os.path.isdir(os.path.join(model_base_path, d))]
if not model_dirs:
raise RuntimeError("No models found in the models directory")
models = {}
# Load each model
for model_name in model_dirs:
model_path = os.path.join(model_base_path, model_name)
try:
model = AutoModelForCausalLM.from_pretrained(model_path, threads=2)
models[model_name] = model
except Exception as e:
print(f"Failed to load model {model_name}: {e}")
continue
# Function to get model dependency
def get_model(model_name: str):
if model_name not in models:
raise HTTPException(status_code=404, detail="Model not found")
return models[model_name]
# Create an endpoint for each model
for model_name in model_dirs:
@app.post(f"/{model_name}")
async def generate_response(item: Validation, model=Depends(lambda: get_model(model_name))):
try:
response = model(item.inputs,
temperature=item.temperature,
max_new_tokens=item.max_new_tokens,
top_p=item.top_p,
repetition_penalty=item.repetition_penalty)
return response
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# Setup endpoints
setup_endpoints(app)
|