XThomasBU commited on
Commit
57b7b8d
1 Parent(s): fe158b7

modularied dataloader + Added Chroma

Browse files
code/config.yml CHANGED
@@ -1,6 +1,5 @@
1
  embedding_options:
2
  embedd_files: False # bool
3
- persist_directory: null # str or None
4
  data_path: 'storage/data' # str
5
  url_file_path: 'storage/data/urls.txt' # str
6
  expand_urls: True # bool
@@ -8,8 +7,9 @@ embedding_options:
8
  db_path : 'vectorstores' # str
9
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
10
  search_top_k : 3 # int
 
11
  llm_params:
12
- use_history: False # bool
13
  memory_window: 3 # int
14
  llm_loader: 'local_llm' # str [local_llm, openai]
15
  openai_params:
 
1
  embedding_options:
2
  embedd_files: False # bool
 
3
  data_path: 'storage/data' # str
4
  url_file_path: 'storage/data/urls.txt' # str
5
  expand_urls: True # bool
 
7
  db_path : 'vectorstores' # str
8
  model : 'sentence-transformers/all-MiniLM-L6-v2' # str [sentence-transformers/all-MiniLM-L6-v2, text-embedding-ada-002']
9
  search_top_k : 3 # int
10
+ score_threshold : 0.5 # float
11
  llm_params:
12
+ use_history: True # bool
13
  memory_window: 3 # int
14
  llm_loader: 'local_llm' # str [local_llm, openai]
15
  openai_params:
code/modules/data_loader.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import re
 
2
  import pysrt
3
- from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain.document_loaders import (
5
  PyMuPDFLoader,
6
  Docx2txtLoader,
@@ -8,49 +9,32 @@ from langchain.document_loaders import (
8
  WebBaseLoader,
9
  TextLoader,
10
  )
 
 
11
  from langchain.schema import Document
12
- import tempfile
13
- from tempfile import NamedTemporaryFile
14
  import logging
15
- import requests
 
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
19
 
20
- class DataLoader:
21
- def __init__(self, config):
22
- """
23
- Class for handling all data extraction and chunking
24
- Inputs:
25
- config - dictionary from yaml file, containing all important parameters
26
- """
27
- self.config = config
28
- self.remove_leftover_delimiters = config["splitter_options"][
29
- "remove_leftover_delimiters"
30
- ]
31
 
32
- # Main list of all documents
33
- self.document_chunks_full = []
34
- self.document_names = []
35
 
36
- if config["splitter_options"]["use_splitter"]:
37
- if config["splitter_options"]["split_by_token"]:
38
- self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
39
- chunk_size=config["splitter_options"]["chunk_size"],
40
- chunk_overlap=config["splitter_options"]["chunk_overlap"],
41
- separators=config["splitter_options"]["chunk_separators"],
42
- disallowed_special=()
43
- )
44
- else:
45
- self.splitter = RecursiveCharacterTextSplitter(
46
- chunk_size=config["splitter_options"]["chunk_size"],
47
- chunk_overlap=config["splitter_options"]["chunk_overlap"],
48
- separators=config["splitter_options"]["chunk_separators"],
49
- disallowed_special=()
50
- )
51
- else:
52
- self.splitter = None
53
- logger.info("InfoLoader instance created")
54
 
55
  def extract_text_from_pdf(self, pdf_path):
56
  text = ""
@@ -73,215 +57,173 @@ class DataLoader:
73
  print("Failed to download PDF from URL:", pdf_url)
74
  return None
75
 
