Skip to content

Commit

Permalink
FIX: support multiple tool calls (#502)
Browse files Browse the repository at this point in the history
* FIX: support multiple tool calls

Prior to this change we had a hard limit of 1 tool call per llm
round trip. This meant you could not google multiple things at
once or perform searches across two tools.

Also:

- Hint when Google stops working
- Log topic_id / post_id when performing completions

* Also track id for title
  • Loading branch information
SamSaffron authored Mar 1, 2024
1 parent b72ee80 commit c02794c
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 59 deletions.
78 changes: 46 additions & 32 deletions lib/ai_bot/bot.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class Bot

BOT_NOT_FOUND = Class.new(StandardError)
MAX_COMPLETIONS = 5
MAX_TOOLS = 5

def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil)
new(bot_user, persona, model)
Expand All @@ -21,14 +22,19 @@ def initialize(bot_user, persona, model = nil)
attr_reader :bot_user
attr_accessor :persona

def get_updated_title(conversation_context, post_user)
def get_updated_title(conversation_context, post)
system_insts = <<~TEXT.strip
You are titlebot. Given a topic, you will figure out a title.
You will never respond with anything but 7 word topic title.
TEXT

title_prompt =
DiscourseAi::Completions::Prompt.new(system_insts, messages: conversation_context)
DiscourseAi::Completions::Prompt.new(
system_insts,
messages: conversation_context,
topic_id: post.topic_id,
post_id: post.id,
)

title_prompt.push(
type: :user,
Expand All @@ -38,7 +44,7 @@ def get_updated_title(conversation_context, post_user)

DiscourseAi::Completions::Llm
.proxy(model)
.generate(title_prompt, user: post_user)
.generate(title_prompt, user: post.user)
.strip
.split("\n")
.last
Expand All @@ -64,37 +70,14 @@ def reply(context, &update_blk)

result =
llm.generate(prompt, **llm_kwargs) do |partial, cancel|
if (tool = persona.find_tool(partial))
tools = persona.find_tools(partial)

if (tools.present?)
tool_found = true
ongoing_chain = tool.chain_next_response?
tool_call_id = tool.tool_call_id
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json

tool_call_message = {
type: :tool_call,
id: tool_call_id,
content: { name: tool.name, arguments: tool.parameters }.to_json,
}

tool_message = { type: :tool, id: tool_call_id, content: invocation_result_json }

if tool.standalone?
standalone_context =
context.dup.merge(
conversation_context: [
context[:conversation_context].last,
tool_call_message,
tool_message,
],
)
prompt = persona.craft_prompt(standalone_context)
else
prompt.push(**tool_call_message)
prompt.push(**tool_message)
tools[0..MAX_TOOLS].each do |tool|
ongoing_chain &&= tool.chain_next_response?
process_tool(tool, raw_context, llm, cancel, update_blk, prompt)
end

raw_context << [tool_call_message[:content], tool_call_id, "tool_call"]
raw_context << [invocation_result_json, tool_call_id, "tool"]
else
update_blk.call(partial, cancel, nil)
end
Expand All @@ -115,6 +98,37 @@ def reply(context, &update_blk)

private

def process_tool(tool, raw_context, llm, cancel, update_blk, prompt)
tool_call_id = tool.tool_call_id
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json

tool_call_message = {
type: :tool_call,
id: tool_call_id,
content: { name: tool.name, arguments: tool.parameters }.to_json,
}

tool_message = { type: :tool, id: tool_call_id, content: invocation_result_json }

if tool.standalone?
standalone_context =
context.dup.merge(
conversation_context: [
context[:conversation_context].last,
tool_call_message,
tool_message,
],
)
prompt = persona.craft_prompt(standalone_context)
else
prompt.push(**tool_call_message)
prompt.push(**tool_message)
end

raw_context << [tool_call_message[:content], tool_call_id, "tool_call"]
raw_context << [invocation_result_json, tool_call_id, "tool"]
end

def invoke_tool(tool, llm, cancel, &update_blk)
update_blk.call("", cancel, build_placeholder(tool.summary, ""))

Expand Down
12 changes: 11 additions & 1 deletion lib/ai_bot/personas/persona.rb
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,25 @@ def craft_prompt(context)
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
TEXT
messages: context[:conversation_context].to_a,
topic_id: context[:topic_id],
post_id: context[:post_id],
)

prompt.tools = available_tools.map(&:signature) if available_tools

prompt
end

def find_tool(partial)
def find_tools(partial)
return [] if !partial.include?("</invoke>")

parsed_function = Nokogiri::HTML5.fragment(partial)
parsed_function.css("invoke").map { |fragment| find_tool(fragment) }.compact
end

protected

def find_tool(parsed_function)
function_id = parsed_function.at("tool_id")&.text
function_name = parsed_function.at("tool_name")&.text
return false if function_name.nil?
Expand Down
4 changes: 3 additions & 1 deletion lib/ai_bot/playground.rb
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def title_playground(post)
context = conversation_context(post)

