File size: 3,198 Bytes
b5b2e6a
 
 
 
 
 
 
 
2a8e87b
b5b2e6a
ca4359e
b5b2e6a
 
9dbae8b
b5b2e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676b3da
 
8a69f83
 
 
5cc985a
b5b2e6a
 
 
949bf2b
 
 
c8510e0
 
 
 
 
 
 
 
fa14017
 
b128493
c32bfde
676b3da
fa14017
676b3da
c8510e0
676b3da
c8510e0
c41eb74
b5b2e6a
676b3da
b5b2e6a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fuzzy_json import loads
from half_json.core import JSONFixer
from together import Together
from retry import retry
import re
from dotenv import load_dotenv
import os
from fastapi import FastAPI
from fastapi import Query
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware

# Retrieve environment variables
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")

SysPromptDefault = "You are an expert AI, complete the given task. Do not add any additional comments."    
SysPromptList = "You are now in the role of an expert AI who can extract structured information from user request. All elements must be in double quotes. You must respond ONLY with a valid python List. Do not add any additional comments."

@retry(tries=3, delay=1)
def together_response(message, model = "meta-llama/Llama-3-8b-chat-hf", SysPrompt = SysPromptDefault):
  client = Together(api_key=TOGETHER_API_KEY)

  messages=[{"role": "system", "content": SysPrompt},{"role": "user", "content": message}]

  response = client.chat.completions.create(
      model=model,
      messages=messages,
      temperature=0.2,
  )
  return response.choices[0].message.content

def json_from_text(text):
    """
    Extracts JSON from text using regex and fuzzy JSON loading.
    """
    match = re.search(r'\{[\s\S]*\}', text)
    if match:
      json_out = match.group(0)
    else:
      json_out = text
    try:
        # Using fuzzy json loader
        return loads(json_out)
    except Exception:
        # Using JSON fixer/ Fixes even half json/ Remove if you need an exception
        fix_json = JSONFixer()
        return loads(fix_json.fix(json_out).line)
    
def generate_topics(user_input, num_topics, previous_queries):
    previous_context = " -> ".join(previous_queries)
    prompt = f"""create a list of {num_topics} subtopics along with descriptions to follow for conducting {user_input} in the context of {previous_context}, RETURN A VALID PYTHON LIST"""\
                +""" Respond in the following format: 
        [["Subtopic","Description"],["Subtopic","Description"]]"""
    response_topics = together_response(prompt, model="meta-llama/Llama-3-8b-chat-hf", SysPrompt=SysPromptList)
    subtopics = json_from_text(response_topics)
    return subtopics

# Define the app
app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Create a Pydantic model to handle the input data
class TopicInput(BaseModel):
    user_input: str = Query(default="market research", description="input query to generate subtopics")
    num_topics: int = Query(default=5, description="Number of subtopics to generate (default: 5)")
    previous_queries: list[str] = Query(default=[], description="List of previous queries for context")

@app.get("/", tags=["Home"])
def api_home():
    return {'detail': 'Welcome to FastAPI Subtopics API! Visit https://pvanand-generate-subtopics.hf.space/docs to test'}

@app.post("/generate_topics")
async def create_topics(input: TopicInput):
    topics = generate_topics(input.user_input, input.num_topics, input.previous_queries)
    return {"topics": topics}