Ritesh-hf commited on
Commit
294a7fa
1 Parent(s): ab310ba

revert previous changes

Browse files
Files changed (2) hide show
  1. app.py +132 -36
  2. test.py +2 -2
app.py CHANGED
@@ -1,42 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import spaces
3
- from transformers import AutoTokenizer, AutoModel
4
  import torch
5
 
6
- # Load the model and tokenizer
7
- model_name = "Alibaba-NLP/gte-large-en-v1.5" # Adjust the model identifier if necessary
8
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
10
-
11
- # Move model to GPU if available
12
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model.to(device)
14
-
15
- @spaces.GPU(duration=1)
16
- def generate_embeddings(text):
17
- # Tokenize input text
18
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
19
-
20
- # Move inputs to GPU if available
21
- inputs = {key: value.to(device) for key, value in inputs.items()}
22
-
23
- # Get model outputs
24
- with torch.no_grad():
25
- outputs = model(**inputs)
26
-
27
- # Extract embeddings (using the mean of the last hidden state as a simple approach)
28
- embeddings = outputs.last_hidden_state.mean(dim=1).cpu().squeeze().tolist()
29
-
30
- return embeddings
31
-
32
- # Define the Gradio interface
33
- interface = gr.Interface(
34
- fn=generate_embeddings,
35
- inputs=gr.Textbox(lines=2, placeholder="Enter text here..."),
36
- outputs=gr.JSON(label="Text Embeddings"),
37
- title="Text Embeddings Generator",
38
- description="Generate text embeddings using the Alibaba-NLP-gte-large-en-v1.5 model."
 
 
 
 
 
 
 
 
 
39
  )
40
 
41
- if __name__ == "__main__":
42
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ load_dotenv(".env")
4
+
5
+ os.environ['USER_AGENT'] = os.getenv("USER_AGENT")
6
+ os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
7
+ os.environ["TOKENIZERS_PARALLELISM"]='true'
8
+
9
+ import nltk
10
+ nltk.download('punkt_tab')
11
+
12
+ from langchain.chains import create_history_aware_retriever, create_retrieval_chain
13
+ from langchain.chains.combine_documents import create_stuff_documents_chain
14
+ from langchain_community.chat_message_histories import ChatMessageHistory
15
+ from langchain_community.document_loaders import WebBaseLoader
16
+ from langchain_core.chat_history import BaseChatMessageHistory
17
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
18
+ from langchain_core.runnables.history import RunnableWithMessageHistory
19
+
20
+ from pinecone import Pinecone
21
+ from pinecone_text.sparse import BM25Encoder
22
+
23
+ from langchain_huggingface import HuggingFaceEmbeddings
24
+ from langchain_community.retrievers import PineconeHybridSearchRetriever
25
+
26
+ from langchain_groq import ChatGroq
27
+
28
  import gradio as gr
29
  import spaces
 
30
  import torch
31
 
32
+
33
+ try:
34
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
35
+ index_name = "traveler-demo-website-vectorstore"
36
+ # connect to index
37
+ pinecone_index = pc.Index(index_name)
38
+ except:
39
+ pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
40
+ index_name = "traveler-demo-website-vectorstore"
41
+ # connect to index
42
+ pinecone_index = pc.Index(index_name)
43
+
44
+ bm25 = BM25Encoder().load("./bm25_traveler_website.json")
45
+
46
+ embed_model = HuggingFaceEmbeddings(model_name="Alibaba-NLP/gte-large-en-v1.5", model_kwargs={"trust_remote_code":True, 'device': 'cuda'})
47
+
48
+ retriever = PineconeHybridSearchRetriever(
49
+ embeddings=embed_model,
50
+ sparse_encoder=bm25,
51
+ index=pinecone_index,
52
+ top_k=20,
53
+ alpha=0.5,
54
+ )
55
+
56
+ llm = ChatGroq(model="llama-3.1-70b-versatile", temperature=0.1, max_tokens=1024, max_retries=2)
57
+
58
+ ### Contextualize question ###
59
+ contextualize_q_system_prompt = """Given a chat history and the latest user question \
60
+ which might reference context in the chat history, formulate a standalone question \
61
+ which can be understood without the chat history. Do NOT answer the question, \
62
+ just reformulate it if needed and otherwise return it as is.
63
+ """
64
+ contextualize_q_prompt = ChatPromptTemplate.from_messages(
65
+ [
66
+ ("system", contextualize_q_system_prompt),
67
+ MessagesPlaceholder("chat_history"),
68
+ ("human", "{input}")
69
+ ]
70
+ )
71
+
72
+ history_aware_retriever = create_history_aware_retriever(
73
+ llm, retriever, contextualize_q_prompt
74
  )
