File size: 18,147 Bytes
f85c983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
import json
import operator
from operator import itemgetter
from typing import Annotated, Sequence, TypedDict

import chainlit as cl
from dotenv import load_dotenv
from langchain.retrievers import ParentDocumentRetriever
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable.config import RunnableConfig
from langchain.storage import InMemoryStore

# from langchain_core.output_parsers import StrOutputParser
from langchain.tools import tool
from langchain_community.document_loaders import ArxivLoader
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.tools.ddg_search import DuckDuckGoSearchRun
from langchain_community.tools.pubmed.tool import PubmedQueryRun

# from langgraph.graph.message import add_messages
from langchain_core.messages import (
    BaseMessage,
    FunctionMessage,
    SystemMessage,
)
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_openai import ChatOpenAI
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_qdrant import Qdrant
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import END, StateGraph
from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver

# from langchain_community.tools.pubmed.tool import PubmedQueryRun
from langgraph.prebuilt import ToolExecutor, ToolInvocation
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

# GLOBAL SCOPE - ENTIRE APPLICATION HAS ACCESS TO VALUES SET IN THIS SCOPE #
# ---- ENV VARIABLES ---- #
"""
This function will load our environment file (.env) if it is present.

NOTE: Make sure that .env is in your .gitignore file - it is by default, but please ensure it remains there.
"""
load_dotenv()

"""
We will load our environment variables here.
"""

# ---- GLOBAL DECLARATIONS ---- #


# -- RETRIEVAL -- #
"""
1. Load Documents from Text File
2. Split Documents into Chunks
3. Load HuggingFace Embeddings (remember to use the URL we set above)
4. Index Files if they do not exist, otherwise load the vectorstore
"""
### 1. CREATE TEXT LOADER AND LOAD DOCUMENTS
### NOTE: PAY ATTENTION TO THE PATH THEY ARE IN.


docs = ArxivLoader(
    query='"mental health counseling" AND (data OR analytics OR "machine learning")',
    load_max_docs=2,
    sort_by="submittedDate",
    sort_order="descending",
).load()


### 2. CREATE QDRANT CLIENT VECTORE STORE

client = QdrantClient(":memory:")
client.create_collection(
    collection_name="split_parents",
    vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
)

vectorstore = Qdrant(
    client,
    collection_name="split_parents",
    embeddings=OpenAIEmbeddings(model="text-embedding-3-small"),
)

store = InMemoryStore()

### 3. CREATE PARENT DOCUMENT TEXT SPLITTER AND RETRIEVER INITIATED

parent_document_retriever = ParentDocumentRetriever(
    vectorstore=vectorstore,
    docstore=store,
    child_splitter=RecursiveCharacterTextSplitter(chunk_size=400),
    parent_splitter=RecursiveCharacterTextSplitter(chunk_size=2000),
)
parent_document_retriever.add_documents(docs)

### 4. CREATE PROMPT OBJECT
RAG_PROMPT = """\
Your are a professional mental helth advisor. Use the following context to answer the user's query. If you cannot answer the question, please respond with 'I don't know'.

Question:
{question}

Context:
{context}
"""

rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)

### 5. CREATE CHAIN PIPLINE RETRIVER

openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo", streaming=True)


def create_qa_chain(retriever):
    mentahealth_qa_llm = openai_chat_model

    created_qa_chain = (
        {
            "context": itemgetter("question") | retriever,
            "question": itemgetter("question"),
        }
        | RunnablePassthrough.assign(context=itemgetter("context"))
        | {
            "response": rag_prompt | mentahealth_qa_llm | StrOutputParser(),
            "context": itemgetter("context"),
        }
    )
    return created_qa_chain


### 6. DEFINE LIST OF TOOLS AVAILABLE FOR AND TOOL EXECUTOR WRAPPED AROUND THEM


@tool
async def rag_tool(question: str) -> str:
    """Use this tool to retrieve relevant information from the knowledge base."""
    # advanced_rag_prompt=ChatPromptTemplate.from_template(INSTRUCTION_PROMPT_TEMPLATE.format(user_query=question))
    parent_document_retriever_qa_chain = create_qa_chain(parent_document_retriever)
    response = await parent_document_retriever_qa_chain.ainvoke({"question": question})

    return response["response"]


tool_belt = [
    rag_tool,
    PubmedQueryRun(),
    ArxivQueryRun(),
    DuckDuckGoSearchRun(),
]

