-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsouthwest_agent.py
228 lines (186 loc) · 6.96 KB
/
southwest_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
import boto3
from langchain.agents import AgentExecutor, create_structured_chat_agent
from langchain_community.chat_models import BedrockChat
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain_core.runnables import RunnableConfig
from langchain.memory import ConversationBufferMemory
from langchain.tools import tool
from langchain.agents import Tool
import requests
import json
# ------------------------------------------------------------------------
# Constants
# Bedrock model id
MODEL_ID = "anthropic.claude-3-sonnet-20240229-v1:0"
# Bedrock model inference parameters
MODEL_KWARGS = {
"max_tokens": 2048,
"temperature": 0.0,
"top_k": 250,
"top_p": 1,
"stop_sequences": ["\n\nHuman"],
}
# Southwest API URL
SOUTHWEST_API_URL = "http://127.0.0.1"
# ------------------------------------------------------------------------
# LangChain
@tool
def search_southwest_flights(event: str) -> str:
"""Search Southwest Airlines for flights on the departure date \
between the origination airport and the destination airport \
for the number of passengers and the number of adults.
event: str --> The event in the format of a JSON String with the following keys: \
departure_date: str --> The date of the flight in the format yyyy-mm-dd. \
origination: str --> The origination airport 3-letter code. Examples: SAN, LAX, SFO. \
destination: str --> The destination airport 3-letter code. Examples: DAL, PHX, LGA. \
passenger_count: int --> The number of passengers. \
adult_count: int --> The number of adults.
"""
data = json.loads(event)
response = requests.post(
SOUTHWEST_API_URL,
json=data
)
return response.json()['message']
def initialize_tools():
search_southwest_flights_tool = Tool(
name="SearchSouthwestFlightsTool",
func=search_southwest_flights,
description="""
Use this tool with a JSON-encoded string argument like \
"{{"departure_date": "yyyy-mm-dd", "origination": "XXX", "destination": "YYY", "passenger_count": 1, "adult_count": 1}}" \
when you need to search for flights on Southwest Airlines. The input will always be a JSON encoded string with those arguments.
""",
)
return [
search_southwest_flights_tool
]
def initialize_bedrock_runtime():
"""Initialize the Bedrock runtime."""
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
region_name="us-west-2"
)
return bedrock_runtime
def initialize_model(bedrock_runtime, model_id, model_kwargs):
model = BedrockChat(
client=bedrock_runtime,
model_id=model_id,
model_kwargs=model_kwargs,
)
return model
def initialize_streamlit_memory():
history = StreamlitChatMessageHistory()
return history
def initialize_memory(streamlit_memory):
memory = ConversationBufferMemory(
chat_memory=streamlit_memory,
return_messages=True,
memory_key="chat_history",
output_key="output"
)
return memory
def intialize_prompt():
system = '''You are a Southwest Airlines customer support agent. You help customers find flights and book them.
Your goal is to generate an answer to the employee's message in a friendly, customer support like tone.
All tool inputs are in the format of a JSON string.
Do not use any tools if you can answer the employee's latest message without them.
You have access to the following tools:
{tools}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
Valid "action" values: "Final Answer" or {tool_names}
Provide only ONE action per $JSON_BLOB, as shown:
```
{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}
```
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
```
$JSON_BLOB
```
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
```
{{
"action": "Final Answer",
"action_input": "Final response to human"
}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation'''
human = '''
{input}
{agent_scratchpad}
(reminder to respond in a JSON blob no matter what)'''
prompt = ChatPromptTemplate.from_messages(
[
("system", system),
MessagesPlaceholder("chat_history", optional=True),
("human", human),
]
)
return prompt
# Initialize the Model
bedrock_runtime = initialize_bedrock_runtime()
model = initialize_model(bedrock_runtime, MODEL_ID, MODEL_KWARGS)
# Initialize the Memory
streamlit_memory = initialize_streamlit_memory()
memory = initialize_memory(streamlit_memory)
# Initialize the Tools
tools = initialize_tools()
# Initialize the Agent
system_prompt = intialize_prompt()
agent = create_structured_chat_agent(
model,
tools,
system_prompt
)
executor = agent_executor = AgentExecutor(
agent=agent,
tools=tools,
memory=memory,
verbose=False,
return_intermediate_steps=True,
handle_parsing_errors=True,
)
# ------------------------------------------------------------------------
# Streamlit
import streamlit as st
# Page title
st.set_page_config(page_title="Southwest Generative AI Agent Demo", page_icon=":plane:")
st.title("Southwest Generative AI Agent Demo")
st.caption("This is a demo of a Generative AI Assistant that can use Tools to interact with Southwest Airlines.")
# Initialize the session state for steps
if "steps" not in st.session_state.keys():
st.session_state.steps = {}
# Display current chat messages
for message in streamlit_memory.messages:
with st.chat_message(message.type):
st.write(message.content)
# Chat Input - User Prompt
if user_input := st.chat_input("Message"):
with st.chat_message("human"):
st.write(user_input)
# As usual, new messages are added to StreamlitChatMessageHistory when the Chain is called.
config = {"configurable": {"session_id": "any"}}
with st.chat_message("assistant"):
st_cb = StreamlitCallbackHandler(st.container(), expand_new_thoughts=False)
cfg = RunnableConfig()
cfg["callbacks"] = [st_cb]
chat_history = memory.buffer_as_messages
response = agent_executor.invoke(
input={
"input": f"{user_input}",
"chat_history": chat_history,
},
config=cfg
)
st.write(response["output"])
st.session_state.steps[str(len(streamlit_memory.messages) - 1)] = response["intermediate_steps"]