import streamlit as st import os import embed_pdf import shutil from utils import make_discord_trace_text make_discord_trace_text("RAG UI opened") def clear_directory(directory): for filename in os.listdir(directory): file_path = os.path.join(directory, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print(f'Failed to delete {file_path}. Reason: {e}') def clear_pdf_files(directory): for filename in os.listdir(directory): file_path = os.path.join(directory, filename) try: if os.path.isfile(file_path) and file_path.endswith('.pdf'): os.remove(file_path) except Exception as e: print(f'Failed to delete {file_path}. Reason: {e}') # clear_pdf_files("pdf") # clear_directory("index") # create sidebar and ask for openai api key if not set in secrets secrets_file_path = os.path.join(".streamlit", "secrets.toml") # if os.path.exists(secrets_file_path): # try: # if "OPENAI_API_KEY" in st.secrets: # os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] # else: # print("OpenAI API Key not found in environment variables") # except FileNotFoundError: # print('Secrets file not found') # else: # print('Secrets file not found') # if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"): # os.environ["OPENAI_API_KEY"] = st.sidebar.text_input( # "OpenAI API Key", type="password" # ) # else: # if st.sidebar.button("Embed Documents"): # st.sidebar.info("Embedding documents...") # try: # embed_pdf.embed_all_pdf_docs() # st.sidebar.info("Done!") # except Exception as e: # st.sidebar.error(e) # st.sidebar.error("Failed to embed documents.") os.environ["OPENAI_API_KEY"] = st.sidebar.text_input( "OpenAI API Key", type="password" ) st.sidebar.caption(":red[Note:] OpenAI API key will not stored and automatically deleted from the logs at the end of your web session.") st.sidebar.write("---") uploaded_file = st.sidebar.file_uploader("Upload Document", type=['pdf'], disabled=False) if uploaded_file is None: file_uploaded_bool = False else: file_uploaded_bool = True if st.sidebar.button("Embed Documents", disabled=not file_uploaded_bool): st.sidebar.info("Embedding documents...") try: embed_pdf.embed_all_inputed_pdf_docs(uploaded_file) # embed_pdf.embed_all_pdf_docs() st.sidebar.info("Done!") except Exception as e: st.sidebar.error(e) st.sidebar.error("Failed to embed documents.") st.sidebar.write("---") st.sidebar.markdown(''' Steps to run app 1. Paste OpenAI API Key and press Enter 2. Upload PDF file 3. Click on Embed Documents button 4. Choose RAG method 5. Start Chatting with your PDF ''') # create the app st.title("Chat with your PDF") # chosen_file = st.radio( # "Choose a file to search", embed_pdf.get_all_index_files(), index=0 # ) # check if openai api key is set if not os.getenv('OPENAI_API_KEY', '').startswith("sk-"): st.warning("Please enter your OpenAI API key!", icon="⚠") st.stop() # load the agent from llm_helper import convert_message, get_rag_chain, get_rag_fusion_chain rag_method_map = { 'Basic RAG': get_rag_chain, 'RAG Fusion': get_rag_fusion_chain } chosen_rag_method = st.radio( "Choose a RAG method", rag_method_map.keys(), index=0 ) get_rag_chain_func = rag_method_map[chosen_rag_method] ## get the chain WITHOUT the retrieval callback (not used) # custom_chain = get_rag_chain_func(chosen_file) # create the message history state if "messages" not in st.session_state: st.session_state.messages = [] # render older messages for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # render the chat input prompt = st.chat_input("Enter your message...") if prompt: st.session_state.messages.append({"role": "user", "content": prompt}) # render the user's new message with st.chat_message("user"): st.markdown(prompt) make_discord_trace_text(prompt) # render the assistant's response with st.chat_message("assistant"): retrival_container = st.container() message_placeholder = st.empty() # retrieval_status = retrival_container.status("**Context Retrieval**") queried_questions = [] rendered_questions = set() def update_retrieval_status(): for q in queried_questions: if q in rendered_questions: continue rendered_questions.add(q) # retrieval_status.markdown(f"\n\n`- {q}`") retrival_container.markdown(f"\n\n`- {q}`") def retrieval_cb(qs): for q in qs: if q not in queried_questions: queried_questions.append(q) return qs # get the chain with the retrieval callback custom_chain = get_rag_chain_func(uploaded_file.name, retrieval_cb=retrieval_cb) if "messages" in st.session_state: chat_history = [convert_message(m) for m in st.session_state.messages[:-1]] else: chat_history = [] full_response = "" for response in custom_chain.stream( {"input": prompt, "chat_history": chat_history} ): if "output" in response: full_response += response["output"] else: full_response += response.content message_placeholder.markdown(full_response + "▌") update_retrieval_status() # retrival_container.update(state="complete") # retrieval_status.update(state="complete") message_placeholder.markdown(full_response) make_discord_trace_text(full_response) # add the full response to the message history st.session_state.messages.append({"role": "assistant", "content": full_response})