Karthikeyan commited on
Commit
ca6370e
1 Parent(s): 8fce35a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -34
app.py CHANGED
@@ -1,51 +1,69 @@
1
  import gradio as gr
 
 
2
 
3
  from langchain.document_loaders import OnlinePDFLoader
4
 
5
  from langchain.text_splitter import CharacterTextSplitter
6
 
7
- from langchain.llms import HuggingFaceHub
8
 
9
- from langchain.embeddings import HuggingFaceHubEmbeddings
10
 
11
- from langchain.vectorstores import Chroma
12
-
13
- from langchain.chains import RetrievalQA
14
 
 
15
 
 
16
 
17
  def loading_pdf():
18
  return "Loading..."
19
 
20
- def pdf_changes(pdf_doc, repo_id):
21
-
22
- loader = OnlinePDFLoader(pdf_doc.name)
23
- documents = loader.load()
24
- text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0)
25
- texts = text_splitter.split_documents(documents)
26
- embeddings = HuggingFaceHubEmbeddings()
27
- db = Chroma.from_documents(texts, embeddings)
28
- retriever = db.as_retriever()
29
- llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature":0.1, "max_new_tokens":250})
30
- global qa
31
- qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True)
32
- return "Ready"
 
 
 
 
 
33
 
34
  def add_text(history, text):
35
  history = history + [(text, None)]
36
  return history, ""
37
 
38
  def bot(history):
39
- response = infer(history[-1][0])
40
- history[-1][1] = response['result']
41
- return history
 
 
 
 
 
42
 
43
- def infer(question):
44
 
 
 
 
 
 
 
 
45
  query = question
46
- result = qa({"query": query})
47
-
48
- return result
49
 
50
  css="""
51
  #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
@@ -60,27 +78,26 @@ title = """
60
  """
61
 
62
 
63
- with gr.Blocks(css=css,theme=gr.themes.Soft()) as demo:
 
64
  with gr.Column(elem_id="col-container"):
65
  gr.HTML(title)
66
-
67
  with gr.Column():
 
68
  pdf_doc = gr.File(label="Load a pdf", file_types=['.pdf'], type="file")
69
- repo_id = gr.Dropdown(label="LLM", choices=["google/flan-ul2", "OpenAssistant/oasst-sft-1-pythia-12b", "bigscience/bloomz"], value="google/flan-ul2")
70
  with gr.Row():
71
  langchain_status = gr.Textbox(label="Status", placeholder="", interactive=False)
72
- load_pdf = gr.Button("Load to langchain")
73
 
74
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=350)
75
  question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
76
- submit_btn = gr.Button("Send message")
77
- repo_id.change(pdf_changes, inputs=[pdf_doc, repo_id], outputs=[langchain_status], queue=False)
78
- load_pdf.click(pdf_changes, inputs=[pdf_doc, repo_id], outputs=[langchain_status], queue=False)
79
  question.submit(add_text, [chatbot, question], [chatbot, question]).then(
80
  bot, chatbot, chatbot
81
  )
82
  submit_btn.click(add_text, [chatbot, question], [chatbot, question]).then(
83
- bot, chatbot, chatbot
84
- )
85
 
86
  demo.launch()
 
1
  import gradio as gr
2
+ import os
3
+ import time
4
 
5
  from langchain.document_loaders import OnlinePDFLoader
6
 
7
  from langchain.text_splitter import CharacterTextSplitter
8
 
 
9
 
10
+ from langchain.llms import OpenAI
11
 
12
+ from langchain.embeddings import OpenAIEmbeddings
 
 
13
 
14
+ from langchain.vectorstores import Chroma
15
 
16
+ from langchain.chains import ConversationalRetrievalChain
17
 
18
  def loading_pdf():
19
  return "Loading..."
20
 
21
+ def pdf_changes(pdf_doc, open_ai_key):
22
+ if openai_key is not None:
23
+ os.environ['OPENAI_API_KEY'] = open_ai_key
24
+ loader = OnlinePDFLoader(pdf_doc.name)
25
+ documents = loader.load()
26
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
27
+ texts = text_splitter.split_documents(documents)
28
+ embeddings = OpenAIEmbeddings()
29
+ db = Chroma.from_documents(texts, embeddings)
30
+ retriever = db.as_retriever()
31
+ global qa
32
+ qa = ConversationalRetrievalChain.from_llm(
33
+ llm=OpenAI(temperature=0.5),
34
+ retriever=retriever,
35
+ return_source_documents=False)
36
+ return "Ready"
37
+ else:
38
+ return "You forgot OpenAI API key"
39
 
40
  def add_text(history, text):
41
  history = history + [(text, None)]
42
  return history, ""
43
 
44
  def bot(history):
45
+ response = infer(history[-1][0], history)
46
+ history[-1][1] = ""
47
+
48
+ for character in response:
49
+ history[-1][1] += character
50
+ time.sleep(0.05)
51
+ yield history
52
+
53
 
54
+ def infer(question, history):
55
 
56
+ res = []
57
+ for human, ai in history[:-1]:
58
+ pair = (human, ai)
59
+ res.append(pair)
60
+
61
+ chat_history = res
62
+ #print(chat_history)
63
  query = question
64
+ result = qa({"question": query, "chat_history": chat_history})
65
+ #print(result)
66
+ return result["answer"]
67
 
68
  css="""
69
  #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
 
78
  """
79
 
80
 
81
+
82
+ with gr.Blocks(css=css) as demo:
83
  with gr.Column(elem_id="col-container"):
84
  gr.HTML(title)
 
85
  with gr.Column():
86
+ openai_key = gr.Textbox(label="You OpenAI API key", type="password")
87
  pdf_doc = gr.File(label="Load a pdf", file_types=['.pdf'], type="file")
 
88
  with gr.Row():
89
  langchain_status = gr.Textbox(label="Status", placeholder="", interactive=False)
90
+ load_pdf = gr.Button("Load pdf to langchain")
91
 
92
  chatbot = gr.Chatbot([], elem_id="chatbot").style(height=350)
93
  question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
94
+ submit_btn = gr.Button("Send Message")
95
+ load_pdf.click(loading_pdf, None, langchain_status, queue=False)
96
+ load_pdf.click(pdf_changes, inputs=[pdf_doc, openai_key], outputs=[langchain_status], queue=False)
97
  question.submit(add_text, [chatbot, question], [chatbot, question]).then(
98
  bot, chatbot, chatbot
99
  )
100
  submit_btn.click(add_text, [chatbot, question], [chatbot, question]).then(
101
+ bot, chatbot, chatbot)
 
102
 
103
  demo.launch()