76
- def get_chunks(self, uploaded_files, weblinks):
77
- # Main list of all documents
78
- self.document_chunks_full = []
79
- self.document_names = []
80
-
81
- def remove_delimiters(document_chunks: list):
82
- """
83
- Helper function to remove remaining delimiters in document chunks
84
- """
85
- for chunk in document_chunks:
86
- for delimiter in self.config["splitter_options"][
87
- "delimiters_to_remove"
88
- ]:
89
- chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
90
- return document_chunks
91
-
92
- def remove_chunks(document_chunks: list):
93
- """
94
- Helper function to remove any unwanted document chunks after splitting
95
- """
96
- front = self.config["splitter_options"]["front_chunk_to_remove"]
97
- end = self.config["splitter_options"]["last_chunks_to_remove"]
98
- # Remove pages
99
- for _ in range(front):
100
- del document_chunks[0]
101
- for _ in range(end):
102
- document_chunks.pop()
103
- logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
104
- return document_chunks
105
-
106
- def get_pdf_from_url(pdf_url: str):
107
- temp_pdf_path = self.download_pdf_from_url(pdf_url)
108
- if temp_pdf_path:
109
- title, document_chunks = get_pdf(temp_pdf_path, pdf_url)
110
- os.remove(temp_pdf_path)
111
- return title, document_chunks
112
-
113
- def get_pdf(temp_file_path: str, title: str):
114
- """
115
- Function to process PDF files
116
- """
117
- loader = PyMuPDFLoader(
118
- temp_file_path
119
- ) # This loader preserves more metadata
120
-
121
- if self.splitter:
122
- document_chunks = self.splitter.split_documents(loader.load())
123
- else:
124
- document_chunks = loader.load()
125
-
126
- if "title" in document_chunks[0].metadata.keys():
127
- title = document_chunks[0].metadata["title"]
128
-
129
- logger.info(
130
- f"\t\tOriginal no. of pages: {document_chunks[0].metadata['total_pages']}"
131
- )
132
-
133
- return title, document_chunks
134
-
135
- def get_txt(temp_file_path: str, title: str):
136
- """
137
- Function to process TXT files
138
- """
139
- loader = TextLoader(temp_file_path, autodetect_encoding=True)
140
-
141
- if self.splitter:
142
- document_chunks = self.splitter.split_documents(loader.load())
143
- else:
144
- document_chunks = loader.load()
145
-
146
- # Update the metadata
147
- for chunk in document_chunks:
148
- chunk.metadata["source"] = title
149
- chunk.metadata["page"] = "N/A"
150
-
151
- return title, document_chunks
152
-
153
- def get_srt(temp_file_path: str, title: str):
154
- """
155
- Function to process SRT files
156
- """
157
- subs = pysrt.open(temp_file_path)
158
-
159
- text = ""
160
- for sub in subs:
161
- text += sub.text
162
- document_chunks = [Document(page_content=text)]
163
-
164
- if self.splitter:
165
- document_chunks = self.splitter.split_documents(document_chunks)
166
-
167
- # Update the metadata
168
- for chunk in document_chunks:
169
- chunk.metadata["source"] = title
170
- chunk.metadata["page"] = "N/A"
171
-
172
- return title, document_chunks
173
-
174
- def get_docx(temp_file_path: str, title: str):
175
- """
176
- Function to process DOCX files
177
- """
178
- loader = Docx2txtLoader(temp_file_path)
179
 
180
- if self.splitter:
181
- document_chunks = self.splitter.split_documents(loader.load())
182
- else:
183
- document_chunks = loader.load()
 
184
 
185
- # Update the metadata
186
- for chunk in document_chunks:
187
- chunk.metadata["source"] = title
188
- chunk.metadata["page"] = "N/A"
189
 
190
- return title, document_chunks
191
 
192
- def get_youtube_transcript(url: str):
193
- """
194
- Function to retrieve youtube transcript and process text
195
- """
196
- loader = YoutubeLoader.from_youtube_url(
197
- url, add_video_info=True, language=["en"], translation="en"
198
- )
 
199
 
200
- if self.splitter:
201
- document_chunks = self.splitter.split_documents(loader.load())
 
 
 
 
 
 
202
  else:
