From 79047b2ca4d0063eb5ec0d9e1be6c44672be41d2 Mon Sep 17 00:00:00 2001 From: Piotr Gloger <36278157+pedrito87@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:00:54 +0100 Subject: [PATCH] genai: Fix multiple tool calls in a single AIMessage (#671) --- .../langchain_google_genai/chat_models.py | 116 +++++++++--------- .../tests/unit_tests/test_chat_models.py | 81 ++++++++++-- 2 files changed, 131 insertions(+), 66 deletions(-) diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 7ae17fa87..48608c47a 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -301,6 +301,49 @@ def _convert_to_parts( return parts +def _convert_tool_message_to_part(message: ToolMessage | FunctionMessage) -> Part: + """Converts a tool or function message to a google part.""" + name = message.name + response: Any + if not isinstance(message.content, str): + response = message.content + else: + try: + response = json.loads(message.content) + except json.JSONDecodeError: + response = message.content # leave as str representation + part = Part( + function_response=FunctionResponse( + name=name, + response=( + {"output": response} if not isinstance(response, dict) else response + ), + ) + ) + return part + + +def _get_ai_message_tool_messages_parts( + tool_messages: Sequence[ToolMessage], ai_message: AIMessage +) -> list[Part]: + """ + Finds relevant tool messages for the AI message and converts them to a single + list of Parts. + """ + # We are interested only in the tool messages that are part of the AI message + tool_calls_ids = [tool_call["id"] for tool_call in ai_message.tool_calls] + parts = [] + for i, message in enumerate(tool_messages): + if not tool_calls_ids: + break + if message.tool_call_id in tool_calls_ids: + # remove the id from the list, so that we do not iterate over it again + tool_calls_ids.remove(message.tool_call_id) + part = _convert_tool_message_to_part(message) + parts.append(part) + return parts + + def _parse_chat_history( input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False ) -> Tuple[Optional[Content], List[Content]]: @@ -310,14 +353,20 @@ def _parse_chat_history( warnings.warn("Convert_system_message_to_human will be deprecated!") system_instruction: Optional[Content] = None - for i, message in enumerate(input_messages): + messages_without_tool_messages = [ + message for message in input_messages if not isinstance(message, ToolMessage) + ] + tool_messages = [ + message for message in input_messages if isinstance(message, ToolMessage) + ] + for i, message in enumerate(messages_without_tool_messages): if i == 0 and isinstance(message, SystemMessage): system_instruction = Content(parts=_convert_to_parts(message.content)) continue elif isinstance(message, AIMessage): role = "model" if message.tool_calls: - parts = [] + ai_message_parts = [] for tool_call in message.tool_calls: function_call = FunctionCall( { @@ -325,7 +374,13 @@ def _parse_chat_history( "args": tool_call["args"], } ) - parts.append(Part(function_call=function_call)) + ai_message_parts.append(Part(function_call=function_call)) + tool_messages_parts = _get_ai_message_tool_messages_parts( + tool_messages=tool_messages, ai_message=message + ) + messages.append(Content(role=role, parts=ai_message_parts)) + messages.append(Content(role="user", parts=tool_messages_parts)) + continue elif raw_function_call := message.additional_kwargs.get("function_call"): function_call = FunctionCall( { @@ -344,60 +399,7 @@ def _parse_chat_history( system_instruction = None elif isinstance(message, FunctionMessage): role = "user" - response: Any - if not isinstance(message.content, str): - response = message.content - else: - try: - response = json.loads(message.content) - except json.JSONDecodeError: - response = message.content # leave as str representation - parts = [ - Part( - function_response=FunctionResponse( - name=message.name, - response=( - {"output": response} - if not isinstance(response, dict) - else response - ), - ) - ) - ] - elif isinstance(message, ToolMessage): - role = "user" - prev_message: Optional[BaseMessage] = ( - input_messages[i - 1] if i > 0 else None - ) - if ( - prev_message - and isinstance(prev_message, AIMessage) - and prev_message.tool_calls - ): - # message.name can be null for ToolMessage - name: str = prev_message.tool_calls[0]["name"] - else: - name = message.name # type: ignore - tool_response: Any - if not isinstance(message.content, str): - tool_response = message.content - else: - try: - tool_response = json.loads(message.content) - except json.JSONDecodeError: - tool_response = message.content # leave as str representation - parts = [ - Part( - function_response=FunctionResponse( - name=name, - response=( - {"output": tool_response} - if not isinstance(tool_response, dict) - else tool_response - ), - ) - ) - ] + parts = [_convert_tool_message_to_part(message)] else: raise ValueError( f"Unexpected message with type {type(message)} at the position {i}." diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index 7be08fa36..1fbf01842 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -133,7 +133,8 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: function_name = "calculator" function_call_1 = { "name": function_name, - "arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "+"}), + "args": {"arg1": "2", "arg2": "2", "op": "+"}, + "id": "0", } function_answer1 = json.dumps({"result": 4}) function_call_2 = { @@ -141,18 +142,28 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: "arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "*"}), } function_answer2 = json.dumps({"result": 4}) + function_call_3 = { + "name": function_name, + "args": {"arg1": "2", "arg2": "2", "op": "*"}, + "id": "1", + } + function_answer_3 = json.dumps({"result": 4}) + function_call_4 = { + "name": function_name, + "args": {"arg1": "2", "arg2": "3", "op": "*"}, + "id": "2", + } + function_answer_4 = json.dumps({"result": 6}) text_answer1 = "They are same" system_message = SystemMessage(content=system_input) message1 = HumanMessage(content=text_question1) message2 = AIMessage( content="", - additional_kwargs={ - "function_call": function_call_1, - }, + tool_calls=[function_call_1], ) message3 = ToolMessage( - name="calculator", content=function_answer1, tool_call_id="1" + name="calculator", content=function_answer1, tool_call_id="0" ) message4 = AIMessage( content="", @@ -161,7 +172,14 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: }, ) message5 = FunctionMessage(name="calculator", content=function_answer2) - message6 = AIMessage(content=text_answer1) + message6 = AIMessage(content="", tool_calls=[function_call_3, function_call_4]) + message7 = ToolMessage( + name="calculator", content=function_answer_3, tool_call_id="1" + ) + message8 = ToolMessage( + name="calculator", content=function_answer_4, tool_call_id="2" + ) + message9 = AIMessage(content=text_answer1) messages = [ system_message, message1, @@ -170,11 +188,14 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: message4, message5, message6, + message7, + message8, + message9, ] system_instruction, history = _parse_chat_history( messages, convert_system_message_to_human=convert_system_message_to_human ) - assert len(history) == 6 + assert len(history) == 8 if convert_system_message_to_human: assert history[0] == glm.Content( role="user", @@ -191,7 +212,7 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: function_call=glm.FunctionCall( { "name": "calculator", - "args": json.loads(function_call_1["arguments"]), + "args": function_call_1["args"], } ) ) @@ -236,7 +257,49 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: ) ], ) - assert history[5] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)]) + assert history[5] == glm.Content( + role="model", + parts=[ + glm.Part( + function_call=glm.FunctionCall( + { + "name": "calculator", + "args": function_call_3["args"], + } + ) + ), + glm.Part( + function_call=glm.FunctionCall( + { + "name": "calculator", + "args": function_call_4["args"], + } + ) + ), + ], + ) + assert history[6] == glm.Content( + role="user", + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + { + "name": "calculator", + "response": {"result": 4}, + } + ) + ), + glm.Part( + function_response=glm.FunctionResponse( + { + "name": "calculator", + "response": {"result": 6}, + } + ) + ), + ], + ) + assert history[7] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)]) if convert_system_message_to_human: assert system_instruction is None else: