XThomasBU commited on
Commit
f0018f2
1 Parent(s): 40de40e

Code to add metadata to the chunks

Browse files
.chainlit/config.toml CHANGED
@@ -22,7 +22,7 @@ prompt_playground = true
22
  unsafe_allow_html = false
23
 
24
  # Process and display mathematical expressions. This can clash with "$" characters in messages.
25
- latex = false
26
 
27
  # Authorize users to upload files with messages
28
  multi_modal = true
 
22
  unsafe_allow_html = false
23
 
24
  # Process and display mathematical expressions. This can clash with "$" characters in messages.
25
+ latex = true
26
 
27
  # Authorize users to upload files with messages
28
  multi_modal = true
code/config.yml CHANGED
@@ -2,14 +2,14 @@ embedding_options:
2
  embedd_files: False # bool
3
  data_path: 'storage/data' # str
4
  url_file_path: 'storage/data/urls.txt' # str
5
- expand_urls: True # bool
6
- db_option : 'FAISS' # str
7
  db_path : 'vectorstores' # str
8
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
9
  search_top_k : 3 # int
10
- score_threshold : 0.5 # float
11
  llm_params:
12
- use_history: True # bool
13
  memory_window: 3 # int
14
  llm_loader: 'local_llm' # str [local_llm, openai]
15
  openai_params:
 
2
  embedd_files: False # bool
3
  data_path: 'storage/data' # str
4
  url_file_path: 'storage/data/urls.txt' # str
5
+ expand_urls: False # bool
6
+ db_option : 'RAGatouille' # str [FAISS, Chroma, RAGatouille]
7
  db_path : 'vectorstores' # str
8
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
9
  search_top_k : 3 # int
10
+ score_threshold : 0.2 # float
11
  llm_params:
12
+ use_history: False # bool
13
  memory_window: 3 # int
14
  llm_loader: 'local_llm' # str [local_llm, openai]
15
  openai_params:
code/modules/data_loader.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import re
3
  import requests
4
  import pysrt