203
- document_chunks = loader.load_and_split()
204
-
205
- # Replace the source with title (for display in st UI later)
206
- for chunk in document_chunks:
207
- chunk.metadata["source"] = chunk.metadata["title"]
208
- logger.info(chunk.metadata["title"])
209
-
210
- return title, document_chunks
211
-
212
- def get_html(url: str):
213
- """
214
- Function to process websites via HTML files
215
- """
216
- loader = WebBaseLoader(url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- if self.splitter:
219
- document_chunks = self.splitter.split_documents(loader.load())
220
- else:
221
- document_chunks = loader.load_and_split()
222
 
223
- title = document_chunks[0].metadata["title"]
224
- logger.info(document_chunks[0].metadata)
225
 
226
- return title, document_chunks
 
 
227
 
228
- # Handle file by file
229
  for file_index, file_path in enumerate(uploaded_files):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- file_name = file_path.split("/")[-1]
232
- file_type = file_name.split(".")[-1]
 
 
233
 
234
- # Handle different file types
235
- if file_type == "pdf":
236
- try:
237
- title, document_chunks = get_pdf(file_path, file_name)
238
- except:
239
- title, document_chunks = get_pdf_from_url(file_path)
240
- elif file_type == "txt":
241
- title, document_chunks = get_txt(file_path, file_name)
242
- elif file_type == "docx":
243
- title, document_chunks = get_docx(file_path, file_name)
244
- elif file_type == "srt":
245
- title, document_chunks = get_srt(file_path, file_name)
246
-
247
- # Additional wrangling - Remove leftover delimiters and any specified chunks
248
- if self.remove_leftover_delimiters:
249
- document_chunks = remove_delimiters(document_chunks)
250
- if self.config["splitter_options"]["remove_chunks"]:
251
- document_chunks = remove_chunks(document_chunks)
252
-
253
- logger.info(f"\t\tExtracted no. of chunks: {len(document_chunks)} from {file_name}")
254
- self.document_names.append(title)
255
- self.document_chunks_full.extend(document_chunks)
256
-
257
- # Handle youtube links:
258
  if weblinks[0] != "":
259
  logger.info(f"Splitting weblinks: total of {len(weblinks)}")
260
 
261
- # Handle link by link
262
  for link_index, link in enumerate(weblinks):
263
  try:
264
  logger.info(f"\tSplitting link {link_index+1} : {link}")
265
  if "youtube" in link:
266
- title, document_chunks = get_youtube_transcript(link)
267
  else:
268
- title, document_chunks = get_html(link)
269
-
270
- # Additional wrangling - Remove leftover delimiters and any specified chunks
271
- if self.remove_leftover_delimiters:
272
- document_chunks = remove_delimiters(document_chunks)
273
- if self.config["splitter_options"]["remove_chunks"]:
274
- document_chunks = remove_chunks(document_chunks)
275
 
276
- print(f"\t\tExtracted no. of chunks: {len(document_chunks)}")
277
- self.document_names.append(title)
278
  self.document_chunks_full.extend(document_chunks)
279
- except:
280
- logger.info(f"\t\tError splitting link {link_index+1} : {link}")
281
- exit()
 
282
 
283
- logger.info(
284
- f"\tNumber of document chunks extracted in total: {len(self.document_chunks_full)}\n\n"
285
- )
286
 
287
- return self.document_chunks_full, self.document_names
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import re
3
+ import requests
4
  import pysrt
 
5
  from langchain.document_loaders import (
6
  PyMuPDFLoader,
7
  Docx2txtLoader,
 
9
  WebBaseLoader,
10
  TextLoader,
11
  )
12
+ from langchain_community.document_loaders import UnstructuredMarkdownLoader
13
+ from llama_parse import LlamaParse
14
  from langchain.schema import Document
 
 
15
  import logging
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain_experimental.text_splitter import SemanticChunker
18
+ from langchain_openai.embeddings import OpenAIEmbeddings
19
 
20
  logger = logging.getLogger(__name__)
21
 
22
 
23
+ class PDFReader:
24
+ def __init__(self):
25
+ pass
 
 
 
 
 
 
 
 
26
 
27
+ def get_loader(self, pdf_path):
28
+ loader = PyMuPDFLoader(pdf_path)
29
+ return loader
30
 
31
+ def get_documents(self, loader):
32
+ return loader.load()
33
+
34
+
35
+ class FileReader:
36
+ def __init__(self):
37
+ self.pdf_reader = PDFReader()
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def extract_text_from_pdf(self, pdf_path):
40
  text = ""
 
57
  print("Failed to download PDF from URL:", pdf_url)
58
  return None
59
 
60
+ def read_pdf(self, temp_file_path: str):
61
+ # parser = LlamaParse(
62
+ # api_key="",
63
+ # result_type="markdown",
64
+ # num_workers=4,
65
+ # verbose=True,
66
+ # language="en",
67
+ # )
68
+ # documents = parser.load_data(temp_file_path)
69
+
70
+ # with open("temp/output.md", "a") as f:
71
+ # for doc in documents:
72
+ # f.write(doc.text + "\n")
73
+
74
+ # markdown_path = "temp/output.md"
75
+ # loader = UnstructuredMarkdownLoader(markdown_path)
76
+ # loader = PyMuPDFLoader(temp_file_path) # This loader preserves more metadata
77
+ # return loader.load()
78
+ loader = self.pdf_reader.get_loader(temp_file_path)
79
+ documents = self.pdf_reader.get_documents(loader)
80
+ return documents
81
+
82
+ def read_txt(self, temp_file_path: str):
83
+ loader = TextLoader(temp_file_path, autodetect_encoding=True)
84
+ return loader.load()
85
+
86
+ def read_docx(self, temp_file_path: str):
87
+ loader = Docx2txtLoader(temp_file_path)
88
+ return loader.load()
89
+
90
+ def read_srt(self, temp_file_path: str):
91
+ subs = pysrt.open(temp_file_path)
92
+ text = ""
93
+ for sub in subs:
94
+ text += sub.text
95
+ return [Document(page_content=text)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ def read_youtube_transcript(self, url: str):
98
+ loader = YoutubeLoader.from_youtube_url(
99
+ url, add_video_info=True, language=["en"], translation="en"
100
+ )
101
+ return loader.load()
102
 
103
+ def read_html(self, url: str):
104
+ loader = WebBaseLoader(url)
105
+ return loader.load()
 
106
 
 
107
 
108
+ class ChunkProcessor:
109
+ def __init__(self, config):
110
+ self.config = config
111
+ self.remove_leftover_delimiters = config["splitter_options"][
112
+ "remove_leftover_delimiters"
113
+ ]
114
+ self.document_chunks_full = []
115
+ self.document_names = []
116
 
117
+ if config["splitter_options"]["use_splitter"]:
118
+ if config["splitter_options"]["split_by_token"]:
119
+ self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
120
+ chunk_size=config["splitter_options"]["chunk_size"],
121
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
122
+ separators=config["splitter_options"]["chunk_separators"],
123
+ disallowed_special=(),
124
+ )
125
  else:
126
+ self.splitter = RecursiveCharacterTextSplitter(
127
+ chunk_size=config["splitter_options"]["chunk_size"],
128
+ chunk_overlap=config["splitter_options"]["chunk_overlap"],
129
+ separators=config["splitter_options"]["chunk_separators"],
130
+ disallowed_special=(),
131
+ )
132
+ else:
133
+ self.splitter = None
134
+ logger.info("ChunkProcessor instance created")
135
+
136
+ def remove_delimiters(self, document_chunks: list):
137
+ for chunk in document_chunks:
138
+ for delimiter in self.config["splitter_options"]["delimiters_to_remove"]:
139
+ chunk.page_content = re.sub(delimiter, " ", chunk.page_content)
140
+ return document_chunks
141
+
142
+ def remove_chunks(self, document_chunks: list):
143
+ front = self.config["splitter_options"]["front_chunk_to_remove"]
144
+ end = self.config["splitter_options"]["last_chunks_to_remove"]
145
+ for _ in range(front):
146
+ del document_chunks[0]
147
+ for _ in range(end):
148
+ document_chunks.pop()
149
+ logger.info(f"\tNumber of pages after skipping: {len(document_chunks)}")
150
+ return document_chunks
151
+
152
+ def process_chunks(self, documents):
153
+ if self.splitter:
154
+ document_chunks = self.splitter.split_documents(documents)
155
+ else:
156
+ document_chunks = documents
157
 
158
+ if self.remove_leftover_delimiters:
159
+ document_chunks = self.remove_delimiters(document_chunks)
160
+ if self.config["splitter_options"]["remove_chunks"]:
161
+ document_chunks = self.remove_chunks(document_chunks)
162
 
163
+ return document_chunks
 
164
 
165
+ def get_chunks(self, file_reader, uploaded_files, weblinks):
166
+ self.document_chunks_full = []
167
+ self.document_names = []
168
 
 
169
  for file_index, file_path in enumerate(uploaded_files):
170
+ file_name = os.path.basename(file_path)
171
+ file_type = file_name.split(".")[-1].lower()
172
+
173
+ try:
174
+ if file_type == "pdf":
175
+ documents = file_reader.read_pdf(file_path)
176
+ elif file_type == "txt":
177
+ documents = file_reader.read_txt(file_path)
178
+ elif file_type == "docx":
179
+ documents = file_reader.read_docx(file_path)
180
+ elif file_type == "srt":
181
+ documents = file_reader.read_srt(file_path)
182
+ else:
183
+ logger.warning(f"Unsupported file type: {file_type}")
184
+ continue
185
+
186
+ document_chunks = self.process_chunks(documents)
187
+ self.document_names.append(file_name)
188
+ self.document_chunks_full.extend(document_chunks)
189
+
190
+ except Exception as e:
191
+ logger.error(f"Error processing file {file_name}: {str(e)}")
192
+
193
+ self.process_weblinks(file_reader, weblinks)
194
 
195
+ logger.info(
196
+ f"Total document chunks extracted: {len(self.document_chunks_full)}"
197
+ )
198
+ return self.document_chunks_full, self.document_names
199
 
200
+ def process_weblinks(self, file_reader, weblinks):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  if weblinks[0] != "":
202
  logger.info(f"Splitting weblinks: total of {len(weblinks)}")
203
 
 
204
  for link_index, link in enumerate(weblinks):
205
  try:
206
  logger.info(f"\tSplitting link {link_index+1} : {link}")
207
  if "youtube" in link:
208
+ documents = file_reader.read_youtube_transcript(link)
209
  else:
210
+ documents = file_reader.read_html(link)
 
 
 
 
 
 
211
 
212
+ document_chunks = self.process_chunks(documents)
213
+ self.document_names.append(link)
214
  self.document_chunks_full.extend(document_chunks)
215
+ except Exception as e:
216
+ logger.error(
217
+ f"Error splitting link {link_index+1} : {link}: {str(e)}"
218
+ )
219
 
 
 
 
220
 
221
+ class DataLoader:
222
+ def __init__(self, config):
223
+ self.file_reader = FileReader()
224
+ self.chunk_processor = ChunkProcessor(config)
225
+
226
+ def get_chunks(self, uploaded_files, weblinks):
227
+ return self.chunk_processor.get_chunks(
228
+ self.file_reader, uploaded_files, weblinks
229
+ )
code/modules/embedding_model_loader.py CHANGED
@@ -1,6 +1,7 @@
1
  from langchain_community.embeddings import OpenAIEmbeddings
2
  from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain.embeddings import LlamaCppEmbeddings
 
4
  try:
5
  from modules.constants import *
6
  except:
@@ -19,6 +20,7 @@ class EmbeddingModelLoader:
19
  model=self.config["embedding_options"]["model"],
20
  show_progress_bar=True,
21
  openai_api_key=OPENAI_API_KEY,
 
22
  )
