forked from Brainana/LexBudget
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBudgetChatbot_AssistAPI.py
268 lines (230 loc) · 10.8 KB
/
BudgetChatbot_AssistAPI.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
from openai import OpenAI
import streamlit as st
import configparser
import time
from datetime import datetime
import pytz
# import libraries for user feedback
from trubrics.integrations.streamlit import FeedbackCollector
from streamlit_feedback import streamlit_feedback
# import regex library
import re
# Get the specific configuration for the app
config = configparser.ConfigParser()
config.read('config.ini')
# Get the debug configuration mode
debug = config.get('Server', 'debug')
# Set page title
st.title(config.get('Template', 'title'))
# Initialize OpenAI client with your own API key
client = OpenAI(api_key=st.secrets["OPENAI_API_KEY"])
# Initialize feedback collector
collector = FeedbackCollector(
project="default",
email=st.secrets.TRUBRICS_EMAIL,
password=st.secrets.TRUBRICS_PASSWORD,
)
# handle feedback submissions
def _submit_feedback():
if st.session_state.feedback_key is None:
st.session_state.feedback_key = {'type': ""}
st.session_state.feedback_key['text'] = st.session_state.feedback_response
collector.log_feedback(
component="default",
model=st.session_state.logged_prompt.config_model.model,
user_response=st.session_state.feedback_key,
prompt_id=st.session_state.logged_prompt.id
)
# Helper function to convert Unix timestamp to datetime object in EST timezone
def convert_to_est(unix_timestamp):
utc_datetime = datetime.utcfromtimestamp(unix_timestamp)
est_timezone = pytz.timezone('US/Eastern')
est_datetime = utc_datetime.replace(tzinfo=pytz.utc).astimezone(est_timezone)
return est_datetime.strftime('%B %d, %Y %H:%M:%S %Z')
# Cache the thread on session state, so we don't keep creating
# new thread for the same browser session
thread = None
if "openai_thread" not in st.session_state:
thread = client.beta.threads.create()
st.session_state["openai_thread"] = thread
else:
thread = st.session_state["openai_thread"]
# Get all previous messages in session state
if "messages" not in st.session_state:
st.session_state.messages = []
# Display all previous messages upon page refresh
assistantAvatar = config.get('Template', 'assistantAvatar')
for message in st.session_state.messages:
if message["role"] == "assistant":
with st.chat_message(message["role"], avatar=assistantAvatar):
st.markdown(message["content"], unsafe_allow_html=True)
else:
with st.chat_message(message["role"]):
st.markdown(message["content"], unsafe_allow_html=True)
# We use a predefined assistant with uploaded files
# Should not create a new assistant every time the page refreshes
assistantId=config.get('OpenAI', 'assistantId')
# Display the input text box
chatInputPlaceholder = config.get('Template', 'chatInputPlaceholder')
if prompt := st.chat_input(chatInputPlaceholder):
# User has entered a question -> save it to the session state
st.session_state.messages.append({"role": "user", "content": prompt})
# Copy the user's question in the chat window
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant", avatar=assistantAvatar):
# Create the ChatGPT message with the user's question
message = client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=prompt
)
# Track query start time
start_time = time.time()
# Query ChatGPT
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistantId
)
# Check query status
runStatus = None
# Display progress bar
progressText = "Retrieving in progress. Please wait."
progressValue = 0
progressBar = st.progress(0, text=progressText)
while runStatus != "completed":
# Update progress bar and make sure it's not exceeding 100
if progressValue < 99:
progressValue += 1
else:
progressValue = 1
progressBar.progress(progressValue, text=progressText)
time.sleep(0.1)
# Keep checking query status
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id
)
if runStatus == "completed":
progressBar.progress(100, text=progressText)
runStatus = run.status
# Remove progress bar
progressBar.empty()
# Track query end time
end_time = time.time()
query_time = end_time - start_time
# construct metadata to be logged
metadata={
"query_time": f"{query_time:.2f} sec",
"start_time": convert_to_est(start_time),
"end_time": convert_to_est(end_time),
"assistant_id": assistantId
}
# Get all messages from the thread
messages = client.beta.threads.messages.list(
thread_id=thread.id,
# make sure thread is ordered with latest messages first
order='desc'
)
# If latest message is not from assistant, sleep to give Assistants API time to add GPT-4 response to thread
if messages.data[0].role != 'assistant':
time.sleep(2)
messages = client.beta.threads.messages.list(
thread_id=thread.id,
# Make sure thread is ordered with latest messages first
order='desc'
)
message = None
if messages.data[0].role != 'assistant':
errorResponse = """OpenAI's Assistants API failed to return a response, please change your query and try again.<br>
Tips to improve your queries:<br>
- Provide a specific year or years<br>
- Provide a specific area of the town budget<br>
- Avoid abbreviations"""
metadata["threadmsgs"] = messages
# If assistant still doesn't return response, make an error response
message = client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=errorResponse
)
else:
# Retrieve the message object from the assistant
message = client.beta.threads.messages.retrieve(
thread_id=thread.id,
# Use the latest message from the assistant
message_id=messages.data[0].id
)
if debug == "true":
# Use Streamlit Magic commands to output message, which is a feature that allows dev to write
# markdown, data, charts, etc without having to write an explicit command
message
# Extract the message content
message_content = message.content[0].text
annotations = message_content.annotations
citations = []
# When there are multiple files associated with an assistant, annotations will return as empty:
# see https://community.openai.com/t/assistant-api-always-return-empty-annotations/489285
if len(annotations) == 0:
message_content.value = re.sub(r'【[\d:]+†source】', '', message_content.value)
# Iterate over the annotations and add footnotes
for index, annotation in enumerate(annotations):
# Gather citations based on annotation attributes
if (file_citation := getattr(annotation, 'file_citation', None)):
if not file_citation.file_id:
# Do not provide footnote if file id of citation is empty
message_content.value = message_content.value.replace(annotation.text, '')
else:
# Replace the annotations with a footnote
message_content.value = message_content.value.replace(annotation.text, f' <sup><a href="#cite_note-{message.id}-{index}">[{index}]</a></sup>')
cited_file = client.files.retrieve(file_citation.file_id)
citations.append(f'<div id="cite_note-{message.id}-{index}" style="font-size: 90%">[{index}]: {cited_file.filename} <br><br> {file_citation.quote}</div>')
else:
# Do not provide footnote if file for citation cannot be found
message_content.value = message_content.value.replace(annotation.text, '')
# Add footnotes to the end of the message before displaying to user
if len(citations) > 0:
message_content.value += '<h5 style="border-bottom: 1px solid">References</h5>'
message_content.value += '\n\n' + '\n'.join(citations)
# Prevent latex formatting by replacing $ with html dollar literal
message_content.value = message_content.value.replace('$','$')
# Display assistant message
st.markdown(message_content.value, unsafe_allow_html=True)
# Save the assistant's message in session state (we do this in addition to
# saving the thread because we processed the message after retrieving it, e.g. for citations)
st.session_state.messages.append({"role": "assistant", "content": message_content.value})
# log user query + assistant response + metadata
st.session_state.logged_prompt = collector.log_prompt(
config_model={"model": "gpt-4-turbo-preview"},
prompt=prompt,
generation=message_content.value,
metadata=metadata
)
# not functional because user feedback comes back empty
# # display feedback ui
# user_feedback = collector.st_feedback(
# component="default",
# feedback_type="thumbs",
# model=st.session_state.logged_prompt.config_model.model,
# prompt_id=st.session_state.logged_prompt.id,
# open_feedback_label='[Optional] Provide additional feedback'
# )
# if user_feedback:
# # log user feedback
# trubrics.log_feedback(
# component="default",
# model=st.session_state.logged_prompt.config_model.model,
# user_response=user_feedback,
# prompt_id=st.session_state.logged_prompt.id
# )
with st.form('form'):
streamlit_feedback(
feedback_type = "thumbs",
align = "flex-start",
key='feedback_key'
)
st.text_input(
label="Please elaborate on your response.",
key="feedback_response"
)
st.form_submit_button('Submit', on_click=_submit_feedback)