forked from business-science/ai-data-science-team
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
197 lines (144 loc) · 5.64 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# BUSINESS SCIENCE
# SQL Database Agent App
# -----------------------
# This app is designed to help you query your SQL database and return data frames that you can interactively inspect and download.
# Imports
# !pip install git+https://github.com/business-science/ai-data-science-team.git --upgrade
from openai import OpenAI
import streamlit as st
import sqlalchemy as sql
import pandas as pd
import asyncio
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_openai import ChatOpenAI
from ai_data_science_team.agents import SQLDatabaseAgent
# * APP INPUTS ----
# MODIFY THIS TO YOUR DATABASE PATH IF YOU WANT TO USE A DIFFERENT DATABASE
DB_OPTIONS = {
"Northwind Database": "sqlite:///data/northwind.db",
}
MODEL_LIST = ['gpt-4o-mini', 'gpt-4o']
TITLE = "Your SQL Database Agent"
# * STREAMLIT APP SETUP ----
st.set_page_config(page_title=TITLE, page_icon="📊", )
st.title(TITLE)
st.markdown("""
Welcome to the SQL Database Agent. This AI agent is designed to help you query your SQL database and return data frames that you can interactively inspect and download.
""")
with st.expander("Example Questions", expanded=False):
st.write(
"""
- What tables exist in the database?
- What are the first 10 rows in the territory table?
- Aggregate sales for each territory.
- Aggregate sales by month for each territory.
"""
)
# * STREAMLIT APP SIDEBAR ----
# Database Selection
db_option = st.sidebar.selectbox(
"Select a Database",
list(DB_OPTIONS.keys()),
)
st.session_state["PATH_DB"] = DB_OPTIONS.get(db_option)
sql_engine = sql.create_engine(st.session_state["PATH_DB"])
conn = sql_engine.connect()
# * OpenAI API Key
st.sidebar.header("Enter your OpenAI API Key")
st.session_state["OPENAI_API_KEY"] = st.sidebar.text_input("API Key", type="password", help="Your OpenAI API key is required for the app to function.")
# Test OpenAI API Key
if st.session_state["OPENAI_API_KEY"]:
# Set the API key for OpenAI
client = OpenAI(api_key=st.session_state["OPENAI_API_KEY"])
# Test the API key (optional)
try:
# Example: Fetch models to validate the key
models = client.models.list()
st.success("API Key is valid!")
except Exception as e:
st.error(f"Invalid API Key: {e}")
else:
st.info("Please enter your OpenAI API Key to proceed.")
st.stop()
# * OpenAI Model Selection
model_option = st.sidebar.selectbox(
"Choose OpenAI model",
MODEL_LIST,
index=0
)
OPENAI_LLM = ChatOpenAI(
model = model_option,
api_key=st.session_state["OPENAI_API_KEY"]
)
llm = OPENAI_LLM
# * STREAMLIT
# Set up memory
msgs = StreamlitChatMessageHistory(key="langchain_messages")
if len(msgs.messages) == 0:
msgs.add_ai_message("How can I help you?")
# Initialize dataframe storage in session state
if "dataframes" not in st.session_state:
st.session_state.dataframes = []
# Function to display chat messages including Plotly charts and dataframes
def display_chat_history():
for i, msg in enumerate(msgs.messages):
with st.chat_message(msg.type):
if "DATAFRAME_INDEX:" in msg.content:
df_index = int(msg.content.split("DATAFRAME_INDEX:")[1])
st.dataframe(st.session_state.dataframes[df_index])
else:
st.write(msg.content)
# Render current messages from StreamlitChatMessageHistory
display_chat_history()
# Create the SQL Database Agent
sql_db_agent = SQLDatabaseAgent(
model = llm,
connection=conn,
n_samples=1,
log = False,
bypass_recommended_steps=True,
)
# Handle the question async
async def handle_question(question):
await sql_db_agent.ainvoke_agent(
user_instructions=question,
)
return sql_db_agent
if st.session_state["PATH_DB"] and (question := st.chat_input("Enter your question here:", key="query_input")):
if not st.session_state["OPENAI_API_KEY"]:
st.error("Please enter your OpenAI API Key to proceed.")
st.stop()
with st.spinner("Thinking..."):
st.chat_message("human").write(question)
msgs.add_user_message(question)
# Run the app
error_occured = False
try:
print(st.session_state["PATH_DB"])
result = asyncio.run(handle_question(question))
except Exception as e:
error_occured = True
print(e)
response_text = f"""
I'm sorry. I am having difficulty answering that question. You can try providing more details and I'll do my best to provide an answer.
Error: {e}
"""
msgs.add_ai_message(response_text)
st.chat_message("ai").write(response_text)
st.error(f"Error: {e}")
# Generate the Results
if not error_occured:
sql_query = result.get_sql_query_code()
response_df = result.get_data_sql()
if sql_query:
# Store the SQL
response_1 = f"### SQL Results:\n\nSQL Query:\n\n```sql\n{sql_query}\n```\n\nResult:"
# Store the forecast df and keep its index
df_index = len(st.session_state.dataframes)
st.session_state.dataframes.append(response_df)
# Store response
msgs.add_ai_message(response_1)
msgs.add_ai_message(f"DATAFRAME_INDEX:{df_index}")
# Write Results
st.chat_message("ai").write(response_1)
st.dataframe(response_df)