tool_executor = ToolExecutor(tool_belt)


### 7. CONVERT TOOLS INTO THE FORMAT COMAPTIBLE WITH OPENAI'S FUNCTION CALLING API THEN BINDING THEM TO MODEL TO BE USED WHEN GENERATION
model = ChatOpenAI(temperature=0, streaming=True)

functions = [convert_to_openai_function(t) for t in tool_belt]
model = model.bind_functions(functions)
model = model.with_config(tags=["final_node"])

### 8. USING the TypedDict FROM THE typing module AND THE langchain_core.messages module, A CUSTOM TYPE NAMED AgentState CREATED.
# THE AgentState type HAS A FIELD NAMED <messages> THAT IS OF TYPE Annotated[Sequence[BaseMessage], operator.add].
# Sequence[BaseMessage]: INDICATES THAT MESSAGES ARE A SEQUENCE OF BaseMessage OBJECTS.
# Annotated: USED TO ATTACH MEATADATA TO THE TYPE, THEN THE MESSAGE FIELD TREATED AS CONCATENABLE SEQUENCE OF BASEMASSAGES TO OPERATOR.ADD FUNCTION.


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]


### 9. TWO FUNCTIONS DEFINED: 1. call_model AND 2. call_tool FUNCTIONS
# 1. INVOKES THE MODEL BY THE MESSAGES EXTRACTED FROM THE STATE RETURNING A DICT CONTAINING THE RESPONSE MESSAGE,
# 2.1 ToolInvocation OBJECT CREATED USING THE NAME AND ARGUMENTS EXTRACTED FROM THE LAST MASSAGE EXTRACTED FROM THE STATE,
# 2.2. tool_executor IS INVOKED BY THE CREATED toolInvocation OBJECT
# 2.3 FunctionMessage OBJECT IS CREATED WITH THE tool_executor RESPONSE AND THE NAME OF THAT TOOL
# 2.4 RETURN IS A DICT CONTAINING FunctionMessage OBJECT.


async def call_model(state):
    messages = state["messages"]
    response = await model.ainvoke(messages)
    return {"messages": [response]}


async def call_tool(state):
    last_message = state["messages"][-1]

    action = ToolInvocation(
        tool=last_message.additional_kwargs["function_call"]["name"],
        tool_input=json.loads(
            last_message.additional_kwargs["function_call"]["arguments"]
        ),
    )
    
    print()
    print(last_message.additional_kwargs["function_call"]["name"])
    print()
    response = await tool_executor.ainvoke(action)

    function_message = FunctionMessage(content=str(response), name=action.tool)

    return {"messages": [function_message]}


###10. GRAPG CREATION WITH HELPFULNESS EVALUATION
# should_continue CHECKS IF THE LAST MASSAGE IN THE STATE IS TO CONTINUE (additional_kwargs EXISTS) OR END.
# THE add_conditional_edges() method IS ORIGINATED FROM THIS REPONSE, EITHER TRANSITION TO ACTION NODE OR END.


def should_continue(state):
    last_message = state["messages"][-1]

    if "function_call" not in last_message.additional_kwargs:
        return "end"

    return "continue"


async def check_helpfulness(state):
    initial_query = state["messages"][0]
    final_response = state["messages"][-1]

    # adding artificial_loop

    if len(state["messages"]) > 20:
        return "end"

    prompt_template = """\
  Given an initial query and a final response, determine if the final response is extremely helpful or not. Please indicate helpfulness with a 'Y'\
  and unhelpfulness as an 'N'.

  Initial Query:
  {initial_query}

  Final Response:
  {final_response}"""

    prompt_template = PromptTemplate.from_template(prompt_template)

    helpfulness_check_model = ChatOpenAI(model="gpt-4")

    helpfulness_check_chain = (
        prompt_template | helpfulness_check_model | StrOutputParser()
    )

    helpfulness_response = await helpfulness_check_chain.ainvoke(
        {"initial_query": initial_query, "final_response": final_response}
    )

    if "Y" in helpfulness_response:
        print("helpful!")
        return "end"

    else:
        print(" Not helpful!!")
        return "continue"


def dummy_node(state):
    return