23
  else:
24
  embedding_model = HuggingFaceEmbeddings(
 
1
  from langchain_community.embeddings import OpenAIEmbeddings
2
  from langchain.embeddings import HuggingFaceEmbeddings
3
  from langchain.embeddings import LlamaCppEmbeddings
4
+
5
  try:
6
  from modules.constants import *
7
  except:
 
20
  model=self.config["embedding_options"]["model"],
21
  show_progress_bar=True,
22
  openai_api_key=OPENAI_API_KEY,
23
+ disallowed_special=(),
24
  )
25
  else:
26
  embedding_model = HuggingFaceEmbeddings(
code/modules/helpers.py CHANGED
@@ -4,6 +4,7 @@ from tqdm import tqdm
4
  from urllib.parse import urlparse
5
  import chainlit as cl
6
  from langchain import PromptTemplate
 
7
  try:
8
  from modules.constants import *
9
  except:
@@ -60,7 +61,7 @@ class WebpageCrawler:
60
 
61
  def get_subpage_links(self, l, base_url):
62
  for link in tqdm(l):
63
- print('checking link:', link)
64
  if not link.endswith("/"):
65
  l[link] = "Checked"
66
  dict_links_subpages = {}
@@ -109,6 +110,7 @@ def get_base_url(url):
109
  base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
110
  return base_url
111
 
 
112
  def get_prompt(config):
113
  if config["llm_params"]["use_history"]:
114
  if config["llm_params"]["llm_loader"] == "local_llm":
@@ -134,6 +136,7 @@ def get_prompt(config):
134
  )
