From ee2fe2616280bdd3717f7ab01d7ebc47e36a5ecb Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Mon, 16 Dec 2024 00:25:53 +0530 Subject: [PATCH 1/7] Rebase to latest main branch --- python/packages/autogen-ext/pyproject.toml | 3 + .../models/_azure/_azure_ai_client.py | 380 ++++++++++++++++++ .../models/_azure/config/__init__.py | 37 ++ .../models/test_azure_ai_model_client.py | 0 4 files changed, 420 insertions(+) create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/_azure/_azure_ai_client.py create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/_azure/config/__init__.py create mode 100644 python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index 34a71f917ca7..192cd4d9edbd 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -52,6 +52,9 @@ video-surfer = [ grpc = [ "grpcio~=1.62.0", # TODO: update this once we have a stable version. ] +azure-ai-inference = [ + "azure-ai-inference>=1.0.0b6", +] [tool.hatch.build.targets.wheel] packages = ["src/autogen_ext"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/_azure/_azure_ai_client.py new file mode 100644 index 000000000000..de2885a0612d --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/_azure/_azure_ai_client.py @@ -0,0 +1,380 @@ +import asyncio +from asyncio import Task +from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast + +from azure.ai.inference.aio import ChatCompletionsClient +from azure.ai.inference.models import ( + ChatCompletions, + CompletionsFinishReason, + ChatCompletionsToolCall, + ChatCompletionsToolDefinition, + FunctionDefinition, + ContentItem, + TextContentItem, + ImageContentItem, + ImageUrl, + ImageDetailLevel, + StreamingChatCompletionsUpdate, + SystemMessage as AzureSystemMessage, + UserMessage as AzureUserMessage, + AssistantMessage as AzureAssistantMessage, + ToolMessage as AzureToolMessage, + FunctionCall as AzureFunctionCall, +) +from azure.ai.inference.models import ( + ChatCompletionsResponseFormatJSON, +) +from typing_extensions import AsyncGenerator, Union + +from autogen_core.base import CancellationToken +from autogen_core.components import FunctionCall, Image +from autogen_core.components.models import ( + ChatCompletionClient, + LLMMessage, + CreateResult, + ModelCapabilities, + RequestUsage, + UserMessage, + SystemMessage, + AssistantMessage, + FunctionExecutionResultMessage, +) +from autogen_core.components.tools import Tool, ToolSchema +from autogen_ext.models._azure.config import AzureAIConfig + + +def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]: + result: List[ChatCompletionsToolDefinition] = [] + for tool in tools: + if isinstance(tool, Tool): + tool_schema = tool.schema.copy() + else: + assert isinstance(tool, dict) + tool_schema = tool.copy() + # tool_schema["parameters"] = {k:v for k,v in tool_schema["parameters"].items()} + # azure_ai_schema = {k:v for k,v in tool_schema["parameters"].items()} + + for key, value in tool_schema["parameters"]["properties"].items(): + if "title" in value.keys(): + del value["title"] + + result.append( + ChatCompletionsToolDefinition( + function=FunctionDefinition( + name=tool_schema["name"], + description=(tool_schema["description"] if "description" in tool_schema else ""), + parameters=(tool_schema["parameters"]) if "parameters" in tool_schema else {}, + ), + ), + ) + return result + + +def _func_call_to_azure(message: FunctionCall) -> ChatCompletionsToolCall: + return ChatCompletionsToolCall( + id=message.id, + function=AzureFunctionCall(arguments=message.arguments, name=message.name), + ) + + +def _system_message_to_azure(message: SystemMessage) -> AzureSystemMessage: + return AzureSystemMessage(content=message.content) + + +def _user_message_to_azure(message: UserMessage) -> AzureUserMessage: + # assert_valid_name(message.source) + if isinstance(message.content, str): + return AzureUserMessage(content=message.content) + else: + parts: List[ContentItem] = [] + for part in message.content: + if isinstance(part, str): + parts.append(TextContentItem(text=part)) + elif isinstance(part, Image): + # TODO: support url based images + # TODO: support specifying details + parts.append(ImageContentItem(image_url=ImageUrl(url=part.data_uri, detail=ImageDetailLevel.AUTO))) + else: + raise ValueError(f"Unknown content type: {message.content}") + return AzureUserMessage(content=parts) + + +def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMessage: + # assert_valid_name(message.source) + if isinstance(message.content, list): + return AzureAssistantMessage( + tool_calls=[_func_call_to_azure(x) for x in message.content], + ) + else: + return AzureAssistantMessage(content=message.content) + + +def _tool_message_to_azure(message: FunctionExecutionResultMessage) -> Sequence[AzureToolMessage]: + return [AzureToolMessage(content=x.content, tool_call_id=x.call_id) for x in message.content] + + +def to_azure_message(message: LLMMessage): + if isinstance(message, SystemMessage): + return [_system_message_to_azure(message)] + elif isinstance(message, UserMessage): + return [_user_message_to_azure(message)] + elif isinstance(message, AssistantMessage): + return [_assistant_message_to_azure(message)] + else: + return _tool_message_to_azure(message) + + +class AzureAIChatCompletionClient(ChatCompletionClient): + def __init__(self, **kwargs: Unpack[AzureAIConfig]): + if "endpoint" not in kwargs: + raise ValueError("endpoint must be provided") + if "credential" not in kwargs: + raise ValueError("credential must be provided") + if "model_capabilities" not in kwargs: + raise ValueError("model_capabilities must be provided") + + self._model_capabilities = kwargs["model_capabilities"] + # TODO: Change + _endpoint = kwargs.pop("endpoint") + _credential = kwargs.pop("credential") + self.create_args = kwargs.copy() + + self._client = ChatCompletionsClient(_endpoint, _credential, **self.create_args) + self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + + async def create( + self, + messages: Sequence[LLMMessage], + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> CreateResult: + # TODO: Validate Args + + if self.capabilities["vision"] is False: + for message in messages: + if isinstance(message, UserMessage): + if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): + raise ValueError("Model does not support vision and image was provided") + args = {} + + if json_output is not None: + if self.capabilities["json_output"] is False and json_output is True: + raise ValueError("Model does not support JSON output") + + if json_output is True: + # TODO: JSON OUTPUT + args["response_format"] = ChatCompletionsResponseFormatJSON() + + if self.capabilities["json_output"] is False and json_output is True: + raise ValueError("Model does not support JSON output") + if self.capabilities["function_calling"] is False and len(tools) > 0: + raise ValueError("Model does not support function calling") + + azure_messages_nested = [to_azure_message(msg) for msg in messages] + azure_messages = [item for sublist in azure_messages_nested for item in sublist] + + task: Task[ChatCompletions] + + if len(tools) > 0: + converted_tools = convert_tools(tools) + task = asyncio.create_task( + self._client.complete( + messages=azure_messages, + tools=converted_tools, + # TODO: Add extra_create_args + ) + ) + else: + task = asyncio.create_task( + self._client.complete( + messages=azure_messages, + max_tokens=20, + **args, + # TODO: Add extra_create_args + ) + ) + + if cancellation_token is not None: + cancellation_token.link_future(task) + + result: ChatCompletions = await task + + usage = RequestUsage( + prompt_tokens=result.usage.prompt_tokens if result.usage else 0, + completion_tokens=result.usage.completion_tokens if result.usage else 0, + ) + + choice = result.choices[0] + if choice.finish_reason == CompletionsFinishReason.TOOL_CALLS: + assert choice.message.tool_calls is not None + + content = [ + FunctionCall( + id=x.id, + arguments=x.function.arguments, + name=x.function.name, + ) + for x in choice.message.tool_calls + ] + finish_reason = "function_calls" + else: + finish_reason = choice.finish_reason.value + content = choice.message.content or "" + + response = CreateResult( + finish_reason=finish_reason, # type: ignore + content=content, + usage=usage, + cached=False, + ) + return response + + async def create_stream( + self, + messages: Sequence[LLMMessage], + tools: Sequence[Tool | ToolSchema] = [], + json_output: Optional[bool] = None, + extra_create_args: Mapping[str, Any] = {}, + cancellation_token: Optional[CancellationToken] = None, + ) -> AsyncGenerator[Union[str, CreateResult], None]: + # TODO: Validate Args + + if self.capabilities["vision"] is False: + for message in messages: + if isinstance(message, UserMessage): + if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): + raise ValueError("Model does not support vision and image was provided") + args = {} + + if json_output is not None: + if self.capabilities["json_output"] is False and json_output is True: + raise ValueError("Model does not support JSON output") + + if json_output is True: + # TODO: JSON OUTPUT + args["response_format"] = ChatCompletionsResponseFormatJSON() + + if self.capabilities["json_output"] is False and json_output is True: + raise ValueError("Model does not support JSON output") + if self.capabilities["function_calling"] is False and len(tools) > 0: + raise ValueError("Model does not support function calling") + + # azure_messages = [to_azure_message(m) for m in messages] + azure_messages_nested = [to_azure_message(msg) for msg in messages] + azure_messages = [item for sublist in azure_messages_nested for item in sublist] + + # task: Task[StreamingChatCompletionsUpdate] + + if len(tools) > 0: + converted_tools = convert_tools(tools) + task = asyncio.create_task( + self._client.complete( + messages=azure_messages, + tools=converted_tools, + stream=True, + # TODO: Add extra_create_args + ) + ) + else: + task = asyncio.create_task( + self._client.complete( + messages=azure_messages, + max_tokens=20, + stream=True, + **args, + # TODO: Add extra_create_args + ) + ) + + if cancellation_token is not None: + cancellation_token.link_future(task) + + # result: ChatCompletions = await task + finish_reason = None + content_deltas: List[str] = [] + full_tool_calls: Dict[str, FunctionCall] = {} + prompt_tokens = 0 + completion_tokens = 0 + chunk: Optional[StreamingChatCompletionsUpdate] = None + async for chunk in await task: + choice = (chunk.choices[0] + if len(chunk.choices) > 0 + else cast(StreamingChatCompletionsUpdate, None)) + if choice.finish_reason is not None: + finish_reason = choice.finish_reason.value + + # We first try to load the content + if choice.delta.content is not None: + content_deltas.append(choice.delta.content) + yield choice.delta.content + # Otherwise, we try to load the tool calls + if choice.delta.tool_calls is not None: + for tool_call_chunk in choice.delta.tool_calls: + idx = tool_call_chunk.id + if idx not in full_tool_calls: + full_tool_calls[idx] = FunctionCall(id="", arguments="", name="") + # + if tool_call_chunk.id is not None: + full_tool_calls[idx].id += tool_call_chunk.id + + if tool_call_chunk.function is not None: + if tool_call_chunk.function.name is not None: + full_tool_calls[idx].name += tool_call_chunk.function.name + if tool_call_chunk.function.arguments is not None: + full_tool_calls[idx].arguments += tool_call_chunk.function.arguments + + if chunk and chunk.usage: + prompt_tokens = chunk.usage.prompt_tokens + + if finish_reason is None: + raise ValueError("No stop reason found") + + content: Union[str, List[FunctionCall]] + + if len(content_deltas) > 1: + content = "".join(content_deltas) + if chunk and chunk.usage: + completion_tokens = chunk.usage.completion_tokens + else: + completion_tokens = 0 + else: + content = list(full_tool_calls.values()) + + usage = RequestUsage( + completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + ) + + result = CreateResult( + finish_reason=finish_reason, # type: ignore + content=content, + usage=usage, + cached=False, + ) + yield result + + def actual_usage(self) -> RequestUsage: + return self._actual_usage + + def total_usage(self) -> RequestUsage: + return self._total_usage + + def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + pass + + def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: + pass + + @property + def capabilities(self) -> ModelCapabilities: + return self._model_capabilities + + def __del__(self): + # TODO: This is a hack to close the open client + try: + asyncio.get_running_loop().create_task(self._client.close()) + except RuntimeError: + asyncio.run(self._client.close()) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_azure/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/_azure/config/__init__.py new file mode 100644 index 000000000000..25b81e9e74f9 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/_azure/config/__init__.py @@ -0,0 +1,37 @@ +from typing import TypedDict, Union, Optional, List, Dict, Any +from azure.ai.inference.models import ( + ChatCompletionsResponseFormat, + ChatCompletionsToolDefinition, + ChatCompletionsToolChoicePreset, + ChatCompletionsNamedToolChoice, +) + +from azure.core.credentials import AzureKeyCredential +from azure.core.credentials_async import AsyncTokenCredential + +from autogen_core.components.models import ModelCapabilities + + +class AzureAIClientArguments(TypedDict, total=False): + endpoint: str + credential: Union[AzureKeyCredential, AsyncTokenCredential] + model_capabilities: ModelCapabilities + + +class AzureAIRequestArguments(TypedDict, total=False): + frequency_penalty: Optional[float] + presence_penalty: Optional[float] + temperature: Optional[float] + top_p: Optional[float] + max_tokens: Optional[int] + response_format: Optional[ChatCompletionsResponseFormat] + stop: Optional[List[str]] + tools: Optional[List[ChatCompletionsToolDefinition]] + tool_choice: Optional[Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]] + seed: Optional[int] + model: Optional[str] + model_extras: Optional[Dict[str, Any]] + + +class AzureAIConfig(AzureAIClientArguments, AzureAIRequestArguments): + pass diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py new file mode 100644 index 000000000000..e69de29bb2d1 From b135d9aae7b6fd0b8b199fc4175f2a4ace4314ef Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Mon, 16 Dec 2024 01:17:20 +0530 Subject: [PATCH 2/7] Moved _azure module to azure --- .../src/autogen_ext/models/azure/__init__.py | 5 +++++ .../{_azure => azure}/_azure_ai_client.py | 21 ++++++++++++------- .../{_azure => azure}/config/__init__.py | 2 +- 3 files changed, 20 insertions(+), 8 deletions(-) create mode 100644 python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py rename python/packages/autogen-ext/src/autogen_ext/models/{_azure => azure}/_azure_ai_client.py (96%) rename python/packages/autogen-ext/src/autogen_ext/models/{_azure => azure}/config/__init__.py (95%) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py new file mode 100644 index 000000000000..08080cacb23e --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py @@ -0,0 +1,5 @@ +from ._azure_ai_client import AzureAIChatCompletionClient + +__all__ = [ + "AzureAIChatCompletionClient" +] \ No newline at end of file diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py similarity index 96% rename from python/packages/autogen-ext/src/autogen_ext/models/_azure/_azure_ai_client.py rename to python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index de2885a0612d..38eb8f5d4c35 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/_azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -26,9 +26,9 @@ ) from typing_extensions import AsyncGenerator, Union -from autogen_core.base import CancellationToken -from autogen_core.components import FunctionCall, Image -from autogen_core.components.models import ( +from autogen_core import CancellationToken +from autogen_core import FunctionCall, Image +from autogen_core.models import ( ChatCompletionClient, LLMMessage, CreateResult, @@ -39,8 +39,8 @@ AssistantMessage, FunctionExecutionResultMessage, ) -from autogen_core.components.tools import Tool, ToolSchema -from autogen_ext.models._azure.config import AzureAIConfig +from autogen_core.tools import Tool, ToolSchema +from autogen_ext.models.azure.config import AzureAIConfig def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]: @@ -313,10 +313,14 @@ async def create_stream( # Otherwise, we try to load the tool calls if choice.delta.tool_calls is not None: for tool_call_chunk in choice.delta.tool_calls: - idx = tool_call_chunk.id + # print(tool_call_chunk) + if "index" in tool_call_chunk: + idx = tool_call_chunk["index"] + else: + idx = tool_call_chunk.id if idx not in full_tool_calls: full_tool_calls[idx] = FunctionCall(id="", arguments="", name="") - # + if tool_call_chunk.id is not None: full_tool_calls[idx].id += tool_call_chunk.id @@ -332,6 +336,9 @@ async def create_stream( if finish_reason is None: raise ValueError("No stop reason found") + if choice and choice.finish_reason is CompletionsFinishReason.TOOL_CALLS: + finish_reason = "function_calls" + content: Union[str, List[FunctionCall]] if len(content_deltas) > 1: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/_azure/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py similarity index 95% rename from python/packages/autogen-ext/src/autogen_ext/models/_azure/config/__init__.py rename to python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py index 25b81e9e74f9..f17ded19488c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/_azure/config/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py @@ -9,7 +9,7 @@ from azure.core.credentials import AzureKeyCredential from azure.core.credentials_async import AsyncTokenCredential -from autogen_core.components.models import ModelCapabilities +from autogen_core.models import ModelCapabilities class AzureAIClientArguments(TypedDict, total=False): From 09c071ea45010a75a0fe37838f0569ec65d655f6 Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Mon, 16 Dec 2024 11:27:18 +0530 Subject: [PATCH 3/7] Validate extra_create_args in and json response --- .../src/autogen_ext/models/azure/__init__.py | 4 +- .../models/azure/_azure_ai_client.py | 71 ++++++++++--------- 2 files changed, 40 insertions(+), 35 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py index 08080cacb23e..02d4392e5a8b 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py @@ -1,5 +1,3 @@ from ._azure_ai_client import AzureAIChatCompletionClient -__all__ = [ - "AzureAIChatCompletionClient" -] \ No newline at end of file +__all__ = ["AzureAIChatCompletionClient"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index 38eb8f5d4c35..95e24c512c3e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -1,7 +1,7 @@ import asyncio from asyncio import Task from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast - +from inspect import getfullargspec from azure.ai.inference.aio import ChatCompletionsClient from azure.ai.inference.models import ( ChatCompletions, @@ -42,6 +42,10 @@ from autogen_core.tools import Tool, ToolSchema from autogen_ext.models.azure.config import AzureAIConfig +create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs) + + +# create_args def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]: result: List[ChatCompletionsToolDefinition] = [] @@ -123,23 +127,23 @@ def to_azure_message(message: LLMMessage): else: return _tool_message_to_azure(message) - +# TODO: Add Support for Github Models class AzureAIChatCompletionClient(ChatCompletionClient): def __init__(self, **kwargs: Unpack[AzureAIConfig]): if "endpoint" not in kwargs: - raise ValueError("endpoint must be provided") + raise ValueError("endpoint is required for AzureAIChatCompletionClient") if "credential" not in kwargs: - raise ValueError("credential must be provided") + raise ValueError("credential is required for AzureAIChatCompletionClient") if "model_capabilities" not in kwargs: - raise ValueError("model_capabilities must be provided") + raise ValueError("model_capabilities is required for AzureAIChatCompletionClient") - self._model_capabilities = kwargs["model_capabilities"] # TODO: Change _endpoint = kwargs.pop("endpoint") _credential = kwargs.pop("credential") - self.create_args = kwargs.copy() + self._model_capabilities = kwargs.pop("model_capabilities") + self._create_args = kwargs.copy() - self._client = ChatCompletionsClient(_endpoint, _credential, **self.create_args) + self._client = ChatCompletionsClient(_endpoint, _credential, **self._create_args) self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) @@ -151,22 +155,26 @@ async def create( extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, ) -> CreateResult: - # TODO: Validate Args + extra_create_args_keys = set(extra_create_args.keys()) + if not create_kwargs.issuperset(extra_create_args_keys): + raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}") + + # Copy the create args and overwrite anything in extra_create_args + create_args = self._create_args.copy() + create_args.update(extra_create_args) if self.capabilities["vision"] is False: for message in messages: if isinstance(message, UserMessage): if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): raise ValueError("Model does not support vision and image was provided") - args = {} if json_output is not None: if self.capabilities["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output") if json_output is True: - # TODO: JSON OUTPUT - args["response_format"] = ChatCompletionsResponseFormatJSON() + create_args["response_format"] = ChatCompletionsResponseFormatJSON() if self.capabilities["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output") @@ -184,16 +192,14 @@ async def create( self._client.complete( messages=azure_messages, tools=converted_tools, - # TODO: Add extra_create_args + **create_args ) ) else: task = asyncio.create_task( self._client.complete( messages=azure_messages, - max_tokens=20, - **args, - # TODO: Add extra_create_args + **create_args, ) ) @@ -240,22 +246,26 @@ async def create_stream( extra_create_args: Mapping[str, Any] = {}, cancellation_token: Optional[CancellationToken] = None, ) -> AsyncGenerator[Union[str, CreateResult], None]: - # TODO: Validate Args + extra_create_args_keys = set(extra_create_args.keys()) + if not create_kwargs.issuperset(extra_create_args_keys): + raise ValueError(f"Extra create args are invalid: {extra_create_args_keys - create_kwargs}") + + create_args = self._create_args.copy() + create_args.update(extra_create_args) if self.capabilities["vision"] is False: - for message in messages: - if isinstance(message, UserMessage): - if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): - raise ValueError("Model does not support vision and image was provided") - args = {} + for message in messages: + if isinstance(message, UserMessage): + if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): + raise ValueError("Model does not support vision and image was provided") + if json_output is not None: if self.capabilities["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output") if json_output is True: - # TODO: JSON OUTPUT - args["response_format"] = ChatCompletionsResponseFormatJSON() + create_args["response_format"] = ChatCompletionsResponseFormatJSON() if self.capabilities["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output") @@ -275,7 +285,7 @@ async def create_stream( messages=azure_messages, tools=converted_tools, stream=True, - # TODO: Add extra_create_args + **create_args ) ) else: @@ -284,8 +294,7 @@ async def create_stream( messages=azure_messages, max_tokens=20, stream=True, - **args, - # TODO: Add extra_create_args + **create_args ) ) @@ -300,9 +309,7 @@ async def create_stream( completion_tokens = 0 chunk: Optional[StreamingChatCompletionsUpdate] = None async for chunk in await task: - choice = (chunk.choices[0] - if len(chunk.choices) > 0 - else cast(StreamingChatCompletionsUpdate, None)) + choice = chunk.choices[0] if len(chunk.choices) > 0 else cast(StreamingChatCompletionsUpdate, None) if choice.finish_reason is not None: finish_reason = choice.finish_reason.value @@ -370,10 +377,10 @@ def total_usage(self) -> RequestUsage: return self._total_usage def count_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: - pass + return 0 def remaining_tokens(self, messages: Sequence[LLMMessage], tools: Sequence[Tool | ToolSchema] = []) -> int: - pass + return 0 @property def capabilities(self) -> ModelCapabilities: From a24901cca0edd5a509a7e4fa26add5669c061a84 Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Mon, 16 Dec 2024 11:36:29 +0530 Subject: [PATCH 4/7] Added Support for Github Models --- .../models/azure/_azure_ai_client.py | 40 +++++++------------ .../models/azure/config/__init__.py | 7 ++-- 2 files changed, 19 insertions(+), 28 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index 95e24c512c3e..41da4b8b2dc2 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -40,12 +40,14 @@ FunctionExecutionResultMessage, ) from autogen_core.tools import Tool, ToolSchema -from autogen_ext.models.azure.config import AzureAIConfig +from autogen_ext.models.azure.config import AzureAIChatCompletionClientConfig, GITHUB_MODELS_ENDPOINT create_kwargs = set(getfullargspec(ChatCompletionsClient.complete).kwonlyargs) -# create_args +def _is_github_model(endpoint: str) -> bool: + return endpoint == GITHUB_MODELS_ENDPOINT + def convert_tools(tools: Sequence[Tool | ToolSchema]) -> List[ChatCompletionsToolDefinition]: result: List[ChatCompletionsToolDefinition] = [] @@ -127,15 +129,18 @@ def to_azure_message(message: LLMMessage): else: return _tool_message_to_azure(message) + # TODO: Add Support for Github Models class AzureAIChatCompletionClient(ChatCompletionClient): - def __init__(self, **kwargs: Unpack[AzureAIConfig]): + def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]): if "endpoint" not in kwargs: raise ValueError("endpoint is required for AzureAIChatCompletionClient") if "credential" not in kwargs: raise ValueError("credential is required for AzureAIChatCompletionClient") if "model_capabilities" not in kwargs: raise ValueError("model_capabilities is required for AzureAIChatCompletionClient") + if _is_github_model(kwargs['endpoint']) and "model" not in kwargs: + raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient") # TODO: Change _endpoint = kwargs.pop("endpoint") @@ -189,11 +194,7 @@ async def create( if len(tools) > 0: converted_tools = convert_tools(tools) task = asyncio.create_task( - self._client.complete( - messages=azure_messages, - tools=converted_tools, - **create_args - ) + self._client.complete(messages=azure_messages, tools=converted_tools, **create_args) ) else: task = asyncio.create_task( @@ -254,11 +255,10 @@ async def create_stream( create_args.update(extra_create_args) if self.capabilities["vision"] is False: - for message in messages: - if isinstance(message, UserMessage): - if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): - raise ValueError("Model does not support vision and image was provided") - + for message in messages: + if isinstance(message, UserMessage): + if isinstance(message.content, list) and any(isinstance(x, Image) for x in message.content): + raise ValueError("Model does not support vision and image was provided") if json_output is not None: if self.capabilities["json_output"] is False and json_output is True: @@ -281,21 +281,11 @@ async def create_stream( if len(tools) > 0: converted_tools = convert_tools(tools) task = asyncio.create_task( - self._client.complete( - messages=azure_messages, - tools=converted_tools, - stream=True, - **create_args - ) + self._client.complete(messages=azure_messages, tools=converted_tools, stream=True, **create_args) ) else: task = asyncio.create_task( - self._client.complete( - messages=azure_messages, - max_tokens=20, - stream=True, - **create_args - ) + self._client.complete(messages=azure_messages, max_tokens=20, stream=True, **create_args) ) if cancellation_token is not None: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py index f17ded19488c..6a60f07db3af 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/config/__init__.py @@ -8,9 +8,10 @@ from azure.core.credentials import AzureKeyCredential from azure.core.credentials_async import AsyncTokenCredential - from autogen_core.models import ModelCapabilities +GITHUB_MODELS_ENDPOINT = "https://models.inference.ai.azure.com" + class AzureAIClientArguments(TypedDict, total=False): endpoint: str @@ -18,7 +19,7 @@ class AzureAIClientArguments(TypedDict, total=False): model_capabilities: ModelCapabilities -class AzureAIRequestArguments(TypedDict, total=False): +class AzureAICreateArguments(TypedDict, total=False): frequency_penalty: Optional[float] presence_penalty: Optional[float] temperature: Optional[float] @@ -33,5 +34,5 @@ class AzureAIRequestArguments(TypedDict, total=False): model_extras: Optional[Dict[str, Any]] -class AzureAIConfig(AzureAIClientArguments, AzureAIRequestArguments): +class AzureAIChatCompletionClientConfig(AzureAIClientArguments, AzureAICreateArguments): pass From bacab86e4c6a13c537710d059947d02303490479 Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Mon, 16 Dec 2024 11:42:45 +0530 Subject: [PATCH 5/7] Added normalize_name and assert_valid name --- .../models/azure/_azure_ai_client.py | 30 ++++++++++++++++--- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index 41da4b8b2dc2..4cd058808ddd 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -1,4 +1,5 @@ import asyncio +import re from asyncio import Task from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast from inspect import getfullargspec @@ -88,7 +89,7 @@ def _system_message_to_azure(message: SystemMessage) -> AzureSystemMessage: def _user_message_to_azure(message: UserMessage) -> AzureUserMessage: - # assert_valid_name(message.source) + assert_valid_name(message.source) if isinstance(message.content, str): return AzureUserMessage(content=message.content) else: @@ -106,7 +107,7 @@ def _user_message_to_azure(message: UserMessage) -> AzureUserMessage: def _assistant_message_to_azure(message: AssistantMessage) -> AzureAssistantMessage: - # assert_valid_name(message.source) + assert_valid_name(message.source) if isinstance(message.content, list): return AzureAssistantMessage( tool_calls=[_func_call_to_azure(x) for x in message.content], @@ -130,7 +131,28 @@ def to_azure_message(message: LLMMessage): return _tool_message_to_azure(message) -# TODO: Add Support for Github Models +def normalize_name(name: str) -> str: + """ + LLMs sometimes ask functions while ignoring their own format requirements, this function should be used to replace invalid characters with "_". + + Prefer _assert_valid_name for validating user configuration or input + """ + return re.sub(r"[^a-zA-Z0-9_-]", "_", name)[:64] + + +def assert_valid_name(name: str) -> str: + """ + Ensure that configured names are valid, raises ValueError if not. + + For munging LLM responses use _normalize_name to ensure LLM specified names don't break the API. + """ + if not re.match(r"^[a-zA-Z0-9_-]+$", name): + raise ValueError(f"Invalid name: {name}. Only letters, numbers, '_' and '-' are allowed.") + if len(name) > 64: + raise ValueError(f"Invalid name: {name}. Name must be less than 64 characters.") + return name + + class AzureAIChatCompletionClient(ChatCompletionClient): def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]): if "endpoint" not in kwargs: @@ -222,7 +244,7 @@ async def create( FunctionCall( id=x.id, arguments=x.function.arguments, - name=x.function.name, + name=normalize_name(x.function.name), ) for x in choice.message.tool_calls ] From 06d3f9522ca46e93375601c22e35701cedb47ad8 Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Mon, 16 Dec 2024 23:03:59 +0530 Subject: [PATCH 6/7] Added Tests for AzureAIChatCompletionClient --- .../models/test_azure_ai_model_client.py | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py index e69de29bb2d1..fd888d7f4b45 100644 --- a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -0,0 +1,158 @@ +import asyncio +from datetime import datetime +from typing import AsyncGenerator, Any + +import pytest +from azure.ai.inference.aio import ( + ChatCompletionsClient, +) + +from azure.ai.inference.models import ( + ChatChoice, + ChatResponseMessage, + CompletionsUsage, + +) +from azure.ai.inference.models import (ChatCompletions, + StreamingChatCompletionsUpdate, StreamingChatChoiceUpdate, + StreamingChatResponseMessageUpdate) + +from azure.core.credentials import AzureKeyCredential + +from autogen_core import CancellationToken +from autogen_core.models import UserMessage +from autogen_ext.models.azure import AzureAIChatCompletionClient + + +async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[StreamingChatCompletionsUpdate, None]: + mock_chunks_content = ["Hello", " Another Hello", " Yet Another Hello"] + + mock_chunks = [ + StreamingChatChoiceUpdate( + index=0, + finish_reason="stop", + delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content), + ) for chunk_content in mock_chunks_content + ] + + for mock_chunk in mock_chunks: + await asyncio.sleep(0.1) + yield StreamingChatCompletionsUpdate( + id="id", + choices=[mock_chunk], + created=datetime.now(), + model="model", + usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + + +async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]: + stream = kwargs.get("stream", False) + + if not stream: + await asyncio.sleep(0.1) + return ChatCompletions( + id="id", + created=datetime.now(), + model='model', + choices=[ + ChatChoice( + index=0, + finish_reason="stop", + message=ChatResponseMessage(content="Hello", role="assistant") + ) + ], + usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + ) + else: + return _mock_create_stream(*args, **kwargs) + + + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client() -> None: + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_capabilities = { + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + assert client + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.MonkeyPatch) -> None: + # monkeypatch.setattr(AsyncCompletions, "create", _mock_create) + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_capabilities = { + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + result = await client.create(messages=[UserMessage(content="Hello", source="user")]) + assert result.content == "Hello" + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) + chunks = [] + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_capabilities = { + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + async for chunk in client.create_stream(messages=[UserMessage(content="Hello", source="user")]): + chunks.append(chunk) + + assert chunks[0] == "Hello" + assert chunks[1] == " Another Hello" + assert chunks[2] == " Yet Another Hello" + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) + cancellation_token = CancellationToken() + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_capabilities={ + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + task = asyncio.create_task( + client.create(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token) + ) + cancellation_token.cancel() + with pytest.raises(asyncio.CancelledError): + await task + +@pytest.mark.asyncio +async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) + cancellation_token = CancellationToken() + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_capabilities={ + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + stream=client.create_stream(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token) + cancellation_token.cancel() + with pytest.raises(asyncio.CancelledError): + async for _ in stream: + pass From daf43de5283783cb102f16eebeda2d920a446f29 Mon Sep 17 00:00:00 2001 From: Rohan Thacker Date: Tue, 24 Dec 2024 13:15:38 +0530 Subject: [PATCH 7/7] WIP: Azure AI Client * Added: object-level usage data * Added: doc string * Added: check existing response_format value * Added: _validate_config and _create_client --- .../autogen-core/docs/src/reference/index.md | 1 + .../python/autogen_ext.models.azure.rst | 8 ++ .../src/autogen_ext/models/azure/__init__.py | 3 +- .../models/azure/_azure_ai_client.py | 109 +++++++++++++++--- .../models/test_azure_ai_model_client.py | 42 ++++--- 5 files changed, 131 insertions(+), 32 deletions(-) create mode 100644 python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst diff --git a/python/packages/autogen-core/docs/src/reference/index.md b/python/packages/autogen-core/docs/src/reference/index.md index cfe36eded2c2..4893b9964b93 100644 --- a/python/packages/autogen-core/docs/src/reference/index.md +++ b/python/packages/autogen-core/docs/src/reference/index.md @@ -48,6 +48,7 @@ python/autogen_ext.agents.video_surfer.tools python/autogen_ext.teams.magentic_one python/autogen_ext.models.openai python/autogen_ext.models.replay +python/autogen_ext.models.azure python/autogen_ext.tools.langchain python/autogen_ext.code_executors.local python/autogen_ext.code_executors.docker diff --git a/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst new file mode 100644 index 000000000000..64c16a5a57d4 --- /dev/null +++ b/python/packages/autogen-core/docs/src/reference/python/autogen_ext.models.azure.rst @@ -0,0 +1,8 @@ +autogen\_ext.models.azure +========================== + + +.. automodule:: autogen_ext.models.azure + :members: + :undoc-members: + :show-inheritance: diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py index 02d4392e5a8b..2dc7b9c70a98 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/__init__.py @@ -1,3 +1,4 @@ from ._azure_ai_client import AzureAIChatCompletionClient +from .config import AzureAIChatCompletionClientConfig -__all__ = ["AzureAIChatCompletionClient"] +__all__ = ["AzureAIChatCompletionClient", "AzureAIChatCompletionClientConfig"] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py index 4cd058808ddd..1297060601fa 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/azure/_azure_ai_client.py @@ -1,5 +1,6 @@ import asyncio import re +import warnings from asyncio import Task from typing import Sequence, Optional, Mapping, Any, List, Unpack, Dict, cast from inspect import getfullargspec @@ -154,25 +155,95 @@ def assert_valid_name(name: str) -> str: class AzureAIChatCompletionClient(ChatCompletionClient): + """ + Chat completion client for models hosted on Azure AI Foundry or GitHub Models. + See `here `_ for more info. + + Args: + endpoint (str): The endpoint to use. **Required.** + credentials (union, AzureKeyCredential, AsyncTokenCredential): The credentials to use. **Required** + model_capabilities (ModelCapabilities): The capabilities of the model. **Required.** + model (str): The name of the model. **Required if model is hosted on GitHub Models.** + frequency_penalty: (optional,float) + presence_penalty: (optional,float) + temperature: (optional,float) + top_p: (optional,float) + max_tokens: (optional,int) + response_format: (optional,ChatCompletionsResponseFormat) + stop: (optional,List[str]) + tools: (optional,List[ChatCompletionsToolDefinition]) + tool_choice: (optional,Union[str, ChatCompletionsToolChoicePreset, ChatCompletionsNamedToolChoice]]) + seed: (optional,int) + model_extras: (optional,Dict[str, Any]) + + To use this client, you must install the `azure-ai-inference` extension: + + .. code-block:: bash + + pip install 'autogen-ext[azure-ai-inference]==0.4.0.dev11' + + The following code snippet shows how to use the client: + + .. code-block:: python + + from azure.core.credentials import AzureKeyCredential + from autogen_ext.models.azure import AzureAIChatCompletionClient + from autogen_core.models import UserMessage + + client = AzureAIChatCompletionClient( + endpoint="endpoint", + credential=AzureKeyCredential("api_key"), + model_capabilities={ + "json_output": False, + "function_calling": False, + "vision": False, + }, + ) + + result = await client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore + print(result) + + """ + def __init__(self, **kwargs: Unpack[AzureAIChatCompletionClientConfig]): - if "endpoint" not in kwargs: + config = self._validate_config(kwargs) + self._model_capabilities = config["model_capabilities"] + self._client = self._create_client(config) + self._create_args = self._prepare_create_args(config) + + self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + + @staticmethod + def _validate_config(config: Dict) -> AzureAIChatCompletionClientConfig: + if "endpoint" not in config: raise ValueError("endpoint is required for AzureAIChatCompletionClient") - if "credential" not in kwargs: + if "credential" not in config: raise ValueError("credential is required for AzureAIChatCompletionClient") - if "model_capabilities" not in kwargs: + if "model_capabilities" not in config: raise ValueError("model_capabilities is required for AzureAIChatCompletionClient") - if _is_github_model(kwargs['endpoint']) and "model" not in kwargs: + if _is_github_model(config["endpoint"]) and "model" not in config: raise ValueError("model is required for when using a Github model with AzureAIChatCompletionClient") - - # TODO: Change - _endpoint = kwargs.pop("endpoint") - _credential = kwargs.pop("credential") - self._model_capabilities = kwargs.pop("model_capabilities") - self._create_args = kwargs.copy() - - self._client = ChatCompletionsClient(_endpoint, _credential, **self._create_args) - self._actual_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) - self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + return config + + @staticmethod + def _create_client(config: AzureAIChatCompletionClientConfig): + return ChatCompletionsClient(**config) + + @staticmethod + def _prepare_create_args(config: Mapping[str, Any]) -> Mapping[str, Any]: + create_args = {k: v for k, v in config.items() if k in create_kwargs} + return create_args + # self._endpoint = config.pop("endpoint") + # self._credential = config.pop("credential") + # self._model_capabilities = config.pop("model_capabilities") + # self._create_args = config.copy() + + def add_usage(self, usage: RequestUsage): + self._total_usage = RequestUsage( + self._total_usage.prompt_tokens + usage.prompt_tokens, + self._total_usage.completion_tokens + usage.completion_tokens, + ) async def create( self, @@ -200,7 +271,7 @@ async def create( if self.capabilities["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output") - if json_output is True: + if json_output is True and "response_format" not in create_args: create_args["response_format"] = ChatCompletionsResponseFormatJSON() if self.capabilities["json_output"] is False and json_output is True: @@ -259,6 +330,9 @@ async def create( usage=usage, cached=False, ) + + self.add_usage(usage) + return response async def create_stream( @@ -286,7 +360,7 @@ async def create_stream( if self.capabilities["json_output"] is False and json_output is True: raise ValueError("Model does not support JSON output") - if json_output is True: + if json_output is True and "response_format" not in create_args: create_args["response_format"] = ChatCompletionsResponseFormatJSON() if self.capabilities["json_output"] is False and json_output is True: @@ -380,6 +454,9 @@ async def create_stream( usage=usage, cached=False, ) + + self.add_usage(usage) + yield result def actual_usage(self) -> RequestUsage: diff --git a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py index fd888d7f4b45..22bd7cf74ee1 100644 --- a/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_azure_ai_model_client.py @@ -7,15 +7,20 @@ ChatCompletionsClient, ) + from azure.ai.inference.models import ( ChatChoice, ChatResponseMessage, CompletionsUsage, + ChatCompletionsResponseFormatJSON, +) +from azure.ai.inference.models import ( + ChatCompletions, + StreamingChatCompletionsUpdate, + StreamingChatChoiceUpdate, + StreamingChatResponseMessageUpdate, ) -from azure.ai.inference.models import (ChatCompletions, - StreamingChatCompletionsUpdate, StreamingChatChoiceUpdate, - StreamingChatResponseMessageUpdate) from azure.core.credentials import AzureKeyCredential @@ -32,7 +37,8 @@ async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[Strea index=0, finish_reason="stop", delta=StreamingChatResponseMessageUpdate(role="assistant", content=chunk_content), - ) for chunk_content in mock_chunks_content + ) + for chunk_content in mock_chunks_content ] for mock_chunk in mock_chunks: @@ -46,7 +52,9 @@ async def _mock_create_stream(*args: Any, **kwargs: Any) -> AsyncGenerator[Strea ) -async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]: +async def _mock_create( + *args: Any, **kwargs: Any +) -> ChatCompletions | AsyncGenerator[StreamingChatCompletionsUpdate, None]: stream = kwargs.get("stream", False) if not stream: @@ -54,12 +62,10 @@ async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGene return ChatCompletions( id="id", created=datetime.now(), - model='model', + model="model", choices=[ ChatChoice( - index=0, - finish_reason="stop", - message=ChatResponseMessage(content="Hello", role="assistant") + index=0, finish_reason="stop", message=ChatResponseMessage(content="Hello", role="assistant") ) ], usage=CompletionsUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), @@ -68,20 +74,21 @@ async def _mock_create(*args: Any, **kwargs: Any) -> ChatCompletions | AsyncGene return _mock_create_stream(*args, **kwargs) - @pytest.mark.asyncio async def test_azure_ai_chat_completion_client() -> None: client = AzureAIChatCompletionClient( endpoint="endpoint", credential=AzureKeyCredential("api_key"), - model_capabilities = { + model_capabilities={ "json_output": False, "function_calling": False, "vision": False, }, + model="model", ) assert client + @pytest.mark.asyncio async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.MonkeyPatch) -> None: # monkeypatch.setattr(AsyncCompletions, "create", _mock_create) @@ -89,7 +96,7 @@ async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.Monkey client = AzureAIChatCompletionClient( endpoint="endpoint", credential=AzureKeyCredential("api_key"), - model_capabilities = { + model_capabilities={ "json_output": False, "function_calling": False, "vision": False, @@ -98,14 +105,15 @@ async def test_azure_ai_chat_completion_client_create(monkeypatch: pytest.Monkey result = await client.create(messages=[UserMessage(content="Hello", source="user")]) assert result.content == "Hello" + @pytest.mark.asyncio -async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest.MonkeyPatch) -> None: +async def test_azure_ai_chat_completion_client_create_stream(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) chunks = [] client = AzureAIChatCompletionClient( endpoint="endpoint", credential=AzureKeyCredential("api_key"), - model_capabilities = { + model_capabilities={ "json_output": False, "function_calling": False, "vision": False, @@ -118,6 +126,7 @@ async def test_azure_ai_chat_completion_client_create_stream(monkeypatch:pytest. assert chunks[1] == " Another Hello" assert chunks[2] == " Yet Another Hello" + @pytest.mark.asyncio async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) @@ -138,6 +147,7 @@ async def test_azure_ai_chat_completion_client_create_cancel(monkeypatch: pytest with pytest.raises(asyncio.CancelledError): await task + @pytest.mark.asyncio async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(ChatCompletionsClient, "complete", _mock_create) @@ -151,7 +161,9 @@ async def test_azure_ai_chat_completion_client_create_stream_cancel(monkeypatch: "vision": False, }, ) - stream=client.create_stream(messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token) + stream = client.create_stream( + messages=[UserMessage(content="Hello", source="user")], cancellation_token=cancellation_token + ) cancellation_token.cancel() with pytest.raises(asyncio.CancelledError): async for _ in stream: