mitultiwari commited on
Commit
5255e92
β€’
1 Parent(s): 09d2176
Files changed (7) hide show
  1. Dockerfile +11 -0
  2. README.md +4 -7
  3. app.py +55 -0
  4. chainlit.md +5 -0
  5. data/test.txt +1 -0
  6. requirements.txt +108 -0
  7. src/retrieval_lib.py +105 -0
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+ RUN useradd -m -u 1000 user
3
+ USER user
4
+ ENV HOME=/home/user \
5
+ PATH=/home/user/.local/bin:$PATH
6
+ WORKDIR $HOME/app
7
+ COPY --chown=user . $HOME/app
8
+ COPY ./requirements.txt ~/app/requirements.txt
9
+ RUN pip install -r requirements.txt
10
+ COPY . .
11
+ CMD ["chainlit", "run", "app.py", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,8 @@
1
  ---
2
- title: Legal Qna
3
- emoji: πŸ‘€
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
8
- license: openrail
9
  ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: PDF RAG Demo
3
+ emoji: πŸ“‰
4
+ colorFrom: pink
5
+ colorTo: yellow
6
  sdk: docker
7
  pinned: false
 
8
  ---
 
 
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # You can find this code for Chainlit python streaming here (https://docs.chainlit.io/concepts/streaming/python)
2
+
3
+ # OpenAI Chat completion
4
+ import os
5
+ from openai import AsyncOpenAI # importing openai for API usage
6
+ import chainlit as cl # importing chainlit for our app
7
+ from chainlit.prompt import Prompt, PromptMessage # importing prompt tools
8
+ from chainlit.playground.providers import ChatOpenAI # importing ChatOpenAI tools
9
+ from dotenv import load_dotenv
10
+ from src.retrieval_lib import initialize_index, load_pdf_to_text, split_text, load_text_to_index, query_index, create_answer_prompt, generate_answer
11
+
12
+ load_dotenv()
13
+
14
+ retriever = initialize_index()
15
+
16
+ @cl.on_chat_start # marks a function that will be executed at the start of a user session
17
+ async def start_chat():
18
+ settings = {
19
+ "model": "gpt-3.5-turbo",
20
+ "temperature": 0,
21
+ "max_tokens": 500,
22
+ "top_p": 1,
23
+ "frequency_penalty": 0,
24
+ "presence_penalty": 0,
25
+ }
26
+ cl.user_session.set("settings", settings)
27
+
28
+
29
+ @cl.on_message # marks a function that should be run each time the chatbot receives a message from a user
30
+ async def main(message: cl.Message):
31
+ settings = cl.user_session.get("settings")
32
+
33
+ client = AsyncOpenAI()
34
+
35
+ print(message.content)
36
+
37
+
38
+ #print([m.to_openai() for m in prompt.messages])
39
+
40
+ query = message.content
41
+ # query = "what is the reason for the lawsuit"
42
+ retrieved_docs = query_index(retriever, query)
43
+ print("retrieved_docs: \n", len(retrieved_docs))
44
+ answer_prompt = create_answer_prompt()
45
+ print("answer_prompt: \n", answer_prompt)
46
+ result = generate_answer(retriever, answer_prompt, query)
47
+ print("result: \n", result["response"].content)
48
+
49
+ msg = cl.Message(content="")
50
+
51
+
52
+ msg.content = result["response"].content
53
+
54
+ # Send and close the message stream
55
+ await msg.send()
chainlit.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # PDF RAG
2
+
3
+ RAG over a PDF document
4
+
5
+ Disclaimer: this is running the query over the pdf document and generating answers using LLM. LLMs can hellucinate and can generate wrong answers.
data/test.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ i
requirements.txt ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.9.3
3
+ aiosignal==1.3.1
4
+ annotated-types==0.6.0
5
+ anyio==3.7.1
6
+ appdirs==1.4.4
7
+ async-timeout==4.0.3
8
+ asyncer==0.0.2
9
+ attrs==23.2.0
10
+ bidict==0.23.1
11
+ certifi==2024.2.2
12
+ chainlit==0.7.700
13
+ charset-normalizer==3.3.2
14
+ click==8.1.7
15
+ dataclasses-json==0.5.14
16
+ datasets==2.18.0
17
+ Deprecated==1.2.14
18
+ dill==0.3.8
19
+ distro==1.9.0
20
+ exceptiongroup==1.2.0
21
+ faiss-cpu==1.8.0
22
+ fastapi==0.100.1
23
+ fastapi-socketio==0.0.10
24
+ filelock==3.13.1
25
+ filetype==1.2.0
26
+ frozenlist==1.4.1
27
+ fsspec==2024.2.0
28
+ googleapis-common-protos==1.62.0
29
+ grpcio==1.62.1
30
+ h11==0.14.0
31
+ httpcore==0.17.3
32
+ httpx==0.24.1
33
+ huggingface-hub==0.21.4
34
+ idna==3.6
35
+ importlib-metadata==6.11.0
36
+ jsonpatch==1.33
37
+ jsonpointer==2.4
38
+ langchain==0.1.11
39
+ langchain-community==0.0.27
40
+ langchain-core==0.1.30
41
+ langchain-openai==0.0.8
42
+ langchain-text-splitters==0.0.1
43
+ langchainhub==0.1.15
44
+ langsmith==0.1.23
45
+ Lazify==0.4.0
46
+ marshmallow==3.21.1
47
+ multidict==6.0.5
48
+ multiprocess==0.70.16
49
+ mypy-extensions==1.0.0
50
+ nest-asyncio==1.6.0
51
+ numpy==1.26.4
52
+ openai==1.13.3
53
+ opentelemetry-api==1.23.0
54
+ opentelemetry-exporter-otlp==1.23.0
55
+ opentelemetry-exporter-otlp-proto-common==1.23.0
56
+ opentelemetry-exporter-otlp-proto-grpc==1.23.0
57
+ opentelemetry-exporter-otlp-proto-http==1.23.0
58
+ opentelemetry-instrumentation==0.44b0
59
+ opentelemetry-proto==1.23.0
60
+ opentelemetry-sdk==1.23.0
61
+ opentelemetry-semantic-conventions==0.44b0
62
+ orjson==3.9.15
63
+ packaging==23.2
64
+ pandas==2.2.1
65
+ protobuf==4.25.3
66
+ pyarrow==15.0.1
67
+ pyarrow-hotfix==0.6
68
+ pydantic==2.6.3
69
+ pydantic_core==2.16.3
70
+ PyJWT==2.8.0
71
+ PyMuPDF==1.23.26
72
+ PyMuPDFb==1.23.22
73
+ pysbd==0.3.4
74
+ python-dateutil==2.9.0.post0
75
+ python-dotenv==1.0.1
76
+ python-engineio==4.9.0
77
+ python-graphql-client==0.4.3
78
+ python-multipart==0.0.6
79
+ python-socketio==5.11.1
80
+ pytz==2024.1
81
+ PyYAML==6.0.1
82
+ ragas==0.1.3
83
+ regex==2023.12.25
84
+ requests==2.31.0
85
+ simple-websocket==1.0.0
86
+ six==1.16.0
87
+ sniffio==1.3.1
88
+ SQLAlchemy==2.0.28
89
+ starlette==0.27.0
90
+ syncer==2.0.3
91
+ tenacity==8.2.3
92
+ tiktoken==0.6.0
93
+ tomli==2.0.1
94
+ tqdm==4.66.2
95
+ types-requests==2.31.0.20240311
96
+ typing-inspect==0.9.0
97
+ typing_extensions==4.10.0
98
+ tzdata==2024.1
99
+ uptrace==1.22.0
100
+ urllib3==2.2.1
101
+ uvicorn==0.23.2
102
+ watchfiles==0.20.0
103
+ websockets==12.0
104
+ wrapt==1.16.0
105
+ wsproto==1.2.0
106
+ xxhash==3.4.1
107
+ yarl==1.9.4
108
+ zipp==3.17.0
src/retrieval_lib.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # import libraries
3
+ import os
4
+ import openai
5
+ from langchain_community.document_loaders import PyMuPDFLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain_openai import OpenAIEmbeddings
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain.prompts import ChatPromptTemplate
10
+ from operator import itemgetter
11
+ from langchain_openai import ChatOpenAI
12
+ from langchain_core.output_parsers import StrOutputParser
13
+ from langchain_core.runnables import RunnablePassthrough
14
+
15
+
16
+ LLM_MODEL_NAME = "gpt-3.5-turbo"
17
+
18
+
19
+ # load PDF doc and convert to text
20
+ def load_pdf_to_text(pdf_path):
21
+ # create a document loader
22
+ loader = PyMuPDFLoader(pdf_path)
23
+ # load the document
24
+ doc = loader.load()
25
+ return doc
26
+
27
+ def split_text(text):
28
+ # create a text splitter
29
+ splitter = RecursiveCharacterTextSplitter(
30
+ chunk_size=700,
31
+ chunk_overlap=100,
32
+ )
33
+ # split the text
34
+ split_text = splitter.split_documents(text)
35
+ return split_text
36
+
37
+ # load text into FAISS index
38
+ def load_text_to_index(doc_splits):
39
+ embeddings = OpenAIEmbeddings(
40
+ model = "text-embedding-3-small"
41
+ )
42
+ vector_store = FAISS.from_documents(doc_splits, embeddings)
43
+ retriever = vector_store.as_retriever()
44
+ return retriever
45
+
46
+ # query FAISS index
47
+ def query_index(retriever, query):
48
+ retrieved_docs = retriever.invoke(query)
49
+ return retrieved_docs
50
+
51
+ # create answer prompt
52
+ def create_answer_prompt():
53
+ template = """Answer the question based only on the following context. If you cannot answer the question with the context, please respond with 'I don't know':
54
+
55
+ Context:
56
+ {context}
57
+
58
+ Question:
59
+ {question}
60
+ """
61
+ print("template: ", len(template))
62
+ prompt = ChatPromptTemplate.from_template(template)
63
+ return prompt
64
+
65
+ # generate answer
66
+ def generate_answer(retriever, answer_prompt, query):
67
+ print("generate_answer()")
68
+ QnA_LLM = ChatOpenAI(model_name=LLM_MODEL_NAME, temperature=0.0)
69
+
70
+ retrieval_qna_chain = (
71
+ {"context": itemgetter("question") | retriever, "question": itemgetter("question")}
72
+ | RunnablePassthrough.assign(context = itemgetter("context"))
73
+ | {"response": answer_prompt | QnA_LLM, "context": itemgetter("context")}
74
+ )
75
+ result = retrieval_qna_chain.invoke({"question": query})
76
+ return result
77
+
78
+ def initialize_index():
79
+ # load pdf
80
+ cwd = os.path.abspath(os.getcwd())
81
+ data_dir = "data"
82
+ pdf_file = "nvidia_earnings_report.pdf"
83
+ # pdf_file = "musk-v-altman-openai-complaint-sf.pdf"
84
+ pdf_path = os.path.join(cwd, data_dir, pdf_file)
85
+ print("path: ", pdf_path)
86
+ doc = load_pdf_to_text(pdf_path)
87
+ print("doc: \n", len(doc))
88
+ doc_splits = split_text(doc)
89
+ print("doc_splits length: \n", len(doc_splits))
90
+ retriever = load_text_to_index(doc_splits)
91
+ return retriever
92
+
93
+ def main():
94
+ retriever = initialize_index()
95
+ # query = "Who is the E-VP, Operations"
96
+ query = "what is the reason for the lawsuit"
97
+ retrieved_docs = query_index(retriever, query)
98
+ print("retrieved_docs: \n", len(retrieved_docs))
99
+ answer_prompt = create_answer_prompt()
100
+ print("answer_prompt: \n", answer_prompt)
101
+ result = generate_answer(retriever, answer_prompt, query)
102
+ print("result: \n", result["response"].content)
103
+
104
+ if __name__ == "__main__":
105
+ main()