File size: 2,180 Bytes
04dbc8e
8a7a2b3
5dab16f
04dbc8e
5dab16f
04dbc8e
e3d6348
 
 
 
 
 
e4cef2a
04dbc8e
5dab16f
4e06fbd
04dbc8e
 
 
 
 
 
 
 
 
 
 
 
 
8a7a2b3
04dbc8e
 
 
 
 
 
 
 
 
8a7a2b3
 
 
 
 
 
 
 
04dbc8e
8a7a2b3
04dbc8e
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from ctransformers import AutoModelForCausalLM

# Pydantic object for request validation
class Validation(BaseModel):
    inputs: str
    temperature: float = 0.0
    max_new_tokens: int = 1048
    top_p: float = 0.15
    repetition_penalty: float = 1.0

# Initialize FastAPI app
app = FastAPI()

# Function to load models and create endpoints
def setup_endpoints(app):
    model_base_path = './models'
    if not os.path.exists(model_base_path) or not os.path.isdir(model_base_path):
        raise RuntimeError("Models directory does not exist or is not a directory")

    model_dirs = [d for d in os.listdir(model_base_path) if os.path.isdir(os.path.join(model_base_path, d))]

    if not model_dirs:
        raise RuntimeError("No models found in the models directory")

    models = {}

    # Load each model
    for model_name in model_dirs:
        model_path = os.path.join(model_base_path, model_name)
        try:
            model = AutoModelForCausalLM.from_pretrained(model_path, threads=2)
            models[model_name] = model
        except Exception as e:
            print(f"Failed to load model {model_name}: {e}")
            continue

    # Function to get model dependency
    def get_model(model_name: str):
        if model_name not in models:
            raise HTTPException(status_code=404, detail="Model not found")
        return models[model_name]

    # Create an endpoint for each model
    for model_name in model_dirs:
        @app.post(f"/{model_name}")
        async def generate_response(item: Validation, model=Depends(lambda: get_model(model_name))):
            try:
                response = model(item.inputs,
                                 temperature=item.temperature,
                                 max_new_tokens=item.max_new_tokens,
                                 top_p=item.top_p,
                                 repetition_penalty=item.repetition_penalty)
                return response
            except Exception as e:
                raise HTTPException(status_code=500, detail=str(e))

# Setup endpoints
setup_endpoints(app)