135
  return prompt
136
 
 
137
  def get_sources(res, answer):
138
  source_elements_dict = {}
139
  source_elements = []
@@ -144,21 +147,22 @@ def get_sources(res, answer):
144
  for idx, source in enumerate(res["source_documents"]):
145
  source_metadata = source.metadata
146
  url = source_metadata["source"]
 
147
 
148
  if url not in source_dict:
149
- source_dict[url] = [source.page_content]
150
  else:
151
- source_dict[url].append(source.page_content)
152
 
153
  for source_idx, (url, text_list) in enumerate(source_dict.items()):
154
  full_text = ""
155
- for url_idx, text in enumerate(text_list):
156
- full_text += f"Source {url_idx+1}:\n {text}\n\n\n"
157
  source_elements.append(cl.Text(name=url, content=full_text))
158
- found_sources.append(url)
159
 
160
  if found_sources:
161
- answer += f"\n\nSources: {', '.join(found_sources)} "
162
  else:
163
  answer += f"\n\nNo source found."
164
 
 
4
  from urllib.parse import urlparse
5
  import chainlit as cl
6
  from langchain import PromptTemplate
7
+
8
  try:
9
  from modules.constants import *
10
  except:
 
61
 
62
  def get_subpage_links(self, l, base_url):
63
  for link in tqdm(l):
