From 1f6d2d7149ac8700ab7a54d674fd35aece27463f Mon Sep 17 00:00:00 2001 From: lgesuellip <102637283+lgesuellip@users.noreply.github.com> Date: Tue, 7 Jan 2025 06:00:29 -0300 Subject: [PATCH 1/3] vertexai: Add ChatAnthropicVertex Text Caching Support (#672) --- .../_anthropic_utils.py | 58 ++-- .../langchain_google_vertexai/model_garden.py | 20 +- .../integration_tests/test_anthropic_cache.py | 147 ++++++++++ .../tests/unit_tests/test_anthropic_utils.py | 260 +++++++++++++++++- .../tests/unit_tests/test_chat_models.py | 20 +- 5 files changed, 478 insertions(+), 27 deletions(-) create mode 100644 libs/vertexai/tests/integration_tests/test_anthropic_cache.py diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py index fd3721d2..af2b10cb 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py @@ -59,14 +59,39 @@ def _format_image(image_url: str) -> Dict: } -def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): - role = _message_type_lookups[message.type] +def _get_cache_control(message: BaseMessage) -> Optional[Dict[str, Any]]: + """Extract cache control from message's additional_kwargs or content block.""" + return ( + message.additional_kwargs.get("cache_control") + if isinstance(message.additional_kwargs, dict) + else None + ) + + +def _format_text_content(text: str) -> Dict[str, Union[str, Dict[str, Any]]]: + """Format text content.""" + content: Dict[str, Union[str, Dict[str, Any]]] = {"type": "text", "text": text} + return content + + +def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMessage]): + """Format a message for Anthropic API. + + Args: + message: The message to format. Can be HumanMessage, AIMessage, or SystemMessage. + + Returns: + A dictionary with the formatted message, or None if the message is empty. + """ # noqa: E501 content: List[Dict[str, Any]] = [] if isinstance(message.content, str): if not message.content.strip(): return None - content.append({"type": "text", "text": message.content}) + message_dict = _format_text_content(message.content) + if cache_control := _get_cache_control(message): + message_dict["cache_control"] = cache_control + content.append(message_dict) elif isinstance(message.content, list): for block in message.content: if isinstance(block, str): @@ -75,9 +100,8 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): # https://github.com/anthropics/anthropic-sdk-python/issues/461 if not block.strip(): continue - content.append({"type": "text", "text": block}) - - if isinstance(block, dict): + content.append(_format_text_content(block)) + elif isinstance(block, dict): if "type" not in block: raise ValueError("Dict content block must have a type key") @@ -113,25 +137,26 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage]): if not is_unique: continue - # all other block types content.append(block) else: raise ValueError("Message should be a str, list of str or list of dicts") - # adding all tool calls if isinstance(message, AIMessage) and message.tool_calls: for tc in message.tool_calls: tu = cast(Dict[str, Any], _lc_tool_call_to_anthropic_tool_use_block(tc)) content.append(tu) - return {"role": role, "content": content} + if message.type == "system": + return content + else: + return {"role": _message_type_lookups[message.type], "content": content} def _format_messages_anthropic( messages: List[BaseMessage], -) -> Tuple[Optional[str], List[Dict]]: +) -> Tuple[Optional[Dict[str, Any]], List[Dict]]: """Formats messages for anthropic.""" - system_message: Optional[str] = None + system_messages: Optional[Dict[str, Any]] = None formatted_messages: List[Dict] = [] merged_messages = _merge_messages(messages) @@ -139,12 +164,9 @@ def _format_messages_anthropic( if message.type == "system": if i != 0: raise ValueError("System message must be at beginning of message list.") - if not isinstance(message.content, str): - raise ValueError( - "System message must be a string, " - f"instead was: {type(message.content)}" - ) - system_message = message.content + fm = _format_message_anthropic(message) + if fm: + system_messages = fm continue fm = _format_message_anthropic(message) @@ -152,7 +174,7 @@ def _format_messages_anthropic( continue formatted_messages.append(fm) - return system_message, formatted_messages + return system_messages, formatted_messages class AnthropicTool(TypedDict): diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py index 9070ac05..fd7938d6 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -32,6 +32,7 @@ AIMessage, BaseMessage, ) +from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ( ChatGeneration, ChatGenerationChunk, @@ -61,6 +62,13 @@ from langchain_google_vertexai._base import _BaseVertexAIModelGarden, _VertexAICommon +class CacheUsageMetadata(UsageMetadata): + cache_creation_input_tokens: Optional[int] + """The number of input tokens used to create the cache entry.""" + cache_read_input_tokens: Optional[int] + """The number of input tokens read from the cache.""" + + class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM): """Large language models served from Vertex AI Model Garden.""" @@ -225,11 +233,13 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult: else: msg = AIMessage(content=content) # Collect token usage - msg.usage_metadata = { - "input_tokens": data.usage.input_tokens, - "output_tokens": data.usage.output_tokens, - "total_tokens": data.usage.input_tokens + data.usage.output_tokens, - } + msg.usage_metadata = CacheUsageMetadata( + input_tokens=data.usage.input_tokens, + output_tokens=data.usage.output_tokens, + total_tokens=data.usage.input_tokens + data.usage.output_tokens, + cache_creation_input_tokens=data.usage.cache_creation_input_tokens, + cache_read_input_tokens=data.usage.cache_read_input_tokens, + ) return ChatResult( generations=[ChatGeneration(message=msg)], llm_output=llm_output, diff --git a/libs/vertexai/tests/integration_tests/test_anthropic_cache.py b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py new file mode 100644 index 00000000..9399331b --- /dev/null +++ b/libs/vertexai/tests/integration_tests/test_anthropic_cache.py @@ -0,0 +1,147 @@ +"""Integration tests for Anthropic cache control functionality.""" +import os +from typing import Dict, List, Union + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate + +from langchain_google_vertexai.model_garden import ChatAnthropicVertex + + +@pytest.mark.extended +def test_anthropic_system_cache() -> None: + """Test chat with system message having cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content="You are my personal assistant. Be helpful and concise.", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + message = HumanMessage(content="Hello! What can you do for me?") + + response = model.invoke( + [context, message], model_name="claude-3-5-sonnet-v2@20241022" + ) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "usage_metadata" in response.additional_kwargs + assert "cache_creation_input_tokens" in response.additional_kwargs["usage_metadata"] + + +@pytest.mark.extended +def test_anthropic_mixed_cache() -> None: + """Test chat with different cache control types.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content=[ + { + "type": "text", + "text": "You are my personal assistant.", + "cache_control": {"type": "ephemeral"}, + } + ] + ) + message = HumanMessage( + content=[ + { + "type": "text", + "text": "What's your name and what can you help me with?", + "cache_control": {"type": "ephemeral"}, + } + ] + ) + + response = model.invoke( + [context, message], model_name="claude-3-5-sonnet-v2@20241022" + ) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "usage_metadata" in response.additional_kwargs + + +@pytest.mark.extended +def test_anthropic_conversation_cache() -> None: + """Test chat conversation with cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + context = SystemMessage( + content="You are my personal assistant. My name is Peter.", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + messages = [ + context, + HumanMessage( + content=[ + { + "type": "text", + "text": "What's my name?", + "cache_control": {"type": "ephemeral"}, + } + ] + ), + AIMessage(content="Your name is Peter."), + HumanMessage( + content=[ + { + "type": "text", + "text": "Can you repeat my name?", + "cache_control": {"type": "ephemeral"}, + } + ] + ), + ] + + response = model.invoke(messages, model_name="claude-3-5-sonnet-v2@20241022") + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "peter" in response.content.lower() # Should remember the name + + +@pytest.mark.extended +def test_anthropic_chat_template_cache() -> None: + """Test chat template with structured content and cache control.""" + project = os.environ["PROJECT_ID"] + location = "us-central1" + model = ChatAnthropicVertex( + project=project, + location=location, + ) + + content: List[Union[Dict[str, Union[str, Dict[str, str]]], str]] = [ + { + "text": "You are a helpful assistant. Be concise and clear.", + "type": "text", + "cache_control": {"type": "ephemeral"}, + } + ] + + prompt = ChatPromptTemplate.from_messages( + [SystemMessage(content=content), ("human", "{input}")] + ) + + chain = prompt | model + + response = chain.invoke( + {"input": "What's the capital of France?"}, + ) + + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + assert "Paris" in response.content diff --git a/libs/vertexai/tests/unit_tests/test_anthropic_utils.py b/libs/vertexai/tests/unit_tests/test_anthropic_utils.py index 7f83bfb3..5e493823 100644 --- a/libs/vertexai/tests/unit_tests/test_anthropic_utils.py +++ b/libs/vertexai/tests/unit_tests/test_anthropic_utils.py @@ -1,3 +1,5 @@ +"""Unit tests for _anthropic_utils.py.""" + import pytest from langchain_core.messages import ( AIMessage, @@ -7,7 +9,260 @@ ) from langchain_core.messages.tool import tool_call as create_tool_call -from langchain_google_vertexai.model_garden import _format_messages_anthropic +from langchain_google_vertexai._anthropic_utils import ( + _format_message_anthropic, + _format_messages_anthropic, +) + + +def test_format_message_anthropic_with_cache_control_in_kwargs(): + """Test formatting a message with cache control in additional_kwargs.""" + message = HumanMessage( + content="Hello", additional_kwargs={"cache_control": {"type": "semantic"}} + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ], + } + + +def test_format_message_anthropic_with_cache_control_in_block(): + """Test formatting a message with cache control in content block.""" + message = HumanMessage( + content=[ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ] + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}} + ], + } + + +def test_format_message_anthropic_with_mixed_blocks(): + """Test formatting a message with mixed blocks, some with cache control.""" + message = HumanMessage( + content=[ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}}, + {"type": "text", "text": "World"}, + "Plain text", + ] + ) + result = _format_message_anthropic(message) + assert result == { + "role": "user", + "content": [ + {"type": "text", "text": "Hello", "cache_control": {"type": "semantic"}}, + {"type": "text", "text": "World"}, + {"type": "text", "text": "Plain text"}, + ], + } + + +def test_format_messages_anthropic_with_system_cache_control(): + """Test formatting messages with system message having cache control.""" + messages = [ + SystemMessage( + content="System message", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_message_anthropic_system(): + """Test formatting a system message.""" + message = SystemMessage( + content="System message", + additional_kwargs={"cache_control": {"type": "ephemeral"}}, + ) + result = _format_message_anthropic(message) + assert result == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + +def test_format_message_anthropic_system_list(): + """Test formatting a system message with list content.""" + message = SystemMessage( + content=[ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + ) + result = _format_message_anthropic(message) + assert result == [ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + + +def test_format_messages_anthropic_with_system_string(): + """Test formatting messages with system message as string.""" + messages = [ + SystemMessage(content="System message"), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [{"type": "text", "text": "System message"}] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_system_list(): + """Test formatting messages with system message as a list.""" + messages = [ + SystemMessage( + content=[ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System rule 1", + "cache_control": {"type": "ephemeral"}, + }, + {"type": "text", "text": "System rule 2"}, + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_system_mixed_list(): + """Test formatting messages with system message as a mixed list.""" + messages = [ + SystemMessage( + content=[ + "Plain system rule", + { + "type": "text", + "text": "Formatted system rule", + "cache_control": {"type": "ephemeral"}, + }, + ] + ), + HumanMessage(content="Hello"), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + {"type": "text", "text": "Plain system rule"}, + { + "type": "text", + "text": "Formatted system rule", + "cache_control": {"type": "ephemeral"}, + }, + ] + + assert formatted_messages == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + +def test_format_messages_anthropic_with_mixed_messages(): + """Test formatting a conversation with various message types and cache controls.""" + messages = [ + SystemMessage( + content=[ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + ), + HumanMessage( + content=[ + { + "type": "text", + "text": "Human message", + "cache_control": {"type": "semantic"}, + } + ] + ), + AIMessage( + content="AI response", + additional_kwargs={"cache_control": {"type": "semantic"}}, + ), + ] + system_messages, formatted_messages = _format_messages_anthropic(messages) + + assert system_messages == [ + { + "type": "text", + "text": "System message", + "cache_control": {"type": "ephemeral"}, + } + ] + + assert formatted_messages == [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Human message", + "cache_control": {"type": "semantic"}, + } + ], + }, + { + "role": "assistant", + "content": [ + { + "type": "text", + "text": "AI response", + "cache_control": {"type": "semantic"}, + } + ], + }, + ] @pytest.mark.parametrize( @@ -113,7 +368,7 @@ content="Mike age is 30", ), ], - "test1", + [{"type": "text", "text": "test1"}], [ { "role": "assistant", @@ -473,6 +728,7 @@ def test_format_messages_anthropic( source_history, expected_sm, expected_history ) -> None: + """Test the original format_messages_anthropic functionality.""" sm, result_history = _format_messages_anthropic(source_history) for result, expected in zip(result_history, expected_history): diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 06567e71..81e9eebf 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -1077,6 +1077,8 @@ def test_anthropic_format_output() -> None: class Usage: input_tokens: int output_tokens: int + cache_creation_input_tokens: Optional[int] + cache_read_input_tokens: Optional[int] @dataclass class Message: @@ -1092,13 +1094,25 @@ def model_dump(self): ], "model": "baz", "role": "assistant", - "usage": Usage(input_tokens=2, output_tokens=1), + "usage": Usage( + input_tokens=2, + output_tokens=1, + cache_creation_input_tokens=1, + cache_read_input_tokens=1, + ), "type": "message", } usage: Usage - test_msg = Message(usage=Usage(input_tokens=2, output_tokens=1)) + test_msg = Message( + usage=Usage( + input_tokens=2, + output_tokens=1, + cache_creation_input_tokens=1, + cache_read_input_tokens=1, + ) + ) model = ChatAnthropicVertex(project="test-project", location="test-location") result = model._format_output(test_msg) @@ -1113,4 +1127,6 @@ def model_dump(self): "input_tokens": 2, "output_tokens": 1, "total_tokens": 3, + "cache_creation_input_tokens": 1, + "cache_read_input_tokens": 1, } From 79047b2ca4d0063eb5ec0d9e1be6c44672be41d2 Mon Sep 17 00:00:00 2001 From: Piotr Gloger <36278157+pedrito87@users.noreply.github.com> Date: Tue, 7 Jan 2025 10:00:54 +0100 Subject: [PATCH 2/3] genai: Fix multiple tool calls in a single AIMessage (#671) --- .../langchain_google_genai/chat_models.py | 116 +++++++++--------- .../tests/unit_tests/test_chat_models.py | 81 ++++++++++-- 2 files changed, 131 insertions(+), 66 deletions(-) diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 7ae17fa8..48608c47 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -301,6 +301,49 @@ def _convert_to_parts( return parts +def _convert_tool_message_to_part(message: ToolMessage | FunctionMessage) -> Part: + """Converts a tool or function message to a google part.""" + name = message.name + response: Any + if not isinstance(message.content, str): + response = message.content + else: + try: + response = json.loads(message.content) + except json.JSONDecodeError: + response = message.content # leave as str representation + part = Part( + function_response=FunctionResponse( + name=name, + response=( + {"output": response} if not isinstance(response, dict) else response + ), + ) + ) + return part + + +def _get_ai_message_tool_messages_parts( + tool_messages: Sequence[ToolMessage], ai_message: AIMessage +) -> list[Part]: + """ + Finds relevant tool messages for the AI message and converts them to a single + list of Parts. + """ + # We are interested only in the tool messages that are part of the AI message + tool_calls_ids = [tool_call["id"] for tool_call in ai_message.tool_calls] + parts = [] + for i, message in enumerate(tool_messages): + if not tool_calls_ids: + break + if message.tool_call_id in tool_calls_ids: + # remove the id from the list, so that we do not iterate over it again + tool_calls_ids.remove(message.tool_call_id) + part = _convert_tool_message_to_part(message) + parts.append(part) + return parts + + def _parse_chat_history( input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False ) -> Tuple[Optional[Content], List[Content]]: @@ -310,14 +353,20 @@ def _parse_chat_history( warnings.warn("Convert_system_message_to_human will be deprecated!") system_instruction: Optional[Content] = None - for i, message in enumerate(input_messages): + messages_without_tool_messages = [ + message for message in input_messages if not isinstance(message, ToolMessage) + ] + tool_messages = [ + message for message in input_messages if isinstance(message, ToolMessage) + ] + for i, message in enumerate(messages_without_tool_messages): if i == 0 and isinstance(message, SystemMessage): system_instruction = Content(parts=_convert_to_parts(message.content)) continue elif isinstance(message, AIMessage): role = "model" if message.tool_calls: - parts = [] + ai_message_parts = [] for tool_call in message.tool_calls: function_call = FunctionCall( { @@ -325,7 +374,13 @@ def _parse_chat_history( "args": tool_call["args"], } ) - parts.append(Part(function_call=function_call)) + ai_message_parts.append(Part(function_call=function_call)) + tool_messages_parts = _get_ai_message_tool_messages_parts( + tool_messages=tool_messages, ai_message=message + ) + messages.append(Content(role=role, parts=ai_message_parts)) + messages.append(Content(role="user", parts=tool_messages_parts)) + continue elif raw_function_call := message.additional_kwargs.get("function_call"): function_call = FunctionCall( { @@ -344,60 +399,7 @@ def _parse_chat_history( system_instruction = None elif isinstance(message, FunctionMessage): role = "user" - response: Any - if not isinstance(message.content, str): - response = message.content - else: - try: - response = json.loads(message.content) - except json.JSONDecodeError: - response = message.content # leave as str representation - parts = [ - Part( - function_response=FunctionResponse( - name=message.name, - response=( - {"output": response} - if not isinstance(response, dict) - else response - ), - ) - ) - ] - elif isinstance(message, ToolMessage): - role = "user" - prev_message: Optional[BaseMessage] = ( - input_messages[i - 1] if i > 0 else None - ) - if ( - prev_message - and isinstance(prev_message, AIMessage) - and prev_message.tool_calls - ): - # message.name can be null for ToolMessage - name: str = prev_message.tool_calls[0]["name"] - else: - name = message.name # type: ignore - tool_response: Any - if not isinstance(message.content, str): - tool_response = message.content - else: - try: - tool_response = json.loads(message.content) - except json.JSONDecodeError: - tool_response = message.content # leave as str representation - parts = [ - Part( - function_response=FunctionResponse( - name=name, - response=( - {"output": tool_response} - if not isinstance(tool_response, dict) - else tool_response - ), - ) - ) - ] + parts = [_convert_tool_message_to_part(message)] else: raise ValueError( f"Unexpected message with type {type(message)} at the position {i}." diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index 7be08fa3..1fbf0184 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -133,7 +133,8 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: function_name = "calculator" function_call_1 = { "name": function_name, - "arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "+"}), + "args": {"arg1": "2", "arg2": "2", "op": "+"}, + "id": "0", } function_answer1 = json.dumps({"result": 4}) function_call_2 = { @@ -141,18 +142,28 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: "arguments": json.dumps({"arg1": "2", "arg2": "2", "op": "*"}), } function_answer2 = json.dumps({"result": 4}) + function_call_3 = { + "name": function_name, + "args": {"arg1": "2", "arg2": "2", "op": "*"}, + "id": "1", + } + function_answer_3 = json.dumps({"result": 4}) + function_call_4 = { + "name": function_name, + "args": {"arg1": "2", "arg2": "3", "op": "*"}, + "id": "2", + } + function_answer_4 = json.dumps({"result": 6}) text_answer1 = "They are same" system_message = SystemMessage(content=system_input) message1 = HumanMessage(content=text_question1) message2 = AIMessage( content="", - additional_kwargs={ - "function_call": function_call_1, - }, + tool_calls=[function_call_1], ) message3 = ToolMessage( - name="calculator", content=function_answer1, tool_call_id="1" + name="calculator", content=function_answer1, tool_call_id="0" ) message4 = AIMessage( content="", @@ -161,7 +172,14 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: }, ) message5 = FunctionMessage(name="calculator", content=function_answer2) - message6 = AIMessage(content=text_answer1) + message6 = AIMessage(content="", tool_calls=[function_call_3, function_call_4]) + message7 = ToolMessage( + name="calculator", content=function_answer_3, tool_call_id="1" + ) + message8 = ToolMessage( + name="calculator", content=function_answer_4, tool_call_id="2" + ) + message9 = AIMessage(content=text_answer1) messages = [ system_message, message1, @@ -170,11 +188,14 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: message4, message5, message6, + message7, + message8, + message9, ] system_instruction, history = _parse_chat_history( messages, convert_system_message_to_human=convert_system_message_to_human ) - assert len(history) == 6 + assert len(history) == 8 if convert_system_message_to_human: assert history[0] == glm.Content( role="user", @@ -191,7 +212,7 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: function_call=glm.FunctionCall( { "name": "calculator", - "args": json.loads(function_call_1["arguments"]), + "args": function_call_1["args"], } ) ) @@ -236,7 +257,49 @@ def test_parse_history(convert_system_message_to_human: bool) -> None: ) ], ) - assert history[5] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)]) + assert history[5] == glm.Content( + role="model", + parts=[ + glm.Part( + function_call=glm.FunctionCall( + { + "name": "calculator", + "args": function_call_3["args"], + } + ) + ), + glm.Part( + function_call=glm.FunctionCall( + { + "name": "calculator", + "args": function_call_4["args"], + } + ) + ), + ], + ) + assert history[6] == glm.Content( + role="user", + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + { + "name": "calculator", + "response": {"result": 4}, + } + ) + ), + glm.Part( + function_response=glm.FunctionResponse( + { + "name": "calculator", + "response": {"result": 6}, + } + ) + ), + ], + ) + assert history[7] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)]) if convert_system_message_to_human: assert system_instruction is None else: From 258d6246d1c72c925940740fd156ed947f4bdd0d Mon Sep 17 00:00:00 2001 From: dotrunghieu96 Date: Tue, 7 Jan 2025 16:04:13 +0700 Subject: [PATCH 3/3] anthropic on vertexai: Update parsed parameter from `image_url` to `image` (#669) --- libs/vertexai/langchain_google_vertexai/_anthropic_utils.py | 4 ++-- libs/vertexai/tests/unit_tests/test_anthropic_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py index af2b10cb..d542237a 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py @@ -123,8 +123,8 @@ def _format_message_anthropic(message: Union[HumanMessage, AIMessage, SystemMess if block["type"] == "image_url": # convert format - new_block["source"] = _format_image(block["image_url"]["url"]) - content.append(new_block) + source = _format_image(block["image_url"]["url"]) + content.append({"type": "image", "source": source}) continue if block["type"] == "tool_use": diff --git a/libs/vertexai/tests/unit_tests/test_anthropic_utils.py b/libs/vertexai/tests/unit_tests/test_anthropic_utils.py index 5e493823..08ba9aba 100644 --- a/libs/vertexai/tests/unit_tests/test_anthropic_utils.py +++ b/libs/vertexai/tests/unit_tests/test_anthropic_utils.py @@ -415,7 +415,7 @@ def test_format_messages_anthropic_with_mixed_messages(): "role": "user", "content": [ { - "type": "image_url", + "type": "image", "source": { "type": "base64", "media_type": "image/png",