ka1kuk commited on
Commit
7c2f128
1 Parent(s): 62f3d3a

Update apis/chat_api.py

Browse files
Files changed (1) hide show
  1. apis/chat_api.py +14 -26
apis/chat_api.py CHANGED
@@ -187,42 +187,30 @@ class ChatAPIApp:
187
  data_response = streamer.chat_return_dict(stream_response)
188
  return data_response
189
 
190
- async def chat_embedding(self, input, model_name, api_key: str = Depends(extract_api_key)):
191
  api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
192
  headers = {"Authorization": f"Bearer {api_key}"}
193
- response = requests.post(api_url, headers=headers, json={"inputs": input})
194
  result = response.json()
195
  if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
196
- # Assuming each embedding is a list of lists of floats, flatten it
197
- flattened_embeddings = [sum(embedding, []) for embedding in result]
198
- return flattened_embeddings
199
  elif "error" in result:
200
  raise RuntimeError("The model is currently loading, please re-run the query.")
201
  else:
202
  raise RuntimeError("Unexpected response format.")
203
-
204
 
205
  async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)):
206
- try:
207
- for attempt in range(3): # Retry logic
208
- try:
209
- embeddings = await self.chat_embedding(request.input, request.model, api_key)
210
- data = [
211
- {"object": "embedding", "index": i, "embedding": embedding}
212
- for i, embedding in enumerate(embeddings)
213
- ]
214
- return {
215
- "object": "list",
216
- "data": data,
217
- "model": request.model,
218
- "usage": {"prompt_tokens": len(request.input), "total_tokens": len(request.input)}
219
- }
220
- except RuntimeError as e:
221
- if attempt < 2: # Don't sleep on the last attempt
222
- await asyncio.sleep(10) # Delay for the retry
223
- raise HTTPException(status_code=503, detail="The model is currently loading, please try again later.")
224
- except Exception as e:
225
- raise HTTPException(status_code=500, detail=str(e))
226
 
227
  def setup_routes(self):
228
  for prefix in ["", "/v1", "/api", "/api/v1"]:
 
187
  data_response = streamer.chat_return_dict(stream_response)
188
  return data_response
189
 
190
+ async def chat_embedding(self, input_text: str, model_name: str, api_key: str):
191
  api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}"
192
  headers = {"Authorization": f"Bearer {api_key}"}
193
+ response = requests.post(api_url, headers=headers, json={"inputs": input_text})
194
  result = response.json()
195
  if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
196
+ return [item for sublist in result for item in sublist] # Flatten the list of lists
 
 
197
  elif "error" in result:
198
  raise RuntimeError("The model is currently loading, please re-run the query.")
199
  else:
200
  raise RuntimeError("Unexpected response format.")
 
201
 
202
  async def embedding(self, request: QueryRequest, api_key: str = Depends(extract_api_key)):
203
+ try:
204
+ embeddings = await self.chat_embedding(request.input, request.model, api_key)
205
+ data = [{"object": "embedding", "index": i, "embedding": embedding} for i, embedding in enumerate(embeddings)]
206
+ return EmbeddingResponse(
207
+ object="list",
208
+ data=data,
209
+ model=request.model,
210
+ usage={"prompt_tokens": len(request.input), "total_tokens": len(request.input)}
211
+ )
212
+ except Exception as e:
213
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
214
 
215
  def setup_routes(self):
216
  for prefix in ["", "/v1", "/api", "/api/v1"]: