Update apis/chat_api.py
Browse files- apis/chat_api.py +6 -3
apis/chat_api.py
CHANGED
@@ -24,7 +24,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
24 |
class EmbeddingResponseItem(BaseModel):
|
25 |
object: str = "embedding"
|
26 |
index: int
|
27 |
-
embedding: List[float]
|
28 |
|
29 |
class EmbeddingResponse(BaseModel):
|
30 |
object: str = "list"
|
@@ -193,11 +193,14 @@ class ChatAPIApp:
|
|
193 |
response = requests.post(api_url, headers=headers, json={"inputs": texts})
|
194 |
result = response.json()
|
195 |
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
|
196 |
-
|
|
|
|
|
197 |
elif "error" in result:
|
198 |
raise RuntimeError("The model is currently loading, please re-run the query.")
|
199 |
else:
|
200 |
-
raise RuntimeError("Unexpected response format.")
|
|
|
201 |
|
202 |
async def embedding(self, request: QueryRequest):
|
203 |
try:
|
|
|
24 |
class EmbeddingResponseItem(BaseModel):
|
25 |
object: str = "embedding"
|
26 |
index: int
|
27 |
+
embedding: List[List[float]]
|
28 |
|
29 |
class EmbeddingResponse(BaseModel):
|
30 |
object: str = "list"
|
|
|
193 |
response = requests.post(api_url, headers=headers, json={"inputs": texts})
|
194 |
result = response.json()
|
195 |
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], list):
|
196 |
+
# Assuming each embedding is a list of lists of floats, flatten it
|
197 |
+
flattened_embeddings = [sum(embedding, []) for embedding in result]
|
198 |
+
return flattened_embeddings
|
199 |
elif "error" in result:
|
200 |
raise RuntimeError("The model is currently loading, please re-run the query.")
|
201 |
else:
|
202 |
+
raise RuntimeError("Unexpected response format.")
|
203 |
+
|
204 |
|
205 |
async def embedding(self, request: QueryRequest):
|
206 |
try:
|