File size: 6,983 Bytes
cd5daea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import openai
import chainlit as cl
import pandas as pd
import chromadb

from chainlit import user_session
from sqlalchemy import create_engine
from typing import List, Tuple, Any
from pydantic import BaseModel, Field
from llama_index import Document
from llama_index import SQLDatabase
from llama_index.agent import OpenAIAgent
from llama_index.tools.query_engine import QueryEngineTool
from llama_index.indices.struct_store.sql_query import NLSQLTableQueryEngine
from llama_index import ServiceContext
from llama_index.llms import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index import VectorStoreIndex
from llama_index.vector_stores import ChromaVectorStore
from llama_index.storage.storage_context import StorageContext
from llama_index.tools import FunctionTool
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.vector_stores.types import (
    VectorStoreInfo,
    MetadataInfo,
    ExactMatchFilter,
    MetadataFilters,
)

openai.api_key = os.environ["OPENAI_API_KEY"]

# preparation
def get_df_from_workbook(sheet_name,
                         workbook_id = '1MB1ZsQul4AB262AsaY4fHtGW4HWp2-56zB-E5xTbs2A'):
    url = f'https://docs.google.com/spreadsheets/d/{workbook_id}/gviz/tq?tqx=out:csv&sheet={sheet_name}'
    return pd.read_csv(url)

docEmailSample = Document(
    text="Hey KD, let's grab dinner after our next game, Steph", 
    metadata={'from_to': 'Stephen Curry to Kevin Durant',}
)
docEmailSample2 = Document(
    text="Yo Joker, you were a monster last year, can't wait to play against you in the opener! Draymond", 
    metadata={'from_to': 'Draymond Green to Nikola Jokic',}
)
docAdditionalSamples = [docEmailSample, docEmailSample2]

class AutoRetrieveModel(BaseModel):
    query: str = Field(..., description="natural language query string")
    filter_key_list: List[str] = Field(
        ..., description="List of metadata filter field names"
    )
    filter_value_list: List[str] = Field(
        ...,
        description=(
            "List of metadata filter field values (corresponding to names specified in filter_key_list)"
        )
    )
    
def auto_retrieve_fn(
    query: str, filter_key_list: List[str], filter_value_list: List[str]
):
    """Auto retrieval function.

    Performs auto-retrieval from a vector database, and then applies a set of filters.

    """
    query = query or "Query"
    
    # for i, (k, v) in enumerate(zip(filter_key_list, filter_value_list)):
    #     if k == 'token_list':
    #         if token not in v:
    #             v = ''

    exact_match_filters = [
        ExactMatchFilter(key=k, value=v)
        for k, v in zip(filter_key_list, filter_value_list)
    ]
    retriever = VectorIndexRetriever(
        vector_index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
    )
    # query_engine = vector_index.as_query_engine(filters=MetadataFilters(filters=exact_match_filters))
    query_engine = RetrieverQueryEngine.from_args(retriever)

    response = query_engine.query(query)
    return str(response)

# loading CSV data
sheet_names = ['Teams', 'Players', 'Schedule', 'Player_Stats']
dict_of_dfs = {sheet: get_df_from_workbook(sheet) for sheet in sheet_names}

engine = create_engine("sqlite+pysqlite:///:memory:")

for df in dict_of_dfs:
    dict_of_dfs[df].to_sql(df, con=engine)

sql_database = SQLDatabase(
    engine,
    include_tables=list(dict_of_dfs.keys())
    )

# setting up llm & service content
embed_model = OpenAIEmbedding()
chunk_size = 1000
llm = OpenAI(
    temperature=0, 
    model="gpt-3.5-turbo",
    streaming=True
)
service_context = ServiceContext.from_defaults(
    llm=llm, 
    chunk_size=chunk_size,
    embed_model=embed_model
)

# setting up vector store
chroma_client = chromadb.Client()
chroma_collection = chroma_client.create_collection("all_data")
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)
vector_index = VectorStoreIndex([], storage_context=storage_context, service_context=service_context)

vector_index.insert_nodes(docAdditionalSamples)

# setting up metadata
top_k = 3
info_emails_players = VectorStoreInfo(
    content_info="emails exchanged between NBA players",
    metadata_info=[
        MetadataInfo(
            name="from_to",
            type="str",
            description="""
email sent by a player of the Golden State Warriors to any other NBA player, one of [
Stephen Curry to any NBA player, 
Klay Thompson to any NBA player, 
Chris Paul to any NBA player, 
Andrew Wiggins to any NBA player, 
Draymond Green to any NBA player, 
Gary Payton II to any NBA player, 
Kevon Looney to any NBA player, 
Jonathan Kuminga to any NBA player, 
Moses Moody to any NBA player, 
Brandin Podziemski to any NBA player, 
Cory Joseph to any NBA player, 
Dario Šarić to any NBA player]"""
        ), 
    ]
)

@cl.on_chat_start
def main():
   
    sql_query_engine = NLSQLTableQueryEngine(
        sql_database=sql_database,
        tables=list(dict_of_dfs.keys())
    )
    
    sql_nba_tool = QueryEngineTool.from_defaults(
        query_engine=sql_query_engine, # 
        name='sql_nba_tool', 
        description=("""Useful for translating a natural language query into a SQL query over tables containing:
                        1. teams, containing information related to all NBA teams
                        2. players, containing information about the team that each player plays for
                        3. schedule, containing information related to the entire NBA game schedule
                        4. player_stats, containing information related to all NBA player stats
                        """
        ),
    )
    
    description_emails = f"""\
    Use this tool to look up information about emails exchanged betweed players of the Golden State Warriors and any other NBA player.
    The vector database schema is given below:
    {info_emails_players.json()}
    """
    auto_retrieve_tool_emails = FunctionTool.from_defaults(
        fn=auto_retrieve_fn, 
        name='auto_retrieve_tool_emails',
        description=description_emails, 
        fn_schema=AutoRetrieveModel
    )
    
    agent = OpenAIAgent.from_tools(
    # agent = ReActAgent.from_tools(
        tools = [sql_nba_tool, 
                 auto_retrieve_tool_emails,
                ], 
        llm=llm, 
        verbose=True,
    )
    
    cl.user_session.set("agent", agent)
    
@cl.on_message
async def main(message):
    agent = cl.user_session.get("agent") 
    
    # response = agent.chat(message.content)
    response = agent.chat(message)
    
    response_message = cl.Message(content="")

    # for token in response.response:
    #     await response_message.stream_token(token=token)

    if response.response:
        response_message.content = response.response

    await response_message.send()