audio_chat / main.py
pvanand's picture
Update main.py
270e05e verified
raw
history blame
No virus
4.86 kB
from fuzzy_json import loads
from half_json.core import JSONFixer
from openai import OpenAI
from retry import retry
import re
from dotenv import load_dotenv
import os
from fastapi import FastAPI
from fastapi import Query
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
# Retrieve environment variables
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
GROQ_API_KEY = "gsk_"+os.getenv("GROQ_API_KEY")
SysPromptDefault = "You are an expert AI, complete the given task. Do not add any additional comments."
SysPromptList = "You are now in the role of an expert AI who can extract structured information from user request. All elements must be in double quotes. You must respond ONLY with a valid python List. Do not add any additional comments."
SysPromptJson = "You are now in the role of an expert AI who can extract structured information from user request. Both key and value pairs must be in double quotes. You must respond ONLY with a valid JSON file. Do not add any additional comments."
SysPromptMd = "You are an expert AI who can create a structured report using information provided in the context from user request.The report should be in markdown format consists of markdown tables structured into subtopics. Do not add any additional comments."
SysPromptMdOffline = "You are an expert AI who can create a structured report using your knowledge on user request.The report should be in markdown format consists of markdown tables/lists/paragraphs as needed, structured into subtopics. Do not add any additional comments."
@retry(tries=3, delay=1)
def together_response(message, model = "meta-llama/Llama-3-8b-chat-hf", SysPrompt = SysPromptDefault):
base_url_groq = "https://api.groq.com/openai/v1"
groq_model_name="llama3-8b-8192"
client = OpenAI(base_url= base_url_groq, api_key= GROQ_API_KEY)
messages=[{"role": "system", "content": SysPrompt},{"role": "user", "content": message}]
response = client.chat.completions.create(
model=groq_model_name,
messages=messages,
temperature=0.2,
)
return response.choices[0].message.content
def json_from_text(text):
"""
Extracts JSON from text using regex and fuzzy JSON loading.
"""
match = re.search(r'\{[\s\S]*\}', text)
if match:
json_out = match.group(0)
else:
json_out = text
try:
# Using fuzzy json loader
return loads(json_out)
except Exception:
# Using JSON fixer/ Fixes even half json/ Remove if you need an exception
fix_json = JSONFixer()
return loads(fix_json.fix(json_out).line)
def generate_topics(user_input, num_topics, previous_queries):
previous_context = " -> ".join(previous_queries)
prompt = f"""create a list of {num_topics} subtopics along with descriptions to follow for conducting {user_input} in the context of {previous_context}, RETURN A VALID PYTHON LIST"""\
+""" Respond in the following format:
[["Subtopic","Description"],["Subtopic","Description"]]"""
response_topics = together_response(prompt, model="meta-llama/Llama-3-8b-chat-hf", SysPrompt=SysPromptList)
subtopics = json_from_text(response_topics)
return subtopics
def generate_report(topic, description):
prompt = f"""create a detailed report on {topic} by following the instructions: {description}"""
md_report = together_response(prompt, model = "meta-llama/Llama-3-70b-chat-hf", SysPrompt = SysPromptMdOffline)
return md_to_html(md_report)
# Define the app
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create a Pydantic model to handle the input data
class TopicInput(BaseModel):
user_input: str = Query(default="market research", description="input query to generate subtopics")
num_topics: int = Query(default=5, description="Number of subtopics to generate (default: 5)")
previous_queries: list[str] = Query(default=[], description="List of previous queries for context")
class ReportInput(BaseModel):
topic: str = Query(description="The main topic for the report")
description: str = Query(description="A brief description of the topic")
@app.get("/", tags=["Home"])
def api_home():
return {'detail': 'Welcome to FastAPI Subtopics API! Visit https://pvanand-generate-subtopics.hf.space/docs to test'}
@app.post("/generate_topics")
async def create_topics(input: TopicInput):
topics = generate_topics(input.user_input, input.num_topics, input.previous_queries)
return {"topics": topics}
@app.post("/generate_report")
async def create_report(input: ReportInput):
report = generate_report(input.topic, input.description) # You'll need to implement this function
return {"report": report}