From da1c2bf12e7f3eab10c08690fb0566de5914ef83 Mon Sep 17 00:00:00 2001
From: Eric Zhu <ekzhu@users.noreply.github.com>
Date: Tue, 21 Jan 2025 06:06:19 -0800
Subject: [PATCH] fix: use tool_calls field to detect tool calls in OpenAI
 client; add integration tests for OpenAI and Gemini (#5122)

* fix: use tool_calls field to detect tool calls in OpenAI client

* Add unit tests for tool calling; and integration tests for openai and gemini
---
 .../models/openai/_openai_client.py           |  30 +-
 .../tests/models/test_openai_model_client.py  | 331 +++++++++++++++++-
 2 files changed, 349 insertions(+), 12 deletions(-)

diff --git a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py
index c44d03711b75..79c13442c7dc 100644
--- a/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py
+++ b/python/packages/autogen-ext/src/autogen_ext/models/openai/_openai_client.py
@@ -539,20 +539,33 @@ async def create(
         if self._resolved_model is not None:
             if self._resolved_model != result.model:
                 warnings.warn(
-                    f"Resolved model mismatch: {self._resolved_model} != {result.model}. Model mapping may be incorrect.",
+                    f"Resolved model mismatch: {self._resolved_model} != {result.model}. "
+                    "Model mapping in autogen_ext.models.openai may be incorrect.",
                     stacklevel=2,
                 )
 
         # Limited to a single choice currently.
         choice: Union[ParsedChoice[Any], ParsedChoice[BaseModel], Choice] = result.choices[0]
-        if choice.finish_reason == "function_call":
-            raise ValueError("Function calls are not supported in this context")
 
+        # Detect whether it is a function call or not.
+        # We don't rely on choice.finish_reason as it is not always accurate, depending on the API used.
         content: Union[str, List[FunctionCall]]
-        if choice.finish_reason == "tool_calls":
-            assert choice.message.tool_calls is not None
-            assert choice.message.function_call is None
-
+        if choice.message.function_call is not None:
+            raise ValueError("function_call is deprecated and is not supported by this model client.")
+        elif choice.message.tool_calls is not None:
+            if choice.finish_reason != "tool_calls":
+                warnings.warn(
+                    f"Finish reason mismatch: {choice.finish_reason} != tool_calls "
+                    "when tool_calls are present. Finish reason may not be accurate. "
+                    "This may be due to the API used that is not returning the correct finish reason.",
+                    stacklevel=2,
+                )
+            if choice.message.content is not None and choice.message.content != "":
+                warnings.warn(
+                    "Both tool_calls and content are present in the message. "
+                    "This is unexpected. content will be ignored, tool_calls will be used.",
+                    stacklevel=2,
+                )
             # NOTE: If OAI response type changes, this will need to be updated
             content = [
                 FunctionCall(
@@ -562,10 +575,11 @@ async def create(
                 )
                 for x in choice.message.tool_calls
             ]
-            finish_reason = "function_calls"
+            finish_reason = "tool_calls"
         else:
             finish_reason = choice.finish_reason
             content = choice.message.content or ""
+
         logprobs: Optional[List[ChatCompletionTokenLogprob]] = None
         if choice.logprobs and choice.logprobs.content:
             logprobs = [
diff --git a/python/packages/autogen-ext/tests/models/test_openai_model_client.py b/python/packages/autogen-ext/tests/models/test_openai_model_client.py
index 18312e1c1614..d629cdc428a5 100644
--- a/python/packages/autogen-ext/tests/models/test_openai_model_client.py
+++ b/python/packages/autogen-ext/tests/models/test_openai_model_client.py
@@ -1,10 +1,11 @@
 import asyncio
 import json
-from typing import Annotated, Any, AsyncGenerator, Generic, List, Literal, Tuple, TypeVar
+import os
+from typing import Annotated, Any, AsyncGenerator, Dict, Generic, List, Literal, Tuple, TypeVar
 from unittest.mock import MagicMock
 
 import pytest
-from autogen_core import CancellationToken, Image
+from autogen_core import CancellationToken, FunctionCall, Image
 from autogen_core.models import (
     AssistantMessage,
     CreateResult,
@@ -26,10 +27,31 @@
 from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta
 from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
 from openai.types.chat.chat_completion_message import ChatCompletionMessage
+from openai.types.chat.chat_completion_message_tool_call import (
+    ChatCompletionMessageToolCall,
+    Function,
+)
 from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage, ParsedChoice
 from openai.types.completion_usage import CompletionUsage
 from pydantic import BaseModel, Field
 
+
+class _MockChatCompletion:
+    def __init__(self, chat_completions: List[ChatCompletion]) -> None:
+        self._saved_chat_completions = chat_completions
+        self.curr_index = 0
+        self.calls: List[Dict[str, Any]] = []
+
+    async def mock_create(
+        self, *args: Any, **kwargs: Any
+    ) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
+        self.calls.append(kwargs)  # Save the call
+        await asyncio.sleep(0.1)
+        completion = self._saved_chat_completions[self.curr_index]
+        self.curr_index += 1
+        return completion
+
+
 ResponseFormatT = TypeVar("ResponseFormatT", bound=BaseModel)
 
 
@@ -37,20 +59,32 @@ class _MockBetaChatCompletion(Generic[ResponseFormatT]):
     def __init__(self, chat_completions: List[ParsedChatCompletion[ResponseFormatT]]) -> None:
         self._saved_chat_completions = chat_completions
         self.curr_index = 0
-        self.calls: List[List[LLMMessage]] = []
+        self.calls: List[Dict[str, Any]] = []
 
     async def mock_parse(
         self,
         *args: Any,
         **kwargs: Any,
     ) -> ParsedChatCompletion[ResponseFormatT]:
-        self.calls.append(kwargs["messages"])
+        self.calls.append(kwargs)  # Save the call
         await asyncio.sleep(0.1)
         completion = self._saved_chat_completions[self.curr_index]
         self.curr_index += 1
         return completion
 
 
+def _pass_function(input: str) -> str:
+    return "pass"
+
+
+async def _fail_function(input: str) -> str:
+    return "fail"
+
+
+async def _echo_function(input: str) -> str:
+    return input
+
+
 class MyResult(BaseModel):
     result: str = Field(description="The other description.")
 
@@ -432,3 +466,292 @@ class AgentResponse(BaseModel):
         == "The user explicitly states that they are happy without any indication of sadness or neutrality."
     )
     assert response.response == "happy"
+
+
+@pytest.mark.asyncio
+async def test_tool_calling(monkeypatch: pytest.MonkeyPatch) -> None:
+    model = "gpt-4o-2024-05-13"
+    chat_completions = [
+        # Successful completion, single tool call
+        ChatCompletion(
+            id="id1",
+            choices=[
+                Choice(
+                    finish_reason="tool_calls",
+                    index=0,
+                    message=ChatCompletionMessage(
+                        content=None,
+                        tool_calls=[
+                            ChatCompletionMessageToolCall(
+                                id="1",
+                                type="function",
+                                function=Function(
+                                    name="_pass_function",
+                                    arguments=json.dumps({"input": "task"}),
+                                ),
+                            )
+                        ],
+                        role="assistant",
+                    ),
+                )
+            ],
+            created=0,
+            model=model,
+            object="chat.completion",
+            usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
+        ),
+        # Successful completion, parallel tool calls
+        ChatCompletion(
+            id="id2",
+            choices=[
+                Choice(
+                    finish_reason="tool_calls",
+                    index=0,
+                    message=ChatCompletionMessage(
+                        content=None,
+                        tool_calls=[
+                            ChatCompletionMessageToolCall(
+                                id="1",
+                                type="function",
+                                function=Function(
+                                    name="_pass_function",
+                                    arguments=json.dumps({"input": "task"}),
+                                ),
+                            ),
+                            ChatCompletionMessageToolCall(
+                                id="2",
+                                type="function",
+                                function=Function(
+                                    name="_fail_function",
+                                    arguments=json.dumps({"input": "task"}),
+                                ),
+                            ),
+                            ChatCompletionMessageToolCall(
+                                id="3",
+                                type="function",
+                                function=Function(
+                                    name="_echo_function",
+                                    arguments=json.dumps({"input": "task"}),
+                                ),
+                            ),
+                        ],
+                        role="assistant",
+                    ),
+                )
+            ],
+            created=0,
+            model=model,
+            object="chat.completion",
+            usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
+        ),
+        # Warning completion when finish reason is not tool_calls.
+        ChatCompletion(
+            id="id3",
+            choices=[
+                Choice(
+                    finish_reason="stop",
+                    index=0,
+                    message=ChatCompletionMessage(
+                        content=None,
+                        tool_calls=[
+                            ChatCompletionMessageToolCall(
+                                id="1",
+                                type="function",
+                                function=Function(
+                                    name="_pass_function",
+                                    arguments=json.dumps({"input": "task"}),
+                                ),
+                            )
+                        ],
+                        role="assistant",
+                    ),
+                )
+            ],
+            created=0,
+            model=model,
+            object="chat.completion",
+            usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
+        ),
+        # Warning completion when content is not None.
+        ChatCompletion(
+            id="id4",
+            choices=[
+                Choice(
+                    finish_reason="tool_calls",
+                    index=0,
+                    message=ChatCompletionMessage(
+                        content="I should make a tool call.",
+                        tool_calls=[
+                            ChatCompletionMessageToolCall(
+                                id="1",
+                                type="function",
+                                function=Function(
+                                    name="_pass_function",
+                                    arguments=json.dumps({"input": "task"}),
+                                ),
+                            )
+                        ],
+                        role="assistant",
+                    ),
+                )
+            ],
+            created=0,
+            model=model,
+            object="chat.completion",
+            usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
+        ),
+    ]
+    mock = _MockChatCompletion(chat_completions)
+    monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
+    pass_tool = FunctionTool(_pass_function, description="pass tool.")
+    fail_tool = FunctionTool(_fail_function, description="fail tool.")
+    echo_tool = FunctionTool(_echo_function, description="echo tool.")
+    model_client = OpenAIChatCompletionClient(model=model, api_key="")
+
+    # Single tool call
+    create_result = await model_client.create(messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool])
+    assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
+    # Verify that the tool schema was passed to the model client.
+    kwargs = mock.calls[0]
+    assert kwargs["tools"] == [{"function": pass_tool.schema, "type": "function"}]
+    # Verify finish reason
+    assert create_result.finish_reason == "function_calls"
+
+    # Parallel tool calls
+    create_result = await model_client.create(
+        messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool, fail_tool, echo_tool]
+    )
+    assert create_result.content == [
+        FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function"),
+        FunctionCall(id="2", arguments=r'{"input": "task"}', name="_fail_function"),
+        FunctionCall(id="3", arguments=r'{"input": "task"}', name="_echo_function"),
+    ]
+    # Verify that the tool schema was passed to the model client.
+    kwargs = mock.calls[1]
+    assert kwargs["tools"] == [
+        {"function": pass_tool.schema, "type": "function"},
+        {"function": fail_tool.schema, "type": "function"},
+        {"function": echo_tool.schema, "type": "function"},
+    ]
+    # Verify finish reason
+    assert create_result.finish_reason == "function_calls"
+
+    # Warning completion when finish reason is not tool_calls.
+    with pytest.warns(UserWarning, match="Finish reason mismatch"):
+        create_result = await model_client.create(
+            messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
+        )
+        assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
+        assert create_result.finish_reason == "function_calls"
+
+    # Warning completion when content is not None.
+    with pytest.warns(UserWarning, match="Both tool_calls and content are present in the message"):
+        create_result = await model_client.create(
+            messages=[UserMessage(content="Hello", source="user")], tools=[pass_tool]
+        )
+        assert create_result.content == [FunctionCall(id="1", arguments=r'{"input": "task"}', name="_pass_function")]
+        assert create_result.finish_reason == "function_calls"
+
+
+async def _test_model_client(model_client: OpenAIChatCompletionClient) -> None:
+    # Test basic completion
+    create_result = await model_client.create(
+        messages=[
+            SystemMessage(content="You are a helpful assistant."),
+            UserMessage(content="Explain to me how AI works.", source="user"),
+        ]
+    )
+    assert isinstance(create_result.content, str)
+    assert len(create_result.content) > 0
+
+    # Test tool calling
+    pass_tool = FunctionTool(_pass_function, name="pass_tool", description="pass session.")
+    fail_tool = FunctionTool(_fail_function, name="fail_tool", description="fail session.")
+    messages: List[LLMMessage] = [UserMessage(content="Call the pass tool with input 'task'", source="user")]
+    create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool])
+    assert isinstance(create_result.content, list)
+    assert len(create_result.content) == 1
+    assert isinstance(create_result.content[0], FunctionCall)
+    assert create_result.content[0].name == "pass_tool"
+    assert json.loads(create_result.content[0].arguments) == {"input": "task"}
+    assert create_result.finish_reason == "function_calls"
+    assert create_result.usage is not None
+
+    # Test reflection on tool call response.
+    messages.append(AssistantMessage(content=create_result.content, source="assistant"))
+    messages.append(
+        FunctionExecutionResultMessage(
+            content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id)]
+        )
+    )
+    create_result = await model_client.create(messages=messages)
+    assert isinstance(create_result.content, str)
+    assert len(create_result.content) > 0
+
+    # Test parallel tool calling
+    messages = [
+        UserMessage(
+            content="Call both the pass tool with input 'task' and the fail tool also with input 'task'", source="user"
+        )
+    ]
+    create_result = await model_client.create(messages=messages, tools=[pass_tool, fail_tool])
+    assert isinstance(create_result.content, list)
+    assert len(create_result.content) == 2
+    assert isinstance(create_result.content[0], FunctionCall)
+    assert create_result.content[0].name == "pass_tool"
+    assert json.loads(create_result.content[0].arguments) == {"input": "task"}
+    assert isinstance(create_result.content[1], FunctionCall)
+    assert create_result.content[1].name == "fail_tool"
+    assert json.loads(create_result.content[1].arguments) == {"input": "task"}
+    assert create_result.finish_reason == "function_calls"
+    assert create_result.usage is not None
+
+    # Test reflection on parallel tool call response.
+    messages.append(AssistantMessage(content=create_result.content, source="assistant"))
+    messages.append(
+        FunctionExecutionResultMessage(
+            content=[
+                FunctionExecutionResult(content="passed", call_id=create_result.content[0].id),
+                FunctionExecutionResult(content="failed", call_id=create_result.content[1].id),
+            ]
+        )
+    )
+    create_result = await model_client.create(messages=messages)
+    assert isinstance(create_result.content, str)
+    assert len(create_result.content) > 0
+
+
+@pytest.mark.asyncio
+async def test_openai() -> None:
+    api_key = os.getenv("OPENAI_API_KEY")
+    if not api_key:
+        pytest.skip("OPENAI_API_KEY not found in environment variables")
+
+    model_client = OpenAIChatCompletionClient(
+        model="gpt-4o-mini",
+        api_key=api_key,
+    )
+    await _test_model_client(model_client)
+
+
+@pytest.mark.asyncio
+async def test_gemini() -> None:
+    api_key = os.getenv("GEMINI_API_KEY")
+    if not api_key:
+        pytest.skip("GEMINI_API_KEY not found in environment variables")
+
+    model_client = OpenAIChatCompletionClient(
+        model="gemini-1.5-flash",
+        api_key=api_key,
+        base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
+        model_info={
+            "function_calling": True,
+            "json_output": True,
+            "vision": True,
+            "family": ModelFamily.UNKNOWN,
+        },
+    )
+    await _test_model_client(model_client)
+
+
+# TODO: add integration tests for Azure OpenAI using AAD token.