5
- from langchain.document_loaders import (
6
  PyMuPDFLoader,
7
  Docx2txtLoader,
8
  YoutubeLoader,
@@ -16,6 +16,15 @@ import logging
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain_experimental.text_splitter import SemanticChunker
18
  from langchain_openai.embeddings import OpenAIEmbeddings
 
 
 
 
 
 
 
 
 
19
 
20
  logger = logging.getLogger(__name__)
21
 
@@ -58,23 +67,6 @@ class FileReader:
58
  return None
59
 
60
  def read_pdf(self, temp_file_path: str):
61
- # parser = LlamaParse(
62
- # api_key="",
63
- # result_type="markdown",
64
- # num_workers=4,
65
- # verbose=True,
66
- # language="en",
67
- # )
68
- # documents = parser.load_data(temp_file_path)
69
-
70
- # with open("temp/output.md", "a") as f:
71
- # for doc in documents:
72
- # f.write(doc.text + "\n")
73
-
74
- # markdown_path = "temp/output.md"
75
- # loader = UnstructuredMarkdownLoader(markdown_path)
76
- # loader = PyMuPDFLoader(temp_file_path) # This loader preserves more metadata
77
- # return loader.load()
78
  loader = self.pdf_reader.get_loader(temp_file_path)
79
  documents = self.pdf_reader.get_documents(loader)
80
  return documents
@@ -108,8 +100,6 @@ class FileReader:
108
  class ChunkProcessor:
109
  def __init__(self, config):
110
  self.config = config
111
- self.document_chunks_full = []
112
- self.document_names = []
113
 
114
  if config["splitter_options"]["use_splitter"]:
115
  if config["splitter_options"]["split_by_token"]:
@@ -130,6 +120,17 @@ class ChunkProcessor:
130
  self.splitter = None
131
  logger.info("ChunkProcessor instance created")
132
 
 
 
 
 
 
 
 
 
 
 
 
133
  def remove_delimiters(self, document_chunks: list):
134
  for chunk in document_chunks:
135
  for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
@@ -146,11 +147,23 @@ class ChunkProcessor:
146
  logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
147
  return document_chunks
148
 
149
- def process_chunks(self, documents):
150
- if self.splitter:
 
 
 
151
  document_chunks = self.splitter.split_documents(documents)
152
- else:
153
- document_chunks = documents
 
 
 
 
 
 
 
 
 
154
 
155
  if self.config["splitter_options"]["remove_leftover_delimiters"]:
156
  document_chunks = self.remove_delimiters(document_chunks)
@@ -161,38 +174,77 @@ class ChunkProcessor:
161
 
162
  def get_chunks(self, file_reader, uploaded_files, weblinks):
163
  self.document_chunks_full = []
164
- self.document_names = []
 
 
 
 
 
 
 
165
 
166
  for file_index, file_path in enumerate(uploaded_files):
167
  file_name = os.path.basename(file_path)
168
  file_type = file_name.split(".")[-1].lower()
169
 
170
- try:
171
- if file_type == "pdf":
172
- documents = file_reader.read_pdf(file_path)
173
- elif file_type == "txt":
174
- documents = file_reader.read_txt(file_path)
175
- elif file_type == "docx":
176
- documents = file_reader.read_docx(file_path)
177
- elif file_type == "srt":
178
- documents = file_reader.read_srt(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  else:
180
- logger.warning(f"Unsupported file type: {file_type}")
181
- continue
182
-
183
- document_chunks = self.process_chunks(documents)
184
- self.document_names.append(file_name)
185
- self.document_chunks_full.extend(document_chunks)
 
 
 
 
 
 
 
 
186
 
187
- except Exception as e:
188
- logger.error(f"Error processing file {file_name}: {str(e)}")
189
 
190
  self.process_weblinks(file_reader, weblinks)
191
 
192
  logger.info(
193
  f"Total document chunks extracted: {len(self.document_chunks_full)}"
194
  )
195
- return self.document_chunks_full, self.document_names
 
 
 
 
 
196
 
197
  def process_weblinks(self, file_reader, weblinks):
198
  if weblinks[0] != "":
@@ -206,9 +258,26 @@ class ChunkProcessor:
206
  else:
207
  documents = file_reader.read_html(link)
208
 
209
- document_chunks = self.process_chunks(documents)
210
- self.document_names.append(link)
211
- self.document_chunks_full.extend(document_chunks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  except Exception as e:
213
  logger.error(
214
  f"Error splitting link {link_index+1} : {link}: {str(e)}"
 
2
  import re
3
  import requests
4
  import pysrt
5
+ from langchain_community.document_loaders import (
6
  PyMuPDFLoader,
7
  Docx2txtLoader,
8
  YoutubeLoader,
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain_experimental.text_splitter import SemanticChunker
18
  from langchain_openai.embeddings import OpenAIEmbeddings
19
+ from ragatouille import RAGPretrainedModel
20
+ from langchain.chains import LLMChain
21
+ from langchain.llms import OpenAI
22
+ from langchain import PromptTemplate
23
+
24
+ try:
25
+ from modules.helpers import get_lecture_metadata
26
+ except:
27
+ from helpers import get_lecture_metadata
28
 
29
  logger = logging.getLogger(__name__)
30
 
 
67
  return None
68
 
69
  def read_pdf(self, temp_file_path: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  loader = self.pdf_reader.get_loader(temp_file_path)
71
  documents = self.pdf_reader.get_documents(loader)
72
  return documents
 
100
  class ChunkProcessor:
101
  def __init__(self, config):
102
  self.config = config
 
 
103
 
104
  if config["splitter_options"]["use_splitter"]:
105
  if config["splitter_options"]["split_by_token"]:
 
120
  self.splitter = None
121
  logger.info("ChunkProcessor instance created")
122
 
123
+ # def extract_metadata(self, document_content):
124
+
125
+ # llm = OpenAI()
126
+ # prompt_template = PromptTemplate(
127
+ # input_variables=["document_content"],
128
+ # template="Extract metadata for this document:\n\n{document_content}\n\nMetadata:",
129
+ # )
130
+ # chain = LLMChain(llm=llm, prompt=prompt_template)
131
+ # metadata = chain.run(document_content=document_content)
132
+ # return metadata
133
+
134
  def remove_delimiters(self, document_chunks: list):
135
  for chunk in document_chunks:
136
  for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
 
147
  logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
148
  return document_chunks
149
 
150
+ def process_chunks(
151
+ self, documents, file_type="txt", source="", page=0, metadata={}
152
+ ):
153
+ documents = [Document(page_content=documents, source=source, page=page)]
154
+ if file_type == "txt":
155
  document_chunks = self.splitter.split_documents(documents)
156
+ elif file_type == "pdf":
157
+ document_chunks = documents # Full page for now
158
+
159
+ # add the source and page number back to the metadata
160
+ for chunk in document_chunks:
161
+ chunk.metadata["source"] = source
162
+ chunk.metadata["page"] = page
163
+
164
+ # add the metadata extracted from the document
165
+ for key, value in metadata.items():
166
+ chunk.metadata[key] = value
167
 
168
  if self.config["splitter_options"]["remove_leftover_delimiters"]:
169
  document_chunks = self.remove_delimiters(document_chunks)
 
174
 
175
  def get_chunks(self, file_reader, uploaded_files, weblinks):
176
  self.document_chunks_full = []
177
+ self.parent_document_names = []
178
+ self.child_document_names = []
179
+ self.documents = []
180
+ self.document_metadata = []
181
+
182
+ lecture_metadata = get_lecture_metadata(
183
+ "https://dl4ds.github.io/sp2024/lectures/"
184
+ ) # TODO: Use more efficiently
185
 
186
  for file_index, file_path in enumerate(uploaded_files):
187
  file_name = os.path.basename(file_path)
188
  file_type = file_name.split(".")[-1].lower()
189
 
190
+ # try:
191
+ if file_type == "pdf":
192
+ documents = file_reader.read_pdf(file_path)
193
+ elif file_type == "txt":
194
+ documents = file_reader.read_txt(file_path)
195
+ elif file_type == "docx":
196
+ documents = file_reader.read_docx(file_path)
197
+ elif file_type == "srt":
198
+ documents = file_reader.read_srt(file_path)
199
+ else:
200
+ logger.warning(f"Unsupported file type: {file_type}")
201
+ continue
202
+
203
+ # full_text = ""
204
+ # for doc in documents:
205
+ # full_text += doc.page_content
206
+ # break # getting only first page for now
207
+
208
+ # extracted_metadata = self.extract_metadata(full_text)
209
+
210
+ for doc in documents:
211
+ page_num = doc.metadata.get("page", 0)
212
+ self.documents.append(doc.page_content)
213
+ self.document_metadata.append({"source": file_path, "page": page_num})
214
+ if "lecture" in file_path.lower():
215
+ metadata = lecture_metadata.get(file_path, {})
216
+ metadata["source_type"] = "lecture"
217
+ self.document_metadata[-1].update(metadata)
218
  else:
219
+ metadata = {"source_type": "other"}
220
+
221
+ self.child_document_names.append(f"{file_name}_{page_num}")
222
+
223
+ self.parent_document_names.append(file_name)
224
+ if self.config["embedding_options"]["db_option"] not in ["RAGatouille"]:
225
+ document_chunks = self.process_chunks(
226
+ self.documents[-1],
227
+ file_type,
228
+ source=file_path,
229
+ page=page_num,
230
+ metadata=metadata,
231
+ )
232
+ self.document_chunks_full.extend(document_chunks)
233
 
234
+ # except Exception as e:
235
+ # logger.error(f"Error processing file {file_name}: {str(e)}")
236
 
237
  self.process_weblinks(file_reader, weblinks)
238
 
239
  logger.info(
240
  f"Total document chunks extracted: {len(self.document_chunks_full)}"
241
  )
242
+ return (
243
+ self.document_chunks_full,
244
+ self.child_document_names,
245
+ self.documents,
246
+ self.document_metadata,
247
+ )
248
 
249
  def process_weblinks(self, file_reader, weblinks):
250
  if weblinks[0] != "":
 
258
  else:
259
  documents = file_reader.read_html(link)
260
 
261
+ for doc in documents:
262
+ page_num = doc.metadata.get("page", 0)
263
+ self.documents.append(doc.page_content)
264
+ self.document_metadata.append(
265
+ {"source": link, "page": page_num}
266
+ )
267
+ self.child_document_names.append(f"{link}")
268
+
269
+ self.parent_document_names.append(link)
270
+ if self.config["embedding_options"]["db_option"] not in [
271
+ "RAGatouille"
272
+ ]:
273
+ document_chunks = self.process_chunks(
274
+ self.documents[-1],
275
+ "txt",
276
+ source=link,
277
+ page=0,
278
+ metadata={"source_type": "webpage"},
279
+ )
280
+ self.document_chunks_full.extend(document_chunks)
281
  except Exception as e:
282
  logger.error(
283
  f"Error splitting link {link_index+1} : {link}: {str(e)}"
code/modules/embedding_model_loader.py CHANGED
@@ -1,6 +1,6 @@
1
  from langchain_community.embeddings import OpenAIEmbeddings
2
- from langchain.embeddings import HuggingFaceEmbeddings
3
- from langchain.embeddings import LlamaCppEmbeddings
4
 
5
  try:
6
  from modules.constants import *
 
1
  from langchain_community.embeddings import OpenAIEmbeddings
2
+ from langchain_community.embeddings import HuggingFaceEmbeddings
3
+ from langchain_community.embeddings import LlamaCppEmbeddings
4
 
5
  try:
6
  from modules.constants import *
code/modules/helpers.py CHANGED
@@ -4,6 +4,8 @@ from tqdm import tqdm
4
  from urllib.parse import urlparse
5
  import chainlit as cl
6
  from langchain import PromptTemplate
 
 
7
 
8
  try:
9
  from modules.constants import *
@@ -138,67 +140,133 @@ def get_prompt(config):
138
 
139
 
140
  def get_sources(res, answer):
141
- source_elements_dict = {}
142
  source_elements = []
143
- found_sources = []
144
-
145
  source_dict = {} # Dictionary to store URL elements
146
 
147
  for idx, source in enumerate(res["source_documents"]):
148
  source_metadata = source.metadata
149
  url = source_metadata["source"]
150
  score = source_metadata.get("score", "N/A")
 
 
 
 
 
151
 
152
- if url not in source_dict:
153
- source_dict[url] = [(source.page_content, score)]
 
 
 
 
 
 
 
 
 
 
 
 
154
  else:
155
- source_dict[url].append((source.page_content, score))
156
 
157
- for source_idx, (url, text_list) in enumerate(source_dict.items()):
158
- full_text = ""
159
- for url_idx, (text, score) in enumerate(text_list):
160
- full_text += f"Source {url_idx + 1} (Score: {score}):\n{text}\n\n\n"
161
- source_elements.append(cl.Text(name=url, content=full_text))
162
- found_sources.append(f"{url} (Score: {score})")
163
 
164
- if found_sources:
165
- answer += f"\n\nSources: {', '.join(found_sources)}"
166
- else:
167
- answer += f"\n\nNo source found."
168
-
169
- # for idx, source in enumerate(res["source_documents"]):
170
- # title = source.metadata["source"]
171
-
172
- # if title not in source_elements_dict:
173
- # source_elements_dict[title] = {
174
- # "page_number": [source.metadata["page"]],
175
- # "url": source.metadata["source"],
176
- # "content": source.page_content,
177
- # }
178
-
179
- # else:
180
- # source_elements_dict[title]["page_number"].append(source.metadata["page"])
181
- # source_elements_dict[title][
182
- # "content_" + str(source.metadata["page"])
183
- # ] = source.page_content
184
- # # sort the page numbers
185
- # # source_elements_dict[title]["page_number"].sort()
186
-
187
- # for title, source in source_elements_dict.items():
188
- # # create a string for the page numbers
189
- # page_numbers = ", ".join([str(x) for x in source["page_number"]])
190
- # text_for_source = f"Page Number(s): {page_numbers}\nURL: {source['url']}"
191
- # source_elements.append(cl.Pdf(name="File", path=title))
192
- # found_sources.append("File")
193
- # # for pn in source["page_number"]:
194
- # # source_elements.append(
195
- # # cl.Text(name=str(pn), content=source["content_"+str(pn)])
196
- # # )
197
- # # found_sources.append(str(pn))
198
-
199
- # if found_sources:
200
- # answer += f"\nSource:{', '.join(found_sources)}"
201
- # else:
202
- # answer += f"\nNo source found."
203
-
204
- return answer, source_elements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from urllib.parse import urlparse
5
  import chainlit as cl
6
  from langchain import PromptTemplate
7
+ import requests
8
+ from bs4 import BeautifulSoup
9
 
10
  try:
11
  from modules.constants import *
 
140
 
141
 
142
  def get_sources(res, answer):
 
143
  source_elements = []
 
 
144
  source_dict = {} # Dictionary to store URL elements
145
 
146
  for idx, source in enumerate(res["source_documents"]):
147
  source_metadata = source.metadata
148
  url = source_metadata["source"]
149
  score = source_metadata.get("score", "N/A")
150
+ page = source_metadata.get("page", 1)
151
+
152
+ lecture_tldr = source_metadata.get("tldr", "N/A")
153
+ lecture_recording = source_metadata.get("lecture_recording", "N/A")
154
+ suggested_readings = source_metadata.get("suggested_readings", "N/A")
155
 
156
+ source_type = source_metadata.get("source_type", "N/A")
157
+
158
+ url_name = f"{url}_{page}"
159
+ if url_name not in source_dict:
160
+ source_dict[url_name] = {
161
+ "text": source.page_content,
162
+ "url": url,
163
+ "score": score,
164
+ "page": page,
165
+ "lecture_tldr": lecture_tldr,
166
+ "lecture_recording": lecture_recording,
167
+ "suggested_readings": suggested_readings,
168
+ "source_type": source_type,
169
+ }
170
  else:
171
+ source_dict[url_name]["text"] += f"\n\n{source.page_content}"
172
 
173
+ # First, display the answer
174
+ full_answer = "**Answer:**\n"
175
+ full_answer += answer
 
 
 
176
 
177
+ # Then, display the sources
178
+ full_answer += "\n\n**Sources:**\n"
179
+ for idx, (url_name, source_data) in enumerate(source_dict.items()):
180
+ full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
181
+
182
+ name = f"Source {idx + 1} Text\n"
183
+ full_answer += name
184
+ source_elements.append(cl.Text(name=name, content=source_data["text"]))
185
+
186
+ # Add a PDF element if the source is a PDF file
187
+ if source_data["url"].lower().endswith(".pdf"):
188
+ name = f"Source {idx + 1} PDF\n"
189
+ full_answer += name
190
+ pdf_url = f"{source_data['url']}#page={source_data['page']+1}"
191
+ source_elements.append(cl.Pdf(name=name, url=pdf_url))
192
+
193
+ # Finally, include lecture metadata for each unique source
194
+ # displayed_urls = set()
195
+ # full_answer += "\n**Metadata:**\n"
196
+ # for url_name, source_data in source_dict.items():
197
+ # if source_data["url"] not in displayed_urls:
198
+ # full_answer += f"\nSource: {source_data['url']}\n"
199
+ # full_answer += f"Type: {source_data['source_type']}\n"
200
+ # full_answer += f"TL;DR: {source_data['lecture_tldr']}\n"
201
+ # full_answer += f"Lecture Recording: {source_data['lecture_recording']}\n"
202
+ # full_answer += f"Suggested Readings: {source_data['suggested_readings']}\n"
203
+ # displayed_urls.add(source_data["url"])
204
+ full_answer += "\n**Metadata:**\n"
205
+ for url_name, source_data in source_dict.items():
206
+ full_answer += f"\nSource: {source_data['url']}\n"
207
+ full_answer += f"Page: {source_data['page']}\n"
208
+ full_answer += f"Type: {source_data['source_type']}\n"
209
+ full_answer += f"TL;DR: {source_data['lecture_tldr']}\n"
210
+ full_answer += f"Lecture Recording: {source_data['lecture_recording']}\n"
211
+ full_answer += f"Suggested Readings: {source_data['suggested_readings']}\n"
212
+
213
+ return full_answer, source_elements
214
+
215
+
216
+ def get_lecture_metadata(schedule_url):
217
+ """
218
+ Function to get the lecture metadata from the schedule URL.
219
+ """
220
+ lecture_metadata = {}
221
+
222
+ # Get the main schedule page content
223
+ r = requests.get(schedule_url)
224
+ soup = BeautifulSoup(r.text, "html.parser")
225
+
226
+ # Find all lecture blocks
227
+ lecture_blocks = soup.find_all("div", class_="lecture-container")
228
+
229
+ for block in lecture_blocks:
230
+ try:
231
+ # Extract the lecture title
232
+ title = block.find("span", style="font-weight: bold;").text.strip()
233
+
234
+ # Extract the TL;DR
235
+ tldr = block.find("strong", text="tl;dr:").next_sibling.strip()
236
+
237
+ # Extract the link to the slides
238
+ slides_link_tag = block.find("a", title="Download slides")
239
+ slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
240
+
241
+ # Extract the link to the lecture recording
242
+ recording_link_tag = block.find("a", title="Download lecture recording")
243
+ recording_link = (
244
+ recording_link_tag["href"].strip() if recording_link_tag else None
245
+ )
246
+
247
+ # Extract suggested readings or summary if available
248
+ suggested_readings_tag = block.find("p", text="Suggested Readings:")
249
+ if suggested_readings_tag:
250
+ suggested_readings = suggested_readings_tag.find_next_sibling("ul")
251
+ if suggested_readings:
252
+ suggested_readings = suggested_readings.get_text(
253
+ separator="\n"
254
+ ).strip()
255
+ else:
256
+ suggested_readings = "No specific readings provided."
257
+ else:
258
+ suggested_readings = "No specific readings provided."
259
+
260
+ # Add to the dictionary
261
+ slides_link = f"https://dl4ds.github.io{slides_link}"
262
+ lecture_metadata[slides_link] = {
263
+ "tldr": tldr,
264
+ "title": title,
265
+ "lecture_recording": recording_link,
266
+ "suggested_readings": suggested_readings,
267
+ }
268
+ except Exception as e:
269
+ print(f"Error processing block: {e}")
270
+ continue
271
+
272
+ return lecture_metadata
code/modules/llm_tutor.py CHANGED
@@ -8,7 +8,6 @@ from langchain.llms import CTransformers
8
  from langchain.memory import ConversationBufferWindowMemory
9
  from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
10
  import os
11
-
12
  from modules.constants import *
13
  from modules.helpers import get_prompt
14
  from modules.chat_model_loader import ChatModelLoader
@@ -34,14 +33,21 @@ class LLMTutor:
34
 
35
  # Retrieval QA Chain
36
  def retrieval_qa_chain(self, llm, prompt, db):
37
- retriever = VectorDBScore(
38
- vectorstore=db,
39
- search_type="similarity_score_threshold",
40
- search_kwargs={
41
- "score_threshold": self.config["embedding_options"]["score_threshold"],
42
- "k": self.config["embedding_options"]["search_top_k"],
43
- },
44
- )
 
 
 
 
 
 
 
45
  if self.config["llm_params"]["use_history"]:
46
  memory = ConversationBufferWindowMemory(
47
  k=self.config["llm_params"]["memory_window"],
 
8
  from langchain.memory import ConversationBufferWindowMemory
9
  from langchain.chains.conversational_retrieval.prompts import QA_PROMPT
10
  import os
 
11
  from modules.constants import *
12
  from modules.helpers import get_prompt
13
  from modules.chat_model_loader import ChatModelLoader
 
33
 
34
  # Retrieval QA Chain
35
  def retrieval_qa_chain(self, llm, prompt, db):
36
+ if self.config["embedding_options"]["db_option"] in ["FAISS", "Chroma"]:
37
+ retriever = VectorDBScore(
38
+ vectorstore=db,
39
+ search_type="similarity_score_threshold",
40
+ search_kwargs={
41
+ "score_threshold": self.config["embedding_options"][
42
+ "score_threshold"
43
+ ],
44
+ "k": self.config["embedding_options"]["search_top_k"],
45
+ },
46
+ )
47
+ elif self.config["embedding_options"]["db_option"] == "RAGatouille":
48
+ retriever = db.as_langchain_retriever(
49
+ k=self.config["embedding_options"]["search_top_k"]
50
+ )
51
  if self.config["llm_params"]["use_history"]:
52
  memory = ConversationBufferWindowMemory(
53
  k=self.config["llm_params"]["memory_window"],
code/modules/vector_db.py CHANGED
@@ -1,11 +1,12 @@
1
  import logging
2
  import os
3
  import yaml
4
- from langchain.vectorstores import FAISS, Chroma
5
  from langchain.schema.vectorstore import VectorStoreRetriever
6
  from langchain.callbacks.manager import CallbackManagerForRetrieverRun
7
  from langchain.schema.document import Document
8
  from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
 
9
 
10
  try:
11
  from modules.embedding_model_loader import EmbeddingModelLoader
@@ -25,7 +26,7 @@ class VectorDBScore(VectorStoreRetriever):
25
 
26
  # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
27
  def _get_relevant_documents(
28
- self, query: str, *, run_manager: CallbackManagerForRetrieverRun
29
  ) -> List[Document]:
30
  docs_and_similarities = (
31
  self.vectorstore.similarity_search_with_relevance_scores(
@@ -55,7 +56,6 @@ class VectorDBScore(VectorStoreRetriever):
55
  return docs
56
 
57
 
58
-
59
  class VectorDB:
60
  def __init__(self, config, logger=None):
61
  self.config = config
@@ -116,7 +116,15 @@ class VectorDB:
116
  self.embedding_model_loader = EmbeddingModelLoader(self.config)
117
  self.embedding_model = self.embedding_model_loader.load_embedding_model()
118
 
119
- def initialize_database(self, document_chunks: list, document_names: list):
 
 
 
 
 
 
 
 
120
  # Track token usage
121
  self.logger.info("Initializing vector_db")
122
  self.logger.info("\tUsing {} as db_option".format(self.db_option))
@@ -136,6 +144,14 @@ class VectorDB:
136
  + self.config["embedding_options"]["model"],
137
  ),
138
  )
 
 
 
 
 
 
 
 
139
  self.logger.info("Completed initializing vector_db")
140
 
141
  def create_database(self):
@@ -146,11 +162,13 @@ class VectorDB:
146
  files += lecture_pdfs
147
  if "storage/data/urls.txt" in files:
148
  files.remove("storage/data/urls.txt")
149
- document_chunks, document_names = data_loader.get_chunks(files, urls)
 
 
150
  self.logger.info("Completed loading data")
151
-
152
- self.create_embedding_model()
153
- self.initialize_database(document_chunks, document_names)
154
 
155
  def save_database(self):
156
  if self.db_option == "FAISS":
@@ -166,6 +184,9 @@ class VectorDB:
166
  elif self.db_option == "Chroma":
167
  # db is saved in the persist directory during initialization
168
  pass
 
 
 
169
  self.logger.info("Saved database")
170
 
171
  def load_database(self):
@@ -180,7 +201,7 @@ class VectorDB:
180
  + self.config["embedding_options"]["model"],
181
  ),
182
  self.embedding_model,
183
- # allow_dangerous_deserialization=True, <- unexpected keyword argument to load_local
184
  )
185
  elif self.db_option == "Chroma":
186
  self.vector_db = Chroma(
@@ -193,6 +214,10 @@ class VectorDB:
193
  ),
194
  embedding_function=self.embedding_model,
195
  )
 
 
 
 
196
  self.logger.info("Loaded database")
197
  return self.vector_db
198
 
 
1
  import logging
2
  import os
3
  import yaml
4
+ from langchain_community.vectorstores import FAISS, Chroma
5
  from langchain.schema.vectorstore import VectorStoreRetriever
6
  from langchain.callbacks.manager import CallbackManagerForRetrieverRun
7
  from langchain.schema.document import Document
8
  from langchain_core.callbacks import AsyncCallbackManagerForRetrieverRun
9
+ from ragatouille import RAGPretrainedModel
10
 
11
  try:
12
  from modules.embedding_model_loader import EmbeddingModelLoader
 
26
 
27
  # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
28
  def _get_relevant_documents(
29
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
30
  ) -> List[Document]:
31
  docs_and_similarities = (
32
  self.vectorstore.similarity_search_with_relevance_scores(
 
56
  return docs
57
 
58
 
 
59
  class VectorDB:
60
  def __init__(self, config, logger=None):
61
  self.config = config
 
116
  self.embedding_model_loader = EmbeddingModelLoader(self.config)
117
  self.embedding_model = self.embedding_model_loader.load_embedding_model()
118
 
119
+ def initialize_database(
120
+ self,
121
+ document_chunks: list,
122
+ document_names: list,
123
+ documents: list,
124
+ document_metadata: list,
125
+ ):
126
+ if self.db_option in ["FAISS", "Chroma"]:
127
+ self.create_embedding_model()
128
  # Track token usage
129
  self.logger.info("Initializing vector_db")
130
  self.logger.info("\tUsing {} as db_option".format(self.db_option))
 
144
  + self.config["embedding_options"]["model"],
145
  ),
146
  )
147
+ elif self.db_option == "RAGatouille":
148
+ self.RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")
149
+ index_path = self.RAG.index(
150
+ index_name="new_idx",
151
+ collection=documents,
152
+ document_ids=document_names,
153
+ document_metadatas=document_metadata,
154
+ )
155
  self.logger.info("Completed initializing vector_db")
156
 
157
  def create_database(self):
 
162
  files += lecture_pdfs
163
  if "storage/data/urls.txt" in files:
164
  files.remove("storage/data/urls.txt")
165
+ document_chunks, document_names, documents, document_metadata = (
166
+ data_loader.get_chunks(files, urls)
167
+ )
168
  self.logger.info("Completed loading data")
169
+ self.initialize_database(
170
+ document_chunks, document_names, documents, document_metadata
171
+ )
172
 
173
  def save_database(self):
174
  if self.db_option == "FAISS":
 
184
  elif self.db_option == "Chroma":
185
  # db is saved in the persist directory during initialization
186
  pass
187
+ elif self.db_option == "RAGatouille":
188
+ # index is saved during initialization
189
+ pass
190
  self.logger.info("Saved database")
191
 
192
  def load_database(self):
 
201
  + self.config["embedding_options"]["model"],
202
  ),
203
  self.embedding_model,
204
+ allow_dangerous_deserialization=True,
205
  )
206
  elif self.db_option == "Chroma":
207
  self.vector_db = Chroma(
 
214
  ),
215
  embedding_function=self.embedding_model,
216
  )
217
+ elif self.db_option == "RAGatouille":
218
+ self.vector_db = RAGPretrainedModel.from_index(
219
+ ".ragatouille/colbert/indexes/new_idx"
220
+ )
221
  self.logger.info("Loaded database")
222
  return self.vector_db
223
 
requirements.txt CHANGED
@@ -17,3 +17,4 @@ fake-useragent==1.4.0
17
  git+https://github.com/huggingface/accelerate.git
18
  llama-cpp-python
19
  PyPDF2==3.0.1
 
 
17
  git+https://github.com/huggingface/accelerate.git
18
  llama-cpp-python
19
  PyPDF2==3.0.1
20
+ ragatouille==0.0.8.post2
storage/data/urls.txt CHANGED
@@ -1 +1,3 @@
1
  https://dl4ds.github.io/sp2024/
 
 
 
1
  https://dl4ds.github.io/sp2024/
2
+ https://dl4ds.github.io/sp2024/static_files/lectures/15_RAG_CoT.pdf
3
+ https://dl4ds.github.io/sp2024/static_files/lectures/21_RL_RLHF_v2.pdf