64
+ print("checking link:", link)
65
  if not link.endswith("/"):
66
  l[link] = "Checked"
67
  dict_links_subpages = {}
 
110
  base_url = f"{parsed_url.scheme}://{parsed_url.netloc}/"
111
  return base_url
112
 
113
+
114
  def get_prompt(config):
115
  if config["llm_params"]["use_history"]:
116
  if config["llm_params"]["llm_loader"] == "local_llm":
 
136
  )
137
  return prompt
138
 
139
+
140
  def get_sources(res, answer):
141
  source_elements_dict = {}
142
  source_elements = []
 
147
  for idx, source in enumerate(res["source_documents"]):
148
  source_metadata = source.metadata
149
  url = source_metadata["source"]
150
+ score = source_metadata.get("score", "N/A")
151
 
152
  if url not in source_dict:
153
+ source_dict[url] = [(source.page_content, score)]
154
  else:
155
+ source_dict[url].append((source.page_content, score))
156
 
157
  for source_idx, (url, text_list) in enumerate(source_dict.items()):
158
  full_text = ""
159
+ for url_idx, (text, score) in enumerate(text_list):
160
+ full_text += f"Source {url_idx + 1} (Score: {score}):\n{text}\n\n\n"
161
  source_elements.append(cl.Text(name=url, content=full_text))
162
+ found_sources.append(f"{url} (Score: {score})")
163
 
164
  if found_sources:
165
+ answer += f"\n\nSources: {', '.join(found_sources)}"
166
  else:
167
  answer += f"\n\nNo source found."
168
 
code/modules/llm_tutor.py CHANGED
@@ -12,7 +12,7 @@ import os
12
  from modules.constants import *
13
  from modules.helpers import get_prompt
14
  from modules.chat_model_loader import ChatModelLoader
15
- from modules.vector_db import VectorDB
16
 
17
 
18
  class LLMTutor:
@@ -34,19 +34,25 @@ class LLMTutor:
34
 
35
  # Retrieval QA Chain
36
  def retrieval_qa_chain(self, llm, prompt, db):
 
 
 
 
 
 
 
 
37
  if self.config["llm_params"]["use_history"]:
38
  memory = ConversationBufferWindowMemory(
39
- k = self.config["llm_params"]["memory_window"],
40
- memory_key="chat_history", return_messages=True, output_key="answer"
 
 
41
  )
