pvanand commited on
Commit
68394ea
1 Parent(s): b067e10

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +110 -17
main.py CHANGED
@@ -1,16 +1,22 @@
1
- from fastapi import FastAPI, HTTPException, Depends, Security
2
  from fastapi.security import APIKeyHeader
3
  from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel, Field
5
- from typing import Literal
6
  import os
7
  from functools import lru_cache
8
  from openai import OpenAI
 
 
 
 
 
 
9
 
10
  app = FastAPI()
11
 
12
  API_KEY_NAME = "X-API-Key"
13
- API_KEY = os.environ.get("API_KEY", "default_secret_key") # Set this in your environment variables
14
  api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
15
 
16
  ModelID = Literal[
@@ -29,12 +35,16 @@ class QueryModel(BaseModel):
29
  default="meta-llama/llama-3-70b-instruct",
30
  description="ID of the model to use for response generation"
31
  )
 
 
32
 
33
  class Config:
34
  schema_extra = {
35
  "example": {
36
  "user_query": "How do I implement a binary search in Python?",
37
- "model_id": "meta-llama/llama-3-70b-instruct"
 
 
38
  }
39
  }
40
 
@@ -47,7 +57,28 @@ def get_api_keys():
47
  api_keys = get_api_keys()
48
  or_client = OpenAI(api_key=api_keys["OPENROUTER_API_KEY"], base_url="https://openrouter.ai/api/v1")
49
 
50
- def chat_with_llama_stream(messages, model, max_output_tokens=2500):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  try:
52
  response = or_client.chat.completions.create(
53
  model=model,
@@ -56,9 +87,16 @@ def chat_with_llama_stream(messages, model, max_output_tokens=2500):
56
  stream=True
57
  )
58
 
 
59
  for chunk in response:
60
  if chunk.choices[0].delta.content is not None:
61
- yield chunk.choices[0].delta.content
 
 
 
 
 
 
62
  except Exception as e:
63
  raise HTTPException(status_code=500, detail=f"Error in model response: {str(e)}")
64
 
@@ -67,8 +105,48 @@ async def verify_api_key(api_key: str = Security(api_key_header)):
67
  raise HTTPException(status_code=403, detail="Could not validate credentials")
68
  return api_key
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  @app.post("/coding-assistant")
71
- async def coding_assistant(query: QueryModel, api_key: str = Depends(verify_api_key)):
72
  """
73
  Coding assistant endpoint that provides programming help based on user queries.
74
 
@@ -83,16 +161,31 @@ async def coding_assistant(query: QueryModel, api_key: str = Depends(verify_api_
83
 
84
  Requires API Key authentication via X-API-Key header.
85
  """
