ka1kuk commited on
Commit
8ef9b69
1 Parent(s): 129a4cd

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +61 -2
apis/chat_api.py CHANGED
@@ -3,13 +3,15 @@ import os
3
  import sys
4
  import time
5
  import uvicorn
 
 
6
 
7
  from pathlib import Path
8
  from fastapi import FastAPI, Depends
9
  from fastapi.responses import HTMLResponse
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
  from pydantic import BaseModel, Field
12
- from typing import Union
13
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
14
  from utils.logger import logger
15
  from networks.message_streamer import MessageStreamer
@@ -18,7 +20,6 @@ from mocks.stream_chat_mocker import stream_chat_mock
18
 
19
  from fastapi.middleware.cors import CORSMiddleware
20
 
21
-
22
  class ChatAPIApp:
23
  def __init__(self):
24
  self.app = FastAPI(
@@ -79,6 +80,13 @@ class ChatAPIApp:
79
  "created": current_time,
80
  "owned_by": "codellama",
81
  },
 
 
 
 
 
 
 
82
  ],
83
  }
84
  return self.available_models
@@ -103,6 +111,23 @@ class ChatAPIApp:
103
  logger.warn("Not provide HF Token!")
104
  return None
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  class ChatCompletionsPostItem(BaseModel):
107
  model: str = Field(
108
  default="mixtral-8x7b",
@@ -161,6 +186,28 @@ class ChatAPIApp:
161
  data_response = streamer.chat_return_dict(stream_response)
162
  return data_response
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def setup_routes(self):
165
  for prefix in ["", "/v1", "/api", "/api/v1"]:
166
  if prefix in ["/api/v1"]:
@@ -180,6 +227,18 @@ class ChatAPIApp:
180
  include_in_schema=include_in_schema,
181
  )(self.chat_completions)
182
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  class ArgParser(argparse.ArgumentParser):
185
  def __init__(self, *args, **kwargs):
 
3
  import sys
4
  import time
5
  import uvicorn
6
+ import requests
7
+ import asyncio
8
 
9
  from pathlib import Path
10
  from fastapi import FastAPI, Depends
11
  from fastapi.responses import HTMLResponse
12
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13
  from pydantic import BaseModel, Field
14
+ from typing import Union, List
15
  from sse_starlette.sse import EventSourceResponse, ServerSentEvent
16
  from utils.logger import logger
17
  from networks.message_streamer import MessageStreamer
 
20
 
21
  from fastapi.middleware.cors import CORSMiddleware
22
 
 
23
  class ChatAPIApp:
24
  def __init__(self):
25
  self.app = FastAPI(
 
80
  "created": current_time,
81
  "owned_by": "codellama",
82
  },
83
+ {
84
+ "id": "bert-base-uncased",
85
+ "description": "[google-bert/bert-base-uncased]: https://huggingface.co/google-bert/bert-base-uncased",
86
+ "object": "embedding",
87
+ "created": current_time,
88
+ "owned_by": "google",
89
+ },
90
  ],
91
  }
92
  return self.available_models
 
111
  logger.warn("Not provide HF Token!")
112
  return None
113
 
114
+ class QueryRequest(BaseModel):
115
+ texts: List[str]
116
+ model_name: str = Field(..., example="bert-base-uncased")
117
+ api_key: str = Field(..., example="your_hf_api_key_here")
118
+
119
+ async def send_request_to_hugging_face(texts, model_name, api_key):
120
+ api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
121
+ headers = {"Authorization": f"Bearer {api_key}"}
122
+ response = requests.post(api_url, headers=headers, json={"inputs": texts})
123
+ result = response.json()
124
+ if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
125
+ return result
126
+ elif "error" in result:
127
+ raise RuntimeError("The model is currently loading, please re-run the query.")
128
+ else:
129
+ raise RuntimeError("Unexpected response format.")
130
+
131
  class ChatCompletionsPostItem(BaseModel):
132
  model: str = Field(
133
  default="mixtral-8x7b",
 
186
  data_response = streamer.chat_return_dict(stream_response)
187
  return data_response
188
 
189
+ async def embedding(request: QueryRequest):
190
+ try:
191
+ for attempt in range(3): # Retry logic
192
+ try:
193
+ embeddings = await send_request_to_hugging_face(request.texts, request.model_name, request.api_key)
194
+ data = [
195
+ {"object": "embedding", "index": i, "embedding": embedding}
196
+ for i, embedding in enumerate(embeddings)
197
+ ]
198
+ return {
199
+ "object": "list",
200
+ "data": data,
201
+ "model": request.model_name,
202
+ "usage": {"prompt_tokens": len(request.texts), "total_tokens": len(request.texts)}
203
+ }
204
+ except RuntimeError as e:
205
+ if attempt < 2: # Don't sleep on the last attempt
206
+ await asyncio.sleep(10) # Delay for the retry
207
+ raise HTTPException(status_code=503, detail="The model is currently loading, please try again later.")
208
+ except Exception as e:
209
+ raise HTTPException(status_code=500, detail=str(e))
210
+
211
  def setup_routes(self):
212
  for prefix in ["", "/v1", "/api", "/api/v1"]:
213
  if prefix in ["/api/v1"]:
 
227
  include_in_schema=include_in_schema,
228
  )(self.chat_completions)
229
 
230
+ if prefix in ["/v1"]:
231
+ include_in_schema = True
232
+ else:
233
+ include_in_schema = False
234
+
235
+ self.app.post(
236
+ prefix + "/embedding", # Use the specific prefix for this route
237
+ summary="Generate embeddings for the given texts",
238
+ include_in_schema=include_in_schema,
239
+ response_model=List # Adapt based on your actual response model
240
+ )(self.embedding)
241
+
242
 
243
  class ArgParser(argparse.ArgumentParser):
244
  def __init__(self, *args, **kwargs):