42
  qa_chain = ConversationalRetrievalChain.from_llm(
43
  llm=llm,
44
  chain_type="stuff",
45
- retriever=db.as_retriever(
46
- search_kwargs={
47
- "k": self.config["embedding_options"]["search_top_k"]
48
- }
49
- ),
50
  return_source_documents=True,
51
  memory=memory,
52
  combine_docs_chain_kwargs={"prompt": prompt},
@@ -55,11 +61,7 @@ class LLMTutor:
55
  qa_chain = RetrievalQA.from_chain_type(
56
  llm=llm,
57
  chain_type="stuff",
58
- retriever=db.as_retriever(
59
- search_kwargs={
60
- "k": self.config["embedding_options"]["search_top_k"]
61
- }
62
- ),
63
  return_source_documents=True,
64
  chain_type_kwargs={"prompt": prompt},
65
  )
 
12
  from modules.constants import *
13
  from modules.helpers import get_prompt
14
  from modules.chat_model_loader import ChatModelLoader
15
+ from modules.vector_db import VectorDB, VectorDBScore
16
 
17
 
18
  class LLMTutor:
 
34
 
35
  # Retrieval QA Chain
36
  def retrieval_qa_chain(self, llm, prompt, db):
37
+ retriever = VectorDBScore(
38
+ vectorstore=db,
39
+ search_type="similarity_score_threshold",
40
+ search_kwargs={
41
+ "score_threshold": self.config["embedding_options"]["score_threshold"],
42
+ "k": self.config["embedding_options"]["search_top_k"],
43
+ },
44
+ )
45
  if self.config["llm_params"]["use_history"]:
46
  memory = ConversationBufferWindowMemory(
47
+ k=self.config["llm_params"]["memory_window"],
48
+ memory_key="chat_history",
49
+ return_messages=True,
50
+ output_key="answer",
51
  )
52
  qa_chain = ConversationalRetrievalChain.from_llm(
53
  llm=llm,
54
  chain_type="stuff",
55
+ retriever=retriever,
 
 
 
 
56
  return_source_documents=True,
57
  memory=memory,
58
  combine_docs_chain_kwargs={"prompt": prompt},
 
61
  qa_chain = RetrievalQA.from_chain_type(
62
  llm=llm,
63
  chain_type="stuff",
64
+ retriever=retriever,
 
 
 
 
65
  return_source_documents=True,
66
  chain_type_kwargs={"prompt": prompt},
67
  )
code/modules/vector_db.py CHANGED
@@ -1,7 +1,10 @@
1
  import logging
2
  import os
3
  import yaml
4
- from langchain.vectorstores import FAISS
 
 
 
5
 
6
  try:
7
  from modules.embedding_model_loader import EmbeddingModelLoader
@@ -15,6 +18,24 @@ except:
15
  from helpers import *
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class VectorDB:
19
  def __init__(self, config, logger=None):
20
  self.config = config
@@ -61,10 +82,12 @@ class VectorDB:
61
  return files, urls
62
 
63
  def clean_url_list(self, urls):
64
- # get lecture pdf links
65
  lecture_pdfs = [link for link in urls if link.endswith(".pdf")]
66
  lecture_pdfs = [link for link in lecture_pdfs if "lecture" in link.lower()]
67
- urls = [link for link in urls if link.endswith("/")] # only keep links that end with a '/'. Extract Files Seperately
 
 
68
 
69
  return urls, lecture_pdfs
70
 
@@ -81,6 +104,18 @@ class VectorDB:
81
  self.vector_db = FAISS.from_documents(
82
  documents=document_chunks, embedding=self.embedding_model
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
84
  self.logger.info("Completed initializing vector_db")
85
 
86
  def create_database(self):
@@ -89,7 +124,8 @@ class VectorDB:
89
  files, urls = self.load_files()
90
  urls, lecture_pdfs = self.clean_url_list(urls)
91
  files += lecture_pdfs
92
- files.remove('storage/data/urls.txt')
 
93
  document_chunks, document_names = data_loader.get_chunks(files, urls)
94
  self.logger.info("Completed loading data")
95
 
@@ -97,29 +133,46 @@ class VectorDB:
97
  self.initialize_database(document_chunks, document_names)
98
 
99
  def save_database(self):
100
- self.vector_db.save_local(
101
- os.path.join(
102
- self.config["embedding_options"]["db_path"],
103
- "db_"
104
- + self.config["embedding_options"]["db_option"]
105
- + "_"
106
- + self.config["embedding_options"]["model"],
 
 
107
  )
108
- )
 
 
109
  self.logger.info("Saved database")
110
 
111
  def load_database(self):
112
  self.create_embedding_model()
113
- self.vector_db = FAISS.load_local(
114
- os.path.join(
115
- self.config["embedding_options"]["db_path"],
116
- "db_"
117
- + self.config["embedding_options"]["db_option"]
118
- + "_"
119
- + self.config["embedding_options"]["model"],
120
- ),
121
- self.embedding_model,
122
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  self.logger.info("Loaded database")
124
  return self.vector_db
125
 
 
1
  import logging
2
  import os
3
  import yaml
4
+ from langchain.vectorstores import FAISS, Chroma
5
+ from langchain.schema.vectorstore import VectorStoreRetriever
6
+ from langchain.callbacks.manager import CallbackManagerForRetrieverRun
7
+ from langchain.schema.document import Document
8
 
9
  try:
10
  from modules.embedding_model_loader import EmbeddingModelLoader
 
18
  from helpers import *
19
 
20
 
21
+ class VectorDBScore(VectorStoreRetriever):
22
+ # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
23
+ def _get_relevant_documents(
24
+ self, query: str, *, run_manager: CallbackManagerForRetrieverRun
25
+ ):
26
+ docs_and_similarities = (
27
+ self.vectorstore.similarity_search_with_relevance_scores(
28
+ query, **self.search_kwargs
29
+ )
30
+ )
31
+ # Make the score part of the document metadata
32
+ for doc, similarity in docs_and_similarities:
33
+ doc.metadata["score"] = similarity
34
+
35
+ docs = [doc for doc, _ in docs_and_similarities]
36
+ return docs
37
+
38
+
39
  class VectorDB:
40
  def __init__(self, config, logger=None):
41
  self.config = config
 
82
  return files, urls
83
 
84
  def clean_url_list(self, urls):
85
+ # get lecture pdf links
86
  lecture_pdfs = [link for link in urls if link.endswith(".pdf")]
87
  lecture_pdfs = [link for link in lecture_pdfs if "lecture" in link.lower()]
88
+ urls = [
89
+ link for link in urls if link.endswith("/")
90
+ ] # only keep links that end with a '/'. Extract Files Seperately
91
 
92
  return urls, lecture_pdfs
93
 
 
104
  self.vector_db = FAISS.from_documents(
105
  documents=document_chunks, embedding=self.embedding_model
106
  )
107
+ elif self.db_option == "Chroma":
108
+ self.vector_db = Chroma.from_documents(
109
+ documents=document_chunks,
110
+ embedding=self.embedding_model,
111
+ persist_directory=os.path.join(
112
+ self.config["embedding_options"]["db_path"],
113
+ "db_"
114
+ + self.config["embedding_options"]["db_option"]
115
+ + "_"
116
+ + self.config["embedding_options"]["model"],
117
+ ),
118
+ )
119
  self.logger.info("Completed initializing vector_db")
120
 
121
  def create_database(self):
 
124
  files, urls = self.load_files()
125
  urls, lecture_pdfs = self.clean_url_list(urls)
126
  files += lecture_pdfs
127
+ if "storage/data/urls.txt" in files:
128
+ files.remove("storage/data/urls.txt")
129
  document_chunks, document_names = data_loader.get_chunks(files, urls)
130
  self.logger.info("Completed loading data")
131
 
 
133
  self.initialize_database(document_chunks, document_names)
134
 
135
  def save_database(self):
136
+ if self.db_option == "FAISS":
137
+ self.vector_db.save_local(
138
+ os.path.join(
139
+ self.config["embedding_options"]["db_path"],
140
+ "db_"
141
+ + self.config["embedding_options"]["db_option"]
142
+ + "_"
143
+ + self.config["embedding_options"]["model"],
144
+ )
145
  )
146
+ elif self.db_option == "Chroma":
147
+ # db is saved in the persist directory during initialization
148
+ pass
149
  self.logger.info("Saved database")
150
 
151
  def load_database(self):
152
  self.create_embedding_model()
153
+ if self.db_option == "FAISS":
154
+ self.vector_db = FAISS.load_local(
155
+ os.path.join(
156
+ self.config["embedding_options"]["db_path"],
157
+ "db_"
158
+ + self.config["embedding_options"]["db_option"]
159
+ + "_"
160
+ + self.config["embedding_options"]["model"],
161
+ ),
162
+ self.embedding_model,
163
+ allow_dangerous_deserialization=True,
164
+ )
165
+ elif self.db_option == "Chroma":
166
+ self.vector_db = Chroma(
167
+ persist_directory=os.path.join(
168
+ self.config["embedding_options"]["db_path"],
169
+ "db_"
170
+ + self.config["embedding_options"]["db_option"]
171
+ + "_"
172
+ + self.config["embedding_options"]["model"],
173
+ ),
174
+ embedding_function=self.embedding_model,
175
+ )
176
  self.logger.info("Loaded database")
177
  return self.vector_db
178