|
import os |
|
import streamlit as st |
|
from streamlit_chat import message as st_message |
|
from sqlalchemy import create_engine |
|
|
|
from langchain.agents import Tool, initialize_agent |
|
from langchain.chains.conversation.memory import ConversationBufferMemory |
|
|
|
from llama_index import GPTSQLStructStoreIndex, LLMPredictor, ServiceContext |
|
from llama_index import SQLDatabase as llama_SQLDatabase |
|
from llama_index.indices.struct_store import SQLContextContainerBuilder |
|
|
|
from constants import ( |
|
DEFAULT_SQL_PATH, |
|
DEFAULT_BUSINESS_TABLE_DESCRP, |
|
DEFAULT_VIOLATIONS_TABLE_DESCRP, |
|
DEFAULT_INSPECTIONS_TABLE_DESCRP, |
|
DEFAULT_LC_TOOL_DESCRP |
|
) |
|
from utils import get_sql_index_tool, get_llm |
|
|
|
|
|
@st.cache_resource |
|
def initialize_index(llm_name, model_temperature, table_context_dict, api_key, sql_path=DEFAULT_SQL_PATH): |
|
"""Create the GPTSQLStructStoreIndex object.""" |
|
llm = get_llm(llm_name, model_temperature, api_key) |
|
|
|
engine = create_engine(sql_path) |
|
sql_database = llama_SQLDatabase(engine) |
|
|
|
context_container = None |
|
if table_context_dict is not None: |
|
context_builder = SQLContextContainerBuilder(sql_database, context_dict=table_context_dict) |
|
context_container = context_builder.build_context_container() |
|
|
|
service_context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm)) |
|
index = GPTSQLStructStoreIndex([], |
|
sql_database=sql_database, |
|
sql_context_container=context_container, |
|
service_context=service_context) |
|
|
|
return index |
|
|
|
|
|
@st.cache_resource |
|
def initialize_chain(llm_name, model_temperature, lc_descrp, api_key, _sql_index): |
|
"""Create a (rather hacky) custom agent and sql_index tool.""" |
|
sql_tool = Tool(name="SQL Index", |
|
func=get_sql_index_tool(_sql_index, _sql_index.sql_context_container.context_dict), |
|
description=lc_descrp) |
|
|
|
llm = get_llm(llm_name, model_temperature, api_key=api_key) |
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) |
|
|
|
agent_chain = initialize_agent([sql_tool], llm, agent="chat-conversational-react-description", verbose=True, memory=memory) |
|
|
|
return agent_chain |
|
|
|
|
|
st.title("π¦ Llama Index SQL Sandbox π¦") |
|
st.markdown(( |
|
"This sandbox uses a sqlite database by default, powered by [Llama Index](https://gpt-index.readthedocs.io/en/latest/index.html) ChatGPT, and LangChain.\n\n" |
|
"The database contains information on health violations and inspections at restaurants in San Francisco." |
|
"This data is spread across three tables - businesses, inspections, and violations.\n\n" |
|
"Using the setup page, you can adjust LLM settings, change the context for the SQL tables, and change the tool description for Langchain." |
|
"The other tabs will perform chatbot and text2sql operations.\n\n" |
|
"Read more about LlamaIndexes structured data support [here!](https://gpt-index.readthedocs.io/en/latest/guides/tutorials/sql_guide.html)" |
|
)) |
|
|
|
|
|
setup_tab, llama_tab, lc_tab = st.tabs(["Setup", "Llama Index", "Langchain+Llama Index"]) |
|
|
|
with setup_tab: |
|
st.subheader("LLM Setup") |
|
api_key = st.text_input("Enter your OpenAI API key here", type="password") |
|
llm_name = st.selectbox('Which LLM?', ["text-davinci-003", "gpt-3.5-turbo", "gpt-4"]) |
|
model_temperature = st.slider("LLM Temperature", min_value=0.0, max_value=1.0, step=0.1) |
|
|
|
st.subheader("Table Setup") |
|
business_table_descrp = st.text_area("Business table description", value=DEFAULT_BUSINESS_TABLE_DESCRP) |
|
violations_table_descrp = st.text_area("Business table description", value=DEFAULT_VIOLATIONS_TABLE_DESCRP) |
|
inspections_table_descrp = st.text_area("Business table description", value=DEFAULT_INSPECTIONS_TABLE_DESCRP) |
|
|
|
table_context_dict = {"businesses": business_table_descrp, |
|
"inspections": inspections_table_descrp, |
|
"violations": violations_table_descrp} |
|
|
|
use_table_descrp = st.checkbox("Use table descriptions?", value=True) |
|
lc_descrp = st.text_area("LangChain Tool Description", value=DEFAULT_LC_TOOL_DESCRP) |
|
|
|
with llama_tab: |
|
st.subheader("Text2SQL with Llama Index") |
|
if st.button("Initialize Index", key="init_index_1"): |
|
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key) |
|
|
|
if "llama_index" in st.session_state: |
|
query_text = st.text_input("Query:", value="Which restaurant has the most violations?") |
|
if st.button("Run Query") and query_text: |
|
with st.spinner("Getting response..."): |
|
try: |
|
response = st.session_state['llama_index'].query(query_text) |
|
response_text = str(response) |
|
response_sql = response.extra_info['sql_query'] |
|
except Exception as e: |
|
response_text = "Error running SQL Query." |
|
response_sql = str(e) |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.text("SQL Result:") |
|
st.markdown(response_text) |
|
|
|
with col2: |
|
st.text("SQL Query:") |
|
st.markdown(response_sql) |
|
|
|
with lc_tab: |
|
st.subheader("Langchain + Llama Index SQL Demo") |
|
|
|
if st.button("Initialize Agent"): |
|
st.session_state['llama_index'] = initialize_index(llm_name, model_temperature, table_context_dict if use_table_descrp else None, api_key) |
|
st.session_state['lc_agent'] = initialize_chain(llm_name, model_temperature, lc_descrp, api_key, st.session_state['llama_index']) |
|
st.session_state['chat_history'] = [] |
|
|
|
model_input = st.text_input("Message:", value="Which restaurant has the most violations?") |
|
if 'lc_agent' in st.session_state and st.button("Send"): |
|
model_input = "User: " + model_input |
|
st.session_state['chat_history'].append(model_input) |
|
with st.spinner("Getting response..."): |
|
response = st.session_state['lc_agent'].run(input=model_input) |
|
st.session_state['chat_history'].append(response) |
|
|
|
if 'chat_history' in st.session_state: |
|
for msg in st.session_state['chat_history']: |
|
st_message(msg.split("User: ")[-1], is_user="User: " in msg) |
|
|
|
|