75
 
76
+
77
+ qa_system_prompt = """You are a highly skilled information retrieval assistant. Use the following pieces of retrieved context to answer the question. \
78
+ Provide links to sources provided in the answer. \
79
+ If you don't know the answer, just say that you don't know. \
80
+ Do not give extra long answers. \
81
+ When responding to queries, your responses should be comprehensive and well-organized. For each response: \
82
+ 1. Provide Clear Answers \
83
+ 2. Include Detailed References: \
84
+ - Include links to sources and any links or sites where there is a mentioned in the answer.
85
+ - Links to Sources: Provide URLs to credible sources where users can verify the information or explore further. \
86
+ - Downloadable Materials: Include links to any relevant downloadable resources if applicable. \
87
+ - Reference Sites: Mention specific websites or platforms that offer additional information. \
88
+ 3. Formatting for Readability: \
89
+ - Bullet Points or Lists: Where applicable, use bullet points or numbered lists to present information clearly. \
90
+ - Emphasize Important Information: Use bold or italics to highlight key details. \
91
+ 4. Organize Content Logically \
92
+ Do not include anything about context in the answer. \
93
+ {context}
94
+ """
95
+ qa_prompt = ChatPromptTemplate.from_messages(
96
+ [
97
+ ("system", qa_system_prompt),
98
+ MessagesPlaceholder("chat_history"),
99
+ ("human", "{input}")
100
+ ]
101
+ )
102
+ question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
103
+
104
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
105
+
106
+ ### Statefully manage chat history ###
107
+ store = {}
108
+
109
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
110
+ if session_id not in store:
111
+ store[session_id] = ChatMessageHistory()
112
+ return store[session_id]
113
+
114
+
115
+ conversational_rag_chain = RunnableWithMessageHistory(
116
+ rag_chain,
117
+ get_session_history,
118
+ input_messages_key="input",
119
+ history_messages_key="chat_history",
120
+ output_messages_key="answer",
121
+ )
122
+
123
+ @spaces.GPU(duration=8)
124
+ def handle_message(question, history={}):
125
+ response = ''
126
+ chain = conversational_rag_chain.pick("answer")
127
+ for chunk in chain.stream(
128
+ {"input": question},
129
+ config={
130
+ "configurable": {"session_id": "abc123"}
131
+ },
132
+ ):
133
+ response += chunk
134
+ yield response
135
+
136
+ if __name__ == '__main__':
137
+ demo = gr.ChatInterface(fn=handle_message)
138
+ demo.launch()
test.py CHANGED
@@ -8,8 +8,8 @@ while True:
8
  question = input("Question: ")
9
  start_time = timeit.default_timer()
10
  result = client.predict(
11
- question=question,
12
- api_name="/chat"
13
  )
14
  end_time = timeit.default_timer()
15
  print(result)
 
8
  question = input("Question: ")
9
  start_time = timeit.default_timer()
10
  result = client.predict(
11
+ text=question,
12
+ api_name="/predict"
13
  )
14
  end_time = timeit.default_timer()
15
  print(result)