RAG / app.py
orrinin's picture
Update app.py
1ebf8b3 verified
raw
history blame
No virus
4.52 kB
import os
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core import Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.readers.web import SimpleWebPageReader
from llama_index.vector_stores.chroma import ChromaVectorStore
import chromadb
import re
from llama_index.llms.gemini import Gemini
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.core import PromptTemplate
from llama_index.core.llms import ChatMessage
import gradio as gr
import uuid
api_key = os.environ.get("API_KEY")
llm = Gemini(api_key=api_key, model_name="models/gemini-1.5-flash-latest")
gemini_embedding_model = GeminiEmbedding(api_key=api_key, model_name="models/embedding-001")
# Set Global settings
Settings.llm = llm
Settings.embed_model = gemini_embedding_model
def extract_web(url):
web_documents = SimpleWebPageReader().load_data(
[url]
)
html_content = web_documents[0].text
# Parse the data.
soup = BeautifulSoup(html_content, 'html.parser')
p_tags = soup.findAll('p')
text_content = ""
for each in p_tags:
text_content += each.text + "\n"
# Convert back to Document format
documents = [Document(text=text_content)]
option = "web"
return documents, option
def extract_doc(path):
documents = SimpleDirectoryReader(input_files=path).load_data()
option = "doc"
return documents, option
def create_col(documents):
# Create a client and a new collection
db_path = f'database/{str(uuid.uuid4()[:4])}'
client = chromadb.PersistentClient(path=db_path)
chroma_collection = client.get_or_create_collection("quickstart")
# Create a vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# Create a storage context
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Create an index from the documents and save it to the disk.
VectorStoreIndex.from_documents(
documents, storage_context=storage_context
)
return db_path
def infer(message:str, history: list):
print(f'message: {message}')
print(f'history: {history}')
messages = []
files_list = message["files"]
for prompt,answer in history:
if prompt is tuple:
files_list += prompt[0]
else:
messages.append(ChatMessage(role= "user", content = prompt))
messages.append(ChatMessage(role= "assistant", content = answer))
if files_list:
documents, option = extract_doc(files_list)
else:
if message["text"].startswith("http://") or message["text"].startswith("https://"):
documents, option = extract_web(message["text"])
elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0:
gr.Error("Please input an url or upload file at first.")
print(documents)
db_path = create_col(documents)
# Load from disk
load_client = chromadb.PersistentClient(path=db_path)
# Fetch the collection
chroma_collection = load_client.get_collection("quickstart")
# Fetch the vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# Get the index from the vector store
index = VectorStoreIndex.from_vector_store(
vector_store
)
template = (
""" You are an assistant for question-answering tasks.
Use the following context to answer the question.
If you don't know the answer, just say that you don't know.
Use five sentences maximum and keep the answer concise.\n
Question: {query_str} \nContext: {context_str} \nAnswer:"""
)
llm_prompt = PromptTemplate(template)
print(llm_prompt)
if option == "web" and len(history) == 0:
response = "Get the web data! You can ask it."
else:
question = message['text']
query_engine = index.as_query_engine(text_qa_template=llm_prompt)
response = query_engine.query(question)
return response
chatbot = gr.Chatbot()
with gr.Blocks(theme="soft") as demo:
gr.ChatInterface(
fn = infer,
title = "RAG demo",
multimodal = True,
chatbot=chatbot,
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)