Thomas (Tom) Gardos commited on
Commit
59567ad
2 Parent(s): e17a5d0 ca828e7

Merge pull request #82 from DL4DS/remove_hardcoded

Browse files
README.md CHANGED
@@ -37,7 +37,7 @@ Please visit [setup](https://dl4ds.github.io/dl4ds_tutor/guide/setup/) for more
37
  3. **To test Data Loading (Optional)**
38
  ```bash
39
  cd code
40
- python -m modules.dataloader.data_loader
41
  ```
42
 
43
  4. **Create the Vector Database**
@@ -47,9 +47,10 @@ Please visit [setup](https://dl4ds.github.io/dl4ds_tutor/guide/setup/) for more
47
  ```
48
  - Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
49
 
50
- 5. **Run the Chainlit App**
51
  ```bash
52
- chainlit run main.py
 
53
  ```
54
 
55
  ## Documentation
 
37
  3. **To test Data Loading (Optional)**
38
  ```bash
39
  cd code
40
+ python -m modules.dataloader.data_loader --links "your_pdf_link"
41
  ```
42
 
43
  4. **Create the Vector Database**
 
47
  ```
48
  - Note: You need to run the above command when you add new data to the `storage/data` directory, or if the `storage/data/urls.txt` file is updated.
49
 
50
+ 6. **Run the FastAPI App**
51
  ```bash
52
+ cd code
53
+ uvicorn app:app --port 7860
54
  ```
55
 
56
  ## Documentation
code/main.py CHANGED
@@ -505,7 +505,6 @@ class Chatbot:
505
  token_count += token_count_cb.total_tokens
506
 
507
  for question in list_of_questions:
508
-
509
  actions.append(
510
  cl.Action(
511
  name="follow up question",
@@ -549,7 +548,6 @@ class Chatbot:
549
 
550
  @cl.header_auth_callback
551
  def header_auth_callback(headers: dict) -> Optional[cl.User]:
552
-
553
  print("\n\n\nI am here\n\n\n")
554
  # try: # TODO: Add try-except block after testing
555
  # TODO: Implement to get the user information from the headers (not the cookie)
 
505
  token_count += token_count_cb.total_tokens
506
 
507
  for question in list_of_questions:
 
508
  actions.append(
509
  cl.Action(
510
  name="follow up question",
 
548
 
549
  @cl.header_auth_callback
550
  def header_auth_callback(headers: dict) -> Optional[cl.User]:
 
551
  print("\n\n\nI am here\n\n\n")
552
  # try: # TODO: Add try-except block after testing
553
  # TODO: Implement to get the user information from the headers (not the cookie)
code/modules/chat/helpers.py CHANGED
@@ -42,7 +42,6 @@ def get_sources(res, answer, stream=True, view_sources=False):
42
  full_answer += answer
43
 
44
  if view_sources:
45
-
46
  # Then, display the sources
47
  # check if the answer has sources
48
  if len(source_dict) == 0:
@@ -51,7 +50,6 @@ def get_sources(res, answer, stream=True, view_sources=False):
51
  else:
52
  full_answer += "\n\n**Sources:**\n"
53
  for idx, (url_name, source_data) in enumerate(source_dict.items()):
54
-
55
  full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
56
 
57
  name = f"Source {idx + 1} Text\n"
 
42
  full_answer += answer
43
 
44
  if view_sources:
 
45
  # Then, display the sources
46
  # check if the answer has sources
47
  if len(source_dict) == 0:
 
50
  else:
51
  full_answer += "\n\n**Sources:**\n"
52
  for idx, (url_name, source_data) in enumerate(source_dict.items()):
 
53
  full_answer += f"\nSource {idx + 1} (Score: {source_data['score']}): {source_data['url']}\n"
54
 
55
  name = f"Source {idx + 1} Text\n"
code/modules/chat/langchain/langchain_rag.py CHANGED
@@ -19,7 +19,6 @@ from .utils import (
19
 
20
 
21
  class Langchain_RAG_V1(BaseRAG):
22
-
23
  def __init__(
24
  self,
25
  llm,
 
19
 
20
 
21
  class Langchain_RAG_V1(BaseRAG):
 
22
  def __init__(
23
  self,
24
  llm,
code/modules/chat/langchain/utils.py CHANGED
@@ -26,7 +26,6 @@ CHAT_TURN_TYPE = Union[Tuple[str, str], BaseMessage]
26
 
27
 
28
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
29
-
30
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
31
  _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
32
  buffer = ""
@@ -139,7 +138,6 @@ class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
139
 
140
 
141
  class CustomRunnableWithHistory(RunnableWithMessageHistory):
142
-
143
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
144
  _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
145
  buffer = ""
@@ -282,7 +280,6 @@ def create_retrieval_chain(
282
 
283
  # TODO: Remove Hard-coded values
284
  async def return_questions(query, response, chat_history_str, context, config):
285
-
286
  system = (
287
  "You are someone that suggests a question based on the student's input and chat history. "
288
  "Generate a question that is relevant to the student's input and chat history. "
 
26
 
27
 
28
  class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
 
29
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
30
  _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
31
  buffer = ""
 
138
 
139
 
140
  class CustomRunnableWithHistory(RunnableWithMessageHistory):
 
141
  def _get_chat_history(self, chat_history: List[CHAT_TURN_TYPE]) -> str:
142
  _ROLE_MAP = {"human": "Student: ", "ai": "AI Tutor: "}
143
  buffer = ""
 
280
 
281
  # TODO: Remove Hard-coded values
282
  async def return_questions(query, response, chat_history_str, context, config):
 
283
  system = (
284
  "You are someone that suggests a question based on the student's input and chat history. "
285
  "Generate a question that is relevant to the student's input and chat history. "
code/modules/chat_processor/helpers.py CHANGED
@@ -156,7 +156,6 @@ async def update_user_info(user_info):
156
 
157
 
158
  async def check_user_cooldown(user_info, current_time):
159
-
160
  # # Check if no tokens left
161
  tokens_left = user_info.metadata.get("tokens_left", 0)
162
  if tokens_left > 0 and not user_info.metadata.get("in_cooldown", False):
@@ -214,7 +213,6 @@ async def reset_tokens_for_user(user_info):
214
 
215
  # Calculate how many tokens should have been regenerated proportionally
216
  if current_tokens < max_tokens:
217
-
218
  # Calculate the regeneration rate per second based on REGEN_TIME for full regeneration
219
  regeneration_rate_per_second = max_tokens / REGEN_TIME
220
 
 
156
 
157
 
158
  async def check_user_cooldown(user_info, current_time):
 
159
  # # Check if no tokens left
160
  tokens_left = user_info.metadata.get("tokens_left", 0)
161
  if tokens_left > 0 and not user_info.metadata.get("in_cooldown", False):
 
213
 
214
  # Calculate how many tokens should have been regenerated proportionally
215
  if current_tokens < max_tokens:
 
216
  # Calculate the regeneration rate per second based on REGEN_TIME for full regeneration
217
  regeneration_rate_per_second = max_tokens / REGEN_TIME
218
 
code/modules/config/{user_config.yml → project_config.yml} RENAMED
@@ -1,3 +1,7 @@
1
  retriever:
2
  retriever_hf_paths:
3
  RAGatouille: "XThomasBU/Colbert_Index"
 
 
 
 
 
1
  retriever:
2
  retriever_hf_paths:
3
  RAGatouille: "XThomasBU/Colbert_Index"
4
+
5
+ metadata:
6
+ metadata_links: ["https://dl4ds.github.io/sp2024/lectures/", "https://dl4ds.github.io/sp2024/schedule/"]
7
+ slide_base_link: "https://dl4ds.github.io"
code/modules/dataloader/data_loader.py CHANGED
@@ -222,8 +222,7 @@ class ChunkProcessor:
222
 
223
  def chunk_docs(self, file_reader, uploaded_files, weblinks):
224
  addl_metadata = get_metadata(
225
- "https://dl4ds.github.io/sp2024/lectures/",
226
- "https://dl4ds.github.io/sp2024/schedule/",
227
  ) # For any additional metadata
228
 
229
  # remove already processed files if reparse_files is False
@@ -325,7 +324,6 @@ class ChunkProcessor:
325
  return
326
 
327
  try:
328
-
329
  if file_path in self.document_data:
330
  self.logger.warning(f"File {file_name} already processed")
331
  documents = [
@@ -419,6 +417,15 @@ class DataLoader:
419
 
420
  if __name__ == "__main__":
421
  import yaml
 
 
 
 
 
 
 
 
 
422
 
423
  logger = logging.getLogger(__name__)
424
  logger.setLevel(logging.INFO)
@@ -426,6 +433,12 @@ if __name__ == "__main__":
426
  with open("../code/modules/config/config.yml", "r") as f:
427
  config = yaml.safe_load(f)
428
 
 
 
 
 
 
 
429
  STORAGE_DIR = os.path.join(BASE_DIR, config["vectorstore"]["data_path"])
430
  uploaded_files = [
431
  os.path.join(STORAGE_DIR, file)
@@ -434,13 +447,15 @@ if __name__ == "__main__":
434
  ]
435
 
436
  data_loader = DataLoader(config, logger=logger)
437
- document_chunks, document_names, documents, document_metadata = (
438
- data_loader.get_chunks(
439
- [
440
- "https://dl4ds.github.io/fa2024/static_files/discussion_slides/00_discussion.pdf"
441
- ],
442
- [],
443
- )
 
 
444
  )
445
 
446
  print(document_names[:5])
 
222
 
223
  def chunk_docs(self, file_reader, uploaded_files, weblinks):
224
  addl_metadata = get_metadata(
225
+ *self.config["metadata"]["metadata_links"], self.config
 
226
  ) # For any additional metadata
227
 
228
  # remove already processed files if reparse_files is False
 
324
  return
325
 
326
  try:
 
327
  if file_path in self.document_data:
328
  self.logger.warning(f"File {file_name} already processed")
329
  documents = [
 
417
 
418
  if __name__ == "__main__":
419
  import yaml
420
+ import argparse
421
+
422
+ parser = argparse.ArgumentParser(description="Process some links.")
423
+ parser.add_argument(
424
+ "--links", nargs="+", required=True, help="List of links to process."
425
+ )
426
+
427
+ args = parser.parse_args()
428
+ links_to_process = args.links
429
 
430
  logger = logging.getLogger(__name__)
431
  logger.setLevel(logging.INFO)
 
433
  with open("../code/modules/config/config.yml", "r") as f:
434
  config = yaml.safe_load(f)
435
 
436
+ with open("../code/modules/config/project_config.yml", "r") as f:
437
+ project_config = yaml.safe_load(f)
438
+
439
+ # Combine project config with the main config
440
+ config.update(project_config)
441
+
442
  STORAGE_DIR = os.path.join(BASE_DIR, config["vectorstore"]["data_path"])
443
  uploaded_files = [
444
  os.path.join(STORAGE_DIR, file)
 
447
  ]
448
 
449
  data_loader = DataLoader(config, logger=logger)
450
+ # Just for testing
451
+ (
452
+ document_chunks,
453
+ document_names,
454
+ documents,
455
+ document_metadata,
456
+ ) = data_loader.get_chunks(
457
+ links_to_process,
458
+ [],
459
  )
460
 
461
  print(document_names[:5])
code/modules/dataloader/helpers.py CHANGED
@@ -21,7 +21,8 @@ def get_base_url(url):
21
  return base_url
22
 
23
 
24
- def get_metadata(lectures_url, schedule_url):
 
25
  """
26
  Function to get the lecture metadata from the lectures and schedule URLs.
27
  """
@@ -50,7 +51,9 @@ def get_metadata(lectures_url, schedule_url):
50
  slides_link_tag = description_div.find("a", title="Download slides")
51
  slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
52
  slides_link = (
53
- f"https://dl4ds.github.io{slides_link}" if slides_link else None
 
 
54
  )
55
  if slides_link:
56
  date_mapping[slides_link] = date
@@ -70,7 +73,9 @@ def get_metadata(lectures_url, schedule_url):
70
  slides_link_tag = block.find("a", title="Download slides")
71
  slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
72
  slides_link = (
73
- f"https://dl4ds.github.io{slides_link}" if slides_link else None
 
 
74
  )
75
 
76
  # Extract the link to the lecture recording
 
21
  return base_url
22
 
23
 
24
+ ### THIS FUNCTION IS NOT GENERALIZABLE.. IT IS SPECIFIC TO THE COURSE WEBSITE ###
25
+ def get_metadata(lectures_url, schedule_url, config):
26
  """
27
  Function to get the lecture metadata from the lectures and schedule URLs.
28
  """
 
51
  slides_link_tag = description_div.find("a", title="Download slides")
52
  slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
53
  slides_link = (
54
+ f"{config['metadata']['slide_base_link']}{slides_link}"
55
+ if slides_link
56
+ else None
57
  )
58
  if slides_link:
59
  date_mapping[slides_link] = date
 
73
  slides_link_tag = block.find("a", title="Download slides")
74
  slides_link = slides_link_tag["href"].strip() if slides_link_tag else None
75
  slides_link = (
76
+ f"{config['metadata']['slide_base_link']}{slides_link}"
77
+ if slides_link
78
+ else None
79
  )
80
 
81
  # Extract the link to the lecture recording
code/modules/retriever/helpers.py CHANGED
@@ -6,7 +6,6 @@ from typing import List
6
 
7
 
8
  class VectorStoreRetrieverScore(VectorStoreRetriever):
9
-
10
  # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
11
  def _get_relevant_documents(
12
  self, query: str, *, run_manager: CallbackManagerForRetrieverRun
 
6
 
7
 
8
  class VectorStoreRetrieverScore(VectorStoreRetriever):
 
9
  # See https://github.com/langchain-ai/langchain/blob/61dd92f8215daef3d9cf1734b0d1f8c70c1571c3/libs/langchain/langchain/vectorstores/base.py#L500
10
  def _get_relevant_documents(
11
  self, query: str, *, run_manager: CallbackManagerForRetrieverRun
code/modules/vectorstore/store_manager.py CHANGED
@@ -47,7 +47,6 @@ class VectorStoreManager:
47
  return logger
48
 
49
  def load_files(self):
50
-
51
  files = os.listdir(self.config["vectorstore"]["data_path"])
52
  files = [
53
  os.path.join(self.config["vectorstore"]["data_path"], file)
@@ -69,7 +68,6 @@ class VectorStoreManager:
69
  return files, urls
70
 
71
  def create_embedding_model(self):
72
-
73
  self.logger.info("Creating embedding function")
74
  embedding_model_loader = EmbeddingModelLoader(self.config)
75
  embedding_model = embedding_model_loader.load_embedding_model()
@@ -100,7 +98,6 @@ class VectorStoreManager:
100
  )
101
 
102
  def create_database(self):
103
-
104
  start_time = time.time() # Start time for creating database
105
  data_loader = DataLoader(self.config, self.logger)
106
  self.logger.info("Loading data")
@@ -110,9 +107,12 @@ class VectorStoreManager:
110
  self.logger.info(f"Number of webpages: {len(webpages)}")
111
  if f"{self.config['vectorstore']['url_file_path']}" in files:
112
  files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
113
- document_chunks, document_names, documents, document_metadata = (
114
- data_loader.get_chunks(files, webpages)
115
- )
 
 
 
116
  num_documents = len(document_chunks)
117
  self.logger.info(f"Number of documents in the DB: {num_documents}")
118
  metadata_keys = list(document_metadata[0].keys()) if document_metadata else []
@@ -128,7 +128,6 @@ class VectorStoreManager:
128
  )
129
 
130
  def load_database(self):
131
-
132
  start_time = time.time() # Start time for loading database
133
  if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
134
  self.embedding_model = self.create_embedding_model()
@@ -168,19 +167,21 @@ if __name__ == "__main__":
168
 
169
  with open("modules/config/config.yml", "r") as f:
170
  config = yaml.safe_load(f)
171
- with open("modules/config/user_config.yml", "r") as f:
172
- user_config = yaml.safe_load(f)
 
 
 
173
  print(config)
174
- print(user_config)
175
  print(f"Trying to create database with config: {config}")
176
  vector_db = VectorStoreManager(config)
177
  if config["vectorstore"]["load_from_HF"]:
178
  if (
179
  config["vectorstore"]["db_option"]
180
- in user_config["retriever"]["retriever_hf_paths"]
181
  ):
182
  vector_db.load_from_HF(
183
- HF_PATH=user_config["retriever"]["retriever_hf_paths"][
184
  config["vectorstore"]["db_option"]
185
  ]
186
  )
 
47
  return logger
48
 
49
  def load_files(self):
 
50
  files = os.listdir(self.config["vectorstore"]["data_path"])
51
  files = [
52
  os.path.join(self.config["vectorstore"]["data_path"], file)
 
68
  return files, urls
69
 
70
  def create_embedding_model(self):
 
71
  self.logger.info("Creating embedding function")
72
  embedding_model_loader = EmbeddingModelLoader(self.config)
73
  embedding_model = embedding_model_loader.load_embedding_model()
 
98
  )
99
 
100
  def create_database(self):
 
101
  start_time = time.time() # Start time for creating database
102
  data_loader = DataLoader(self.config, self.logger)
103
  self.logger.info("Loading data")
 
107
  self.logger.info(f"Number of webpages: {len(webpages)}")
108
  if f"{self.config['vectorstore']['url_file_path']}" in files:
109
  files.remove(f"{self.config['vectorstores']['url_file_path']}") # cleanup
110
+ (
111
+ document_chunks,
112
+ document_names,
113
+ documents,
114
+ document_metadata,
115
+ ) = data_loader.get_chunks(files, webpages)
116
  num_documents = len(document_chunks)
117
  self.logger.info(f"Number of documents in the DB: {num_documents}")
118
  metadata_keys = list(document_metadata[0].keys()) if document_metadata else []
 
128
  )
129
 
130
  def load_database(self):
 
131
  start_time = time.time() # Start time for loading database
132
  if self.config["vectorstore"]["db_option"] in ["FAISS", "Chroma", "RAPTOR"]:
133
  self.embedding_model = self.create_embedding_model()
 
167
 
168
  with open("modules/config/config.yml", "r") as f:
169
  config = yaml.safe_load(f)
170
+ with open("modules/config/project_config.yml", "r") as f:
171
+ project_config = yaml.safe_load(f)
172
+
173
+ # combine the two configs
174
+ config.update(project_config)
175
  print(config)
 
176
  print(f"Trying to create database with config: {config}")
177
  vector_db = VectorStoreManager(config)
178
  if config["vectorstore"]["load_from_HF"]:
179
  if (
180
  config["vectorstore"]["db_option"]
181
+ in config["retriever"]["retriever_hf_paths"]
182
  ):
183
  vector_db.load_from_HF(
184
+ HF_PATH=config["retriever"]["retriever_hf_paths"][
185
  config["vectorstore"]["db_option"]
186
  ]
187
  )
docs/setup.md CHANGED
@@ -124,4 +124,4 @@ CHAINLIT_URL=<your_chainlit_url>
124
  # Configuration
125
 
126
  The configuration file `code/modules/config.yaml` contains the parameters that control the behaviour of your app.
127
- The configuration file `code/modules/user_config.yaml` contains user-defined parameters.
 
124
  # Configuration
125
 
126
  The configuration file `code/modules/config.yaml` contains the parameters that control the behaviour of your app.
127
+ The configuration file `code/modules/project_config.yaml` contains project-specific parameters.