### 11. SETTING THE GRAPH WORKFLOW:
# 1. AN INSTANCE OF THE STATEGRAPH CREATED OF THE TYPE AgentState. THREE NODES ADDED TO THE GRAPH USING add_node() method:
# 1.1 THE "agent" NODE IS ASSOCIATED WITH THE call_model FUNCTION.
# 1.2 THE "action" NODE IS ASSOCIATED WITH THE call_tool FUNCTION.
# 1.3 THE "passthrough" NODE IS A CUSTOM NODE THAT IS ASSOCIATED WITH CHECKING HELPFULNESS.
# 1.5 THE CONDITIONAL EDGES
# 1.5.1 BETWEEN agent NODE AND THE OTHER TWO NODES TO EITHER action NODE OR passthrough NODE
# 1.5.2 BETWEEN passthrough NODE AND agen NODE OR END NODE.
# 1.5.3 BETWEEN agent AND action NODES AS MODEL HAS ACCESS TO TOOLS FOR RESPONSE GENERATION.
def get_state_update_bot():
    workflow = StateGraph(AgentState)

    workflow.add_node("agent", call_model)  # agent node has access to llm
    workflow.add_node("action", call_tool)  # action node has access to tools
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "action",  # tools
            "end": END,
        },
    )
    workflow.add_edge("action", "agent")  # tools
    state_update_bot = workflow.compile()

    return state_update_bot


#   --------------------------------------------------
from langgraph.checkpoint.memory import MemorySaver

def get_state_update_bot_with_helpfullness_node():
    # memory = MemorySaver()

    graph_with_helpfulness_check = StateGraph(AgentState)

    graph_with_helpfulness_check.add_node("agent", call_model)
    graph_with_helpfulness_check.add_node("action", call_tool)
    graph_with_helpfulness_check.add_node("passthrough", dummy_node)

    graph_with_helpfulness_check.set_entry_point("agent")

    graph_with_helpfulness_check.add_conditional_edges(
        "agent", should_continue, {"continue": "action", "end": "passthrough"}
    )

    graph_with_helpfulness_check.add_conditional_edges(
        "passthrough", check_helpfulness, {"continue": "agent", "end": END}
    )

    graph_with_helpfulness_check.add_edge("action", "agent")
    memory=AsyncSqliteSaver.from_conn_string(":memory:")
    return graph_with_helpfulness_check.compile(checkpointer=memory)


### 12.
# def convert_inputs(input_object):
#     system_prompt = f"""You are a qualified psychologist providing mental health advice. Be empathetic in your responses. 
#     Always provide a complete response. Be empathetic and provide a follow-up question to find a resolution. 
#     First, look up the RAG (retrieval-augmented generation) and then arxiv research or use InternetSearch:



# You will operate in a loop of Thought, Action, PAUSE, and Observation. At the end of the loop, you will provide an Answer.

# Instructions:

# Thought: Describe your thoughts about the user's question.
# Action: Choose one of the available actions to gather information or provide insights.
# PAUSE: Pause to allow the action to complete.
# Observation: Review the results of the action.

# Available Actions:

# Use the tools at your disposal to look up information or resolve the consultancy. You are allowed to make multiple calls (either together or in sequence).:

# 1. rag_tool: RAG (Retrieval-Augmented Generation) to access relevant mental health information.
# 2. DuckDuckGoSearchRun: Perform an online search: InternetSearch to find up-to-date resources and recommendations.
# 3. ArxivQueryRun: Find relevant research or content.
# 3. PubMedQuerRun: Find a specific coping strategies or management techniques by doing research paper

# You may make multiple calls to these tools as needed to provide comprehensive advice.

# Present your final response in a clear, structured format, including a chart of recommended actions if appropriate.

#     User's question: {input_object["messages"]}

#     Response: Your task is When responding to users' personal issues or concerns:

# 1. With a brief empathetic acknowledgment of the user's situation, continue
# 2. Provide practical, actionable advice that often includes 
# 3. Suggesting professional help (e.g., therapists, counselors) when appropriate
# 4. Encouraging open communication and dialogue with involved parties and 
# 5. Recommending self-reflection or exploration of emotions and values and
# 6. Offering specific coping strategies or management techniques
# """
#     return {"messages": [SystemMessage(content=system_prompt)]}
def convert_inputs(input_object):
    system_prompt = f"""You are a qualified psychologist providing mental health advice. Be empathetic in your responses. 
    Always provide a complete response. Be empathetic and provide a follow-up question to find a resolution. 
    
    You must Use the tools at your dsiposal.
    You must consult pubmed, then ragtool, then duckduckgo_results_json.
    You must make multiple calls to these tools as needed to provide comprehensive advice.


    User's question: {input_object["messages"]}
    """
    return {"messages": [SystemMessage(content=system_prompt)]}


