kdh044 / app.py
danny042's picture
Update app.py
d9f5870
raw
history blame
No virus
5.91 kB
import streamlit as st
import tiktoken
from loguru import logger
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders.pdf import (PyPDFLoader, PyMuPDFLoader)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.vectorstores import FAISS
# from streamlit_chat import message
from langchain.callbacks import get_openai_callback
from langchain.memory import StreamlitChatMessageHistory
from gtts import gTTS
from IPython.display import Audio, display
from pydub import AudioSegment
#์‚ฌ์ดํŠธ ๊ด€๋ จ ํ•จ์ˆ˜
def main():
st.set_page_config(
page_title="์ฐจ๋Ÿ‰์šฉ Q&A ์ฑ—๋ด‡",
page_icon=":car:")
st.title("์ฐจ๋Ÿ‰์šฉ Q&A ์ฑ—๋ด‡ :car:")
if "conversation" not in st.session_state:
st.session_state.conversation = None
if "chat_history" not in st.session_state:
st.session_state.chat_history = None
if "processComplete" not in st.session_state:
st.session_state.processComplete = None
with st.sidebar:
uploaded_files = st.file_uploader("์ฐจ๋Ÿ‰ ๋ฉ”๋‰ด์–ผ PDF ํŒŒ์ผ์„ ๋„ฃ์–ด์ฃผ์„ธ์š”.", type=['pdf'], accept_multiple_files=True)
openai_api_key = st.text_input("OpenAI API Key", key="chatbot_api_key", type="password")
process = st.button("์‹คํ–‰")
if process:
if not openai_api_key:
st.info("Open AIํ‚ค๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.")
st.stop()
files_text = get_text(uploaded_files)
text_chunks = get_text_chunks(files_text)
vetorestore = get_vectorstore(text_chunks)
st.session_state.conversation = get_conversation_chain(vetorestore, openai_api_key)
st.session_state.processComplete = True
if 'messages' not in st.session_state:
st.session_state['messages'] = [{"role": "assistant",
"content": "์•ˆ๋…•ํ•˜์„ธ์š”! ์ฃผ์–ด์ง„ ๋ฌธ์„œ์— ๋Œ€ํ•ด ๊ถ๊ธˆํ•˜์‹  ๊ฒƒ์ด ์žˆ์œผ๋ฉด ์–ธ์ œ๋“  ๋ฌผ์–ด๋ด์ฃผ์„ธ์š”!"}]
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
history = StreamlitChatMessageHistory(key="chat_messages")
# Chat logic
if query := st.chat_input("์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”."):
st.session_state.messages.append({"role": "user", "content": query})
with st.chat_message("user"):
st.markdown(query)
with st.chat_message("assistant"):
chain = st.session_state.conversation
with st.spinner("Thinking..."):
result = chain({"question": query})
with get_openai_callback() as cb:
st.session_state.chat_history = result['chat_history']
response = result['answer']
source_documents = result['source_documents']
# Text-to-Speech ๋ณ€ํ™˜
tts = gTTS(text=response, lang='ko')
tts.save('output.mp3') # ์Œ์„ฑ ํŒŒ์ผ ์ €์žฅ
# ์Œ์„ฑ ํŒŒ์ผ ๋กœ๋“œ
audio = AudioSegment.from_file("output.mp3", format="mp3")
# Streamlit์—์„œ ์Œ์„ฑ ์žฌ์ƒ
st.audio(audio.export(format='mp3').read(), start_time=0)
st.markdown(response)
with st.expander("์ฐธ๊ณ  ๋ฌธ์„œ ํ™•์ธ"):
st.markdown(source_documents[0].metadata['source'], help=source_documents[0].page_content)
st.markdown(source_documents[1].metadata['source'], help=source_documents[1].page_content)
st.markdown(source_documents[2].metadata['source'], help=source_documents[2].page_content)
# Add assistant message to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
#ํ† ํฐํ™” ์‹œํ‚ค๋Š” ๊ณณ
def tiktoken_len(text):
tokenizer = tiktoken.get_encoding("cl100k_base")
tokens = tokenizer.encode(text)
return len(tokens)
#pdfload์ฝ”๋“œ
def get_text(docs):
doc_list = []
for doc in docs:
file_name = doc.name # doc ๊ฐ์ฒด์˜ ์ด๋ฆ„์„ ํŒŒ์ผ ์ด๋ฆ„์œผ๋กœ ์‚ฌ์šฉ
with open(file_name, "wb") as file: # ํŒŒ์ผ์„ doc.name์œผ๋กœ ์ €์žฅ
file.write(doc.getvalue())
logger.info(f"Uploaded {file_name}")
if '.pdf' in doc.name:
loader = PyMuPDFLoader(file_name)
documents = loader.load_and_split()
doc_list.extend(documents)
return doc_list
#textsplitter ์ฝ”๋“œ
def get_text_chunks(text):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
length_function=tiktoken_len
)
chunks = text_splitter.split_documents(text)
return chunks
#์ž„๋ฒ ๋”ฉ ๋ฐ ๋ฒกํ„ฐ์ €์žฅ ์ฝ”๋“œ
def get_vectorstore(text_chunks):
embeddings = HuggingFaceEmbeddings(
model_name="jhgan/ko-sroberta-multitask",
model_kwargs={'device': 'cpu'},
encode_kwargs={'normalize_embeddings': True}
)
vectordb = FAISS.from_documents(text_chunks, embeddings)
return vectordb
#๋ฆฌํŠธ๋ฆฌ๋ฒ„ ๋ฐ llm์ฝ”๋“œ
def get_conversation_chain(vetorestore, openai_api_key):
llm = ChatOpenAI(openai_api_key=openai_api_key, model_name='gpt-3.5-turbo', temperature=0)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=vetorestore.as_retriever(search_type='mmr', vervose=True),
memory=ConversationBufferMemory(memory_key='chat_history', return_messages=True, output_key='answer'),
get_chat_history=lambda h: h,
return_source_documents=True,
verbose=True
)
return conversation_chain
if __name__ == '__main__':
main()