XThomasBU commited on
Commit
2719f21
1 Parent(s): 9207578

fixed follow up logging

Browse files
code/main.py CHANGED
@@ -395,11 +395,20 @@ class Chatbot:
395
 
396
  if self.config["llm_params"]["generate_follow_up"]:
397
  start_time = time.time()
398
- list_of_questions = self.question_generator.generate_questions(
 
 
 
 
 
 
 
 
399
  query=user_query_dict["input"],
400
  response=answer,
401
  chat_history=res.get("chat_history"),
402
  context=res.get("context"),
 
403
  )
404
 
405
  for question in list_of_questions:
 
395
 
396
  if self.config["llm_params"]["generate_follow_up"]:
397
  start_time = time.time()
398
+ config = {
399
+ "callbacks": (
400
+ [cl.LangchainCallbackHandler()]
401
+ if cl_data._data_layer and self.config["chat_logging"]["callbacks"]
402
+ else None
403
+ )
404
+ }
405
+
406
+ list_of_questions = await self.question_generator.generate_questions(
407
  query=user_query_dict["input"],
408
  response=answer,
409
  chat_history=res.get("chat_history"),
410
  context=res.get("context"),
411
+ config=config,
412
  )
413
 
414
  for question in list_of_questions:
code/modules/chat/langchain/langchain_rag.py CHANGED
@@ -100,8 +100,8 @@ class QuestionGenerator:
100
  def __init__(self):
101
  pass
102
 
103
- def generate_questions(self, query, response, chat_history, context):
104
- questions = return_questions(query, response, chat_history, context)
105
  return questions
106
 
107
 
@@ -204,7 +204,7 @@ class Langchain_RAG_V2(BaseRAG):
204
  is_shared=True,
205
  ),
206
  ],
207
- )
208
 
209
  if callbacks is not None:
210
  self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
 
100
  def __init__(self):
101
  pass
102
 
103
+ def generate_questions(self, query, response, chat_history, context, config):
104
+ questions = return_questions(query, response, chat_history, context, config)
105
  return questions
106
 
107
 
 
204
  is_shared=True,
205
  ),
206
  ],
207
+ ).with_config(run_name="Langchain_RAG_V2")
208
 
209
  if callbacks is not None:
210
  self.rag_chain = self.rag_chain.with_config(callbacks=callbacks)
code/modules/chat/langchain/utils.py CHANGED
@@ -280,7 +280,8 @@ def create_retrieval_chain(
280
  return retrieval_chain
281
 
282
 
283
- def return_questions(query, response, chat_history_str, context):
 
284
 
285
  system = (
286
  "You are someone that suggests a question based on the student's input and chat history. "
@@ -303,13 +304,17 @@ def return_questions(query, response, chat_history_str, context):
303
  )
304
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
305
  question_generator = prompt | llm | StrOutputParser()
306
- new_questions = question_generator.invoke(
 
 
 
307
  {
308
  "chat_history_str": chat_history_str,
309
  "context": context,
310
  "query": query,
311
  "response": response,
312
- }
 
313
  )
314
 
315
  list_of_questions = new_questions.split("...")
 
280
  return retrieval_chain
281
 
282
 
283
+ # TODO: Remove Hard-coded values
284
+ async def return_questions(query, response, chat_history_str, context, config):
285
 
286
  system = (
287
  "You are someone that suggests a question based on the student's input and chat history. "
 
304
  )
305
  llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
306
  question_generator = prompt | llm | StrOutputParser()
307
+ question_generator = question_generator.with_config(
308
+ run_name="follow_up_question_generator"
309
+ )
310
+ new_questions = await question_generator.ainvoke(
311
  {
312
  "chat_history_str": chat_history_str,
313
  "context": context,
314
  "query": query,
315
  "response": response,
316
+ },
317
+ config=config,
318
  )
319
 
320
  list_of_questions = new_questions.split("...")