# Define the function to parse the output
def parse_output(input_state):
    return input_state


# bot_with_helpfulness_check=get_state_update_bot_with_helpfullness_node() # type:
# bot=get_state_update_bot()

# Create the agent chain
# agent_chain = convert_inputs | bot_with_helpfulness_check# | StrOutputParser()#| parse_output

# Run the agent chain with the input
# messages=agent_chain.invoke({"question": mental_health_counseling_data['test'][14]['Context']})
import uuid
# ---------------------------------------------------------------------------------------------------------
#                                       DEPLOYMENT
# ---------------------------------------------------------------------------------------------------------
from langchain_core.messages import HumanMessage

@cl.author_rename
def rename(original_author: str):
    """
    This function can be used to rename the 'author' of a message.

    In this case, we're overriding the 'Assistant' author to be 'Paul Graham Essay Bot'.
    """
    rename_dict = {"Assistant": "Mental Health Advisor Bot"}
    return rename_dict.get(original_author, original_author)


@cl.on_chat_start
async def start_chat():
    """
    This function will be called at the start of every user session.

    We will build our LCEL RAG chain here, and store it in the user session.

    The user session is a dictionary that is unique to each user session, and is stored in the memory of the server.
    """

    ### BUILD LCEL RAG CHAIN THAT ONLY RETURNS TEXT
    # lcel_rag_chain = ( {"context": itemgetter("query") | hf_retriever, "query": itemgetter("query")}

    #                    | rag_prompt | hf_llm
    #                 )
    memory=MemorySaver
    bot_with_helpfulness_check = get_state_update_bot_with_helpfullness_node()#(checkpointer=memory)
  # type: ignore
    lcel_agent_langgraph_chain = (
        convert_inputs | bot_with_helpfulness_check) #| StrOutputParser())

    # bot=get_state_update_bot()

    # lcel_agent_chain = convert_inputs | bot| parse_output# StrOutputParser()

    cl.user_session.set("langgraph_agent_chain", lcel_agent_langgraph_chain)

    # Create a thread id and pass it as configuration 
    # to be able to use Langgraph's MemorySaver
    conversation_id = str(uuid.uuid4())
    config = {"configurable": {"thread_id": conversation_id}}
    cl.user_session.set("config", config)



@cl.on_message
async def main(message: cl.Message):
    """
    This function will be called every time a message is recieved from a session.

    """
    # msg is the human message, could be mixed with system message.
    # agent_message is the agent's response.

    graph = cl.user_session.get("langgraph_agent_chain")
    config = cl.user_session.get("config")
    final_output=""

    # inputs = {"messages": [("user", message.content)]}
    inputs={"messages": [HumanMessage(message.content)]}

    agent_message = cl.Message(content="")
    await agent_message.send()


    # final_output=""

    async for event in graph.astream_events(
        inputs,
        config=config,#=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
        version="v2",
    ):
        
        kind = event["event"]
        tags = event.get("tags", [])
        name=event.get("name", "")
        print()
        print(f"Received event: {event}")  # Debugging statement
        print()
        if kind == "on_chain_start":
            if (
                event["name"] == "Agent"
            ):  # Was assigned when creating the agent with `.with_config({"run_name": "Agent"})`
                print(
                    f"Starting agent: {event['name']} with input: {event['data'].get('input')}"
                )
        
    # await agent_message.send()
        elif kind == "on_chain_end" and name=="RunnableSequence":#"tool_end" in tags:
            if 'output' in event['data'] and "agent" in event["data"]['output']:
                agent_output=event["data"]["output"]["agent"]
                if "messages" in agent_output and agent_output["messages"]:
                    final_output=agent_output["messages"][0].content
                    await agent_message.stream_token(final_output)

        # elif kind=="on_chain_stream":
        #     data=event['data']
        #     if data["chunk"].content:
        #         print(f"Streaming content: {data['chunk'].content}")  
        #         await agent_message.stream_token(data["chunk"].content)


    await agent_message.send()

#docker build -t llm-app-langgraph-react-chainlit-mentalmindbt .
#docker run -it -p 7860:7860 llm-app-langgraph-react-chainlit-mentalmindbt:latest