bot
.get_updated_title(context, post.user)
.get_updated_title(context, post)
.tap do |new_title|
PostRevisor.new(post.topic.first_post, post.topic).revise!(
bot.bot_user,
Expand All @@ -182,6 +182,8 @@ def reply_to(post)
participants: post.topic.allowed_users.map(&:username).join(", "),
conversation_context: conversation_context(post),
user: post.user,
post_id: post.id,
topic_id: post.topic_id,
}

reply_user = bot.bot_user
Expand Down
14 changes: 14 additions & 0 deletions lib/ai_bot/tools/google.rb
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def invoke(bot_user, llm)
URI(
"https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{escaped_query}&num=10",
)

body = Net::HTTP.get(uri)

parse_search_json(body, escaped_query, llm)
Expand Down Expand Up @@ -65,6 +66,19 @@ def minimize_field(result, field, llm, max_tokens: 100)

def parse_search_json(json_data, escaped_query, llm)
parsed = JSON.parse(json_data)
error_code = parsed.dig("error", "code")
if error_code == 429
Rails.logger.warn(
"Google Custom Search is Rate Limited, no search can be performed at the moment. #{json_data[0..1000]}",
)
return(
"Google Custom Search is Rate Limited, no search can be performed at the moment. Let the user know there is a problem."
)
elsif error_code
Rails.logger.warn("Google Custom Search returned an error. #{json_data[0..1000]}")
return "Google Custom Search returned an error. Let the user know there is a problem."
end

results = parsed["items"]

@results_count = parsed.dig("searchInformation", "totalResults").to_i
Expand Down
4 changes: 3 additions & 1 deletion lib/completions/dialects/dialect.rb
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,11 @@ def max_prompt_tokens
raise NotImplemented
end

attr_reader :prompt

private

attr_reader :prompt, :model_name, :opts
attr_reader :model_name, :opts

def trim_messages(messages)
prompt_limit = max_prompt_tokens
Expand Down
20 changes: 14 additions & 6 deletions lib/completions/endpoints/base.rb
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def perform_completion!(dialect, user, model_params = {})
user_id: user&.id,
raw_request_payload: request_body,
request_tokens: prompt_size(prompt),
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
)

if !@streaming_mode
Expand Down Expand Up @@ -273,16 +275,22 @@ def extract_prompt_for_tokenizer(prompt)
def build_buffer
Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
<invoke>
<tool_name></tool_name>
<tool_id></tool_id>
<parameters>
</parameters>
</invoke>
#{noop_function_call_text}
</function_calls>
TEXT
end

def noop_function_call_text
(<<~TEXT).strip
<invoke>
<tool_name></tool_name>
<tool_id></tool_id>
<parameters>
</parameters>
</invoke>
TEXT
end

def has_tool?(response)
response.include?("<function")
end
Expand Down
24 changes: 21 additions & 3 deletions lib/completions/endpoints/open_ai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,26 @@ def add_to_buffer(function_buffer, _response_data, partial)
@args_buffer ||= +""

f_name = partial.dig(:function, :name)
function_buffer.at("tool_name").content = f_name if f_name
function_buffer.at("tool_id").content = partial[:id] if partial[:id]

@current_function ||= function_buffer.at("invoke")

if f_name
current_name = function_buffer.at("tool_name").content

if current_name.blank?
# first call
else
# we have a previous function, so we need to add a noop
@args_buffer = +""
@current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
end

@current_function.at("tool_name").content = f_name if f_name
@current_function.at("tool_id").content = partial[:id] if partial[:id]

args = partial.dig(:function, :arguments)

Expand All @@ -185,7 +203,7 @@ def add_to_buffer(function_buffer, _response_data, partial)
end
argument_fragments << "\n"

function_buffer.at("parameters").children =
@current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
rescue JSON::ParserError
return function_buffer
Expand Down
16 changes: 13 additions & 3 deletions lib/completions/prompt.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,22 @@ class Prompt
INVALID_TURN = Class.new(StandardError)

attr_reader :messages
attr_accessor :tools

def initialize(system_message_text = nil, messages: [], tools: [], skip_validations: false)
attr_accessor :tools, :topic_id, :post_id

def initialize(
system_message_text = nil,
messages: [],
tools: [],
skip_validations: false,
topic_id: nil,
post_id: nil
)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)

@topic_id = topic_id
@post_id = post_id

@messages = []
@skip_validations = skip_validations

Expand Down
2 changes: 0 additions & 2 deletions spec/lib/completions/endpoints/endpoint_compliance.rb
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,9 @@ def regular_mode_simple_prompt(mock)
def regular_mode_tools(mock)
prompt = generic_prompt(tools: [mock.tool])
a_dialect = dialect(prompt: prompt)

mock.stub_tool_call(a_dialect.translate)

completion_response = endpoint.perform_completion!(a_dialect, user)

expect(completion_response).to eq(mock.invocation_response)
end

Expand Down
Loading

0 comments on commit c02794c

Please sign in to comment.