legal-qna / src /retrieval_lib.py
mitultiwari's picture
Update src/retrieval_lib.py
9e046c6 verified
raw
history blame contribute delete
No virus
3.32 kB
# import libraries
import os
import openai
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.prompts import ChatPromptTemplate
from operator import itemgetter
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
LLM_MODEL_NAME = "gpt-3.5-turbo"
# load PDF doc and convert to text
def load_pdf_to_text(pdf_path):
# create a document loader
loader = PyMuPDFLoader(pdf_path)
# load the document
doc = loader.load()
return doc
def split_text(text):
# create a text splitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=700,
chunk_overlap=100,
)
# split the text
split_text = splitter.split_documents(text)
return split_text
# load text into FAISS index
def load_text_to_index(doc_splits):
embeddings = OpenAIEmbeddings(
model = "text-embedding-3-small"
)
vector_store = FAISS.from_documents(doc_splits, embeddings)
retriever = vector_store.as_retriever()
return retriever
# query FAISS index
def query_index(retriever, query):
retrieved_docs = retriever.invoke(query)
return retrieved_docs
# create answer prompt
def create_answer_prompt():
template = """Answer the question based only on the following context. If you cannot answer the question with the context, please respond with 'I don't know':
Context:
{context}
Question:
{question}
"""
print("template: ", len(template))
prompt = ChatPromptTemplate.from_template(template)
return prompt
# generate answer
def generate_answer(retriever, answer_prompt, query):
print("generate_answer()")
QnA_LLM = ChatOpenAI(model_name=LLM_MODEL_NAME, temperature=0.0)
retrieval_qna_chain = (
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
| RunnablePassthrough.assign(context = itemgetter("context"))
| {"response": answer_prompt | QnA_LLM, "context": itemgetter("context")}
)
result = retrieval_qna_chain.invoke({"question": query})
return result
def initialize_index():
# load pdf
cwd = os.path.abspath(os.getcwd())
data_dir = "data"
pdf_file = "nvidia_earnings_report.pdf"
# pdf_file = "musk-v-altman-openai-complaint-sf.pdf"
pdf_path = os.path.join(cwd, data_dir, pdf_file)
print("path: ", pdf_path)
doc = load_pdf_to_text(pdf_path)
print("doc: \n", len(doc))
doc_splits = split_text(doc)
print("doc_splits length: \n", len(doc_splits))
retriever = load_text_to_index(doc_splits)
return retriever
def main():
retriever = initialize_index()
# query = "Who is the E-VP, Operations"
query = "what is the reason for the lawsuit"
retrieved_docs = query_index(retriever, query)
print("retrieved_docs: \n", len(retrieved_docs))
answer_prompt = create_answer_prompt()
print("answer_prompt: \n", answer_prompt)
result = generate_answer(retriever, answer_prompt, query)
print("result: \n", result["response"].content)
if __name__ == "__main__":
main()