86
- system_prompt = "You are a helpful assistant proficient in coding tasks. Help the user in understanding and writing code."
87
- messages = [
88
- {"role": "system", "content": system_prompt},
89
- {"role": "user", "content": query.user_query}
90
- ]
91
-
92
- return StreamingResponse(
93
- chat_with_llama_stream(messages, model=query.model_id),
94
- media_type="text/event-stream"
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  if __name__ == "__main__":
98
  import uvicorn
 
1
+ from fastapi import FastAPI, HTTPException, Depends, Security, BackgroundTasks
2
  from fastapi.security import APIKeyHeader
3
  from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel, Field
5
+ from typing import Literal, List, Dict
6
  import os
7
  from functools import lru_cache
8
  from openai import OpenAI
9
+ from uuid import uuid4
10
+ import tiktoken
11
+ import sqlite3
12
+ import time
13
+ from datetime import datetime, timedelta
14
+ import asyncio
15
 
16
  app = FastAPI()
17
 
18
  API_KEY_NAME = "X-API-Key"
19
+ API_KEY = os.environ.get("API_KEY", "default_secret_key")
20
  api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
21
 
22
  ModelID = Literal[
 
35
  default="meta-llama/llama-3-70b-instruct",
36
  description="ID of the model to use for response generation"
37
  )
38
+ conversation_id: str = Field(default_factory=lambda: str(uuid4()), description="Unique identifier for the conversation")
39
+ user_id: str = Field(..., description="Unique identifier for the user")
40
 
41
  class Config:
42
  schema_extra = {
43
  "example": {
44
  "user_query": "How do I implement a binary search in Python?",
45
+ "model_id": "meta-llama/llama-3-70b-instruct",
46
+ "conversation_id": "123e4567-e89b-12d3-a456-426614174000",
47
+ "user_id": "user123"
48
  }
49
  }
50
 
 
57
  api_keys = get_api_keys()
58
  or_client = OpenAI(api_key=api_keys["OPENROUTER_API_KEY"], base_url="https://openrouter.ai/api/v1")
59
 
60
+ # In-memory storage for conversations
61
+ conversations: Dict[str, List[Dict[str, str]]] = {}
62
+ last_activity: Dict[str, float] = {}
63
+
64
+ # Token encoding
65
+ encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
66
+
67
+ def limit_tokens(input_string, token_limit=6000):
68
+ return encoding.decode(encoding.encode(input_string)[:token_limit])
69
+
70
+ def calculate_tokens(msgs):
71
+ return sum(len(encoding.encode(str(m))) for m in msgs)
72
+
73
+ def chat_with_llama_stream(messages, model="gpt-3.5-turbo", max_llm_history=4, max_output_tokens=2500):
74
+ while calculate_tokens(messages) > (8000 - max_output_tokens):
75
+ if len(messages) > max_llm_history:
76
+ messages = [messages[0]] + messages[-max_llm_history:]
77
+ else:
78
+ max_llm_history -= 1
79
+ if max_llm_history < 2:
80
+ raise ValueError("Unable to reduce message length below token limit")
81
+
82
  try:
83
  response = or_client.chat.completions.create(
84
  model=model,
 
87
  stream=True
88
  )
89
 
90
+ full_response = ""
91
  for chunk in response:
92
  if chunk.choices[0].delta.content is not None:
93
+ content = chunk.choices[0].delta.content
94
+ full_response += content
95
+ yield content
96
+
97
+ # After streaming, add the full response to the conversation history
98
+ messages.append({"role": "assistant", "content": full_response})
99
+ return full_response
100
  except Exception as e:
101
  raise HTTPException(status_code=500, detail=f"Error in model response: {str(e)}")
102
 
 
105
  raise HTTPException(status_code=403, detail="Could not validate credentials")
106
  return api_key
107
 
108
+ # SQLite setup
109
+ def init_db():
110
+ conn = sqlite3.connect('conversations.db')
111
+ c = conn.cursor()
112
+ c.execute('''CREATE TABLE IF NOT EXISTS conversations
113
+ (id INTEGER PRIMARY KEY AUTOINCREMENT,
114
+ user_id TEXT,
115
+ conversation_id TEXT,
116
+ message TEXT,
117
+ response TEXT,
118
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''')
119
+ conn.commit()
120
+ conn.close()
121
+
122
+ init_db()
123
+
124
+ def update_db(user_id, conversation_id, message, response):
125
+ conn = sqlite3.connect('conversations.db')
126
+ c = conn.cursor()
127
+ c.execute('''INSERT INTO conversations (user_id, conversation_id, message, response)
128
+ VALUES (?, ?, ?, ?)''', (user_id, conversation_id, message, response))
129
+ conn.commit()
130
+ conn.close()
131
+
132
+ async def clear_inactive_conversations():
133
+ while True:
134
+ current_time = time.time()
135
+ inactive_convos = [conv_id for conv_id, last_time in last_activity.items()
136
+ if current_time - last_time > 1800] # 30 minutes
137
+ for conv_id in inactive_convos:
138
+ if conv_id in conversations:
139
+ del conversations[conv_id]
140
+ if conv_id in last_activity:
141
+ del last_activity[conv_id]
142
+ await asyncio.sleep(60) # Check every minute
143
+
144
+ @app.on_event("startup")
145
+ async def startup_event():
146
+ asyncio.create_task(clear_inactive_conversations())
147
+
148
  @app.post("/coding-assistant")
149
+ async def coding_assistant(query: QueryModel, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key)):
150
  """
151
  Coding assistant endpoint that provides programming help based on user queries.
152
 
 
161
 
162
  Requires API Key authentication via X-API-Key header.
163
  """
164
+ if query.conversation_id not in conversations:
165
+ conversations[query.conversation_id] = [
166
+ {"role": "system", "content": "You are a helpful assistant proficient in coding tasks. Help the user in understanding and writing code."}
167
+ ]
168
+
169
+ conversations[query.conversation_id].append({"role": "user", "content": query.user_query})
170
+ last_activity[query.conversation_id] = time.time()
171
+
172
+ # Limit tokens in the conversation history
173
+ limited_conversation = conversations[query.conversation_id]
174
+ while calculate_tokens(limited_conversation) > 8000:
175
+ if len(limited_conversation) > 2: # Keep at least the system message and the latest user message
176
+ limited_conversation.pop(1)
177
+ else:
178
+ error_message = "Token limit exceeded. Please shorten your input or start a new conversation."
179
+ raise HTTPException(status_code=400, detail=error_message)
180
+
181
+ async def process_response():
182
+ full_response = ""
183
+ async for content in chat_with_llama_stream(limited_conversation, model=query.model_id):
184
+ full_response += content
185
+ yield content
186
+ background_tasks.add_task(update_db, query.user_id, query.conversation_id, query.user_query, full_response)
187
+
188
+ return StreamingResponse(process_response(), media_type="text/event-stream")
189
 
190
  if __name__ == "__main__":
191
  import uvicorn