From 77ce9ec5b0e1c7a10f68e2ac78ccedb3442bffaf Mon Sep 17 00:00:00 2001 From: "Rossdan Craig rossdan@lastmileai.dev" <> Date: Wed, 27 Dec 2023 15:49:18 -0500 Subject: [PATCH] [python] Use OutputDataWithValue type for function calls This is extending https://github.com/lastmile-ai/aiconfig/pull/605 where I converted from model-specific response types --> pure strings. Now I'm going to also support function calls --> OutputData types Same as last diff (https://github.com/lastmile-ai/aiconfig/pull/610) except now for python files instead of typescript Don't need to do this for the image outputs since I already did that in https://github.com/lastmile-ai/aiconfig/pull/608 Actually like typescript, the only model parser we need to support for function calling is `openai.py` --- .../src/aiconfig_extension_gemini/Gemini.py | 37 +++-- .../local_inference/text_generation.py | 44 +++-- .../text_generation.py | 53 +++--- .../LLamaGuard.py | 49 ++++-- extensions/llama/python/llama.py | 64 +++++-- python/src/aiconfig/default_parsers/hf.py | 33 ++-- python/src/aiconfig/default_parsers/openai.py | 156 +++++++++++++----- python/src/aiconfig/default_parsers/palm.py | 24 ++- python/tests/parsers/test_openai_util.py | 10 +- 9 files changed, 329 insertions(+), 141 deletions(-) diff --git a/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py b/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py index de17fb5c7..3f33ba125 100644 --- a/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py +++ b/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py @@ -1,17 +1,27 @@ # Define a Model Parser for LLama-Guard from typing import TYPE_CHECKING, Dict, List, Optional, Any import copy +import json import google.generativeai as genai -import copy +from google.generativeai.types import content_types +from google.protobuf.json_format import MessageToDict +from aiconfig import ( + AIConfigRuntime, + CallbackEvent, + get_api_key_from_environment, +) from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ExecuteResult, Output, Prompt +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, + PromptInput, +) from aiconfig.util.params import resolve_prompt, resolve_prompt_string -from aiconfig import CallbackEvent, get_api_key_from_environment, AIConfigRuntime, PromptMetadata, PromptInput -from google.generativeai.types import content_types -from google.protobuf.json_format import MessageToDict # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -218,8 +228,8 @@ async def serialize( ExecuteResult( **{ "output_type": "execute_result", - "data": model_message_parts[0], - "metadata": {"rawResponse": model_message} + "data": model_message_parts[0], + "metadata": {"rawResponse": model_message}, } ) ] @@ -359,11 +369,14 @@ def get_output_text( output_data = output.data if isinstance(output_data, str): return output_data - else: - raise ValueError("Not Implemented") - - else: - return "" + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # Gemini does not support function calls so shouldn't + # get here, but just being safe + return json.dumps(output_data.value, indent=2) + raise ValueError("Not Implemented") + return "" def _construct_chat_history( self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py index 372f67480..aac197de8 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py @@ -1,11 +1,23 @@ import copy -from typing import TYPE_CHECKING, Any, Dict, List, Optional -from transformers import AutoTokenizer, Pipeline, pipeline, TextIteratorStreamer +import json import threading +from typing import TYPE_CHECKING, Any, Dict, List, Optional +from transformers import ( + AutoTokenizer, + Pipeline, + pipeline, + TextIteratorStreamer, +) from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, + PromptMetadata, +) from aiconfig.util.params import resolve_prompt # Circuluar Dependency Type Hints @@ -105,19 +117,20 @@ def construct_stream_output( options (InferenceOptions): The inference options. Used to determine the stream callback. """ - accumulated_message = "" output = ExecuteResult( **{ "output_type": "execute_result", - "data": accumulated_message, + "data": "", # We update this below "execution_count": 0, #Multiple outputs are not supported for streaming "metadata": {}, } ) + accumulated_message = "" for new_text in streamer: - accumulated_message += new_text - options.stream_callback(new_text, accumulated_message, 0) - output.data = accumulated_message + if isinstance(new_text, str): + accumulated_message += new_text + options.stream_callback(new_text, accumulated_message, 0) + output.data = accumulated_message return output @@ -283,7 +296,14 @@ def get_output_text( # TODO (rossdanlm): Handle multiple outputs in list # https://github.com/lastmile-ai/aiconfig/issues/467 if output.output_type == "execute_result": - if isinstance(output.data, str): - return output.data - else: - return "" + output_data = output.data + if isinstance(output_data, str): + return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # HuggingFace Text generation does not support function + # calls so shouldn't get here, but just being safe + return json.dumps(output_data.value, indent=2) + raise ValueError("Not Implemented") + return "" diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py index 010a22ec9..e1fb7ce9e 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py @@ -1,9 +1,6 @@ import copy -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions -from aiconfig.util.config_utils import get_api_key_from_environment +import json +from typing import TYPE_CHECKING, Any, Iterable, List, Optional # HuggingFace API imports from huggingface_hub import InferenceClient @@ -12,13 +9,19 @@ TextGenerationStreamResponse, ) -from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata from aiconfig import CallbackEvent - +from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser +from aiconfig.model_parser import InferenceOptions +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, + PromptMetadata, +) +from aiconfig.util.config_utils import get_api_key_from_environment from aiconfig.util.params import resolve_prompt -# ModelParser Utils -# Type hint imports # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -63,7 +66,7 @@ def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, A def construct_stream_output( - response: TextGenerationStreamResponse, + response: Iterable[TextGenerationStreamResponse], response_includes_details: bool, options: InferenceOptions, ) -> Output: @@ -79,26 +82,25 @@ def construct_stream_output( accumulated_message = "" for iteration in response: metadata = {} - data = iteration + # If response_includes_details is false, `iteration` will be a string, + # otherwise, `iteration` is a TextGenerationStreamResponse + new_text = iteration if response_includes_details: iteration: TextGenerationStreamResponse - data = iteration.token.text + new_text = iteration.token.text metadata = {"token": iteration.token, "details": iteration.details} - else: - data: str # Reduce - accumulated_message += data + accumulated_message += new_text index = 0 # HF Text Generation api doesn't support multiple outputs - delta = data if options and options.stream_callback: - options.stream_callback(delta, accumulated_message, index) + options.stream_callback(new_text, accumulated_message, index) output = ExecuteResult( **{ "output_type": "execute_result", - "data": copy.deepcopy(accumulated_message), + "data": accumulated_message, "execution_count": index, "metadata": metadata, } @@ -316,7 +318,14 @@ def get_output_text( return "" if output.output_type == "execute_result": - if isinstance(output.data, str): - return output.data - else: - return "" + output_data = output.data + if isinstance(output_data, str): + return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # HuggingFace Text generation does not support function + # calls so shouldn't get here, but just being safe + return json.dumps(output_data.value, indent=2) + raise ValueError("Not Implemented") + return "" diff --git a/extensions/LLama-Guard/python/python/src/aiconfig_extension_llama_guard/LLamaGuard.py b/extensions/LLama-Guard/python/python/src/aiconfig_extension_llama_guard/LLamaGuard.py index 6a62cea2e..816accf18 100644 --- a/extensions/LLama-Guard/python/python/src/aiconfig_extension_llama_guard/LLamaGuard.py +++ b/extensions/LLama-Guard/python/python/src/aiconfig_extension_llama_guard/LLamaGuard.py @@ -1,26 +1,29 @@ -# Define a Model Parser for LLama-Guard - +"""Define a Model Parser for LLama-Guard""" import copy +import json from typing import TYPE_CHECKING, Any, Dict, List, Optional -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + +from aiconfig import CallbackEvent from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, + PromptMetadata, +) from aiconfig.util.params import resolve_prompt -from aiconfig import CallbackEvent - -from transformers import AutoTokenizer, AutoModelForCausalLM -import torch # Circuluar Dependency Type Hints if TYPE_CHECKING: from aiconfig.Config import AIConfigRuntime - - # Step 1: define Helpers def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: """ @@ -242,17 +245,21 @@ async def run_inference( output_text = self.tokenizer.decode( response[0][prompt_len:], skip_special_tokens=True ) - - Output = ExecuteResult( + output_data_content : str = '' + if isinstance(output_text, str): + output_data_content = output_text + else: + raise ValueError(f"Output {output_text} needs to be of type 'str' but is of type: {type(output_text)}") + output = ExecuteResult( **{ "output_type": "execute_result", - "data": output_text, + "data": output_data_content, "execution_count": 0, "metadata": {}, } ) - prompt.outputs = [Output] + prompt.outputs = [output] return prompt.outputs def get_output_text( @@ -268,7 +275,13 @@ def get_output_text( return "" if output.output_type == "execute_result": - if isinstance(output.data, str): - return output.data - else: - return "" + output_data = output.data + if isinstance(output_data, str): + return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # LlamaGuard does not support function + # calls so shouldn't get here, but just being safe + return json.dumps(output_data.value, indent=2) + return "" diff --git a/extensions/llama/python/llama.py b/extensions/llama/python/llama.py index daf11ea9f..232f5920c 100644 --- a/extensions/llama/python/llama.py +++ b/extensions/llama/python/llama.py @@ -1,14 +1,18 @@ +import json from typing import Any, List from aiconfig.Config import AIConfigRuntime from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions +from aiconfig.schema import ( + ExecuteResult, + OutputDataWithValue, + Output, + Prompt, +) from aiconfig.util.params import resolve_prompt from llama_cpp import Llama -from aiconfig import Output, Prompt -from aiconfig.schema import ExecuteResult - class LlamaModelParser(ParameterizedModelParser): def __init__(self, model_path: str) -> None: @@ -83,9 +87,15 @@ async def _run_inference_helper(self, model_input, options) -> List[Output]: if options: options.stream_callback(data, acc, index) print(flush=True) + + output_data_value : str = '' + if isinstance(acc, str): + output_data_value = acc + else: + raise ValueError(f"Output {acc} needs to be of type 'str' but is of type: {type(acc)}") return ExecuteResult( output_type="execute_result", - data=acc, + data=output_data_value, metadata={} ) else: @@ -95,25 +105,43 @@ async def _run_inference_helper(self, model_input, options) -> List[Output]: except TypeError: texts = [response["choices"][0]["text"]] + output_data_value : str = '' + last_response = texts[-1] + if isinstance(last_response, str): + output_data_value = last_response + else: + raise ValueError(f"Output {last_response} needs to be of type 'str' but is of type: {type(last_response)}") return ExecuteResult( output_type="execute_result", - data=texts[-1], + data=output_data_value, metadata={} ) def get_output_text( self, prompt: Prompt, aiconfig: AIConfigRuntime, output: Output | None = None ) -> str: - match output: - case ExecuteResult(data=d): - if isinstance(d, str): - return d - - # Doing this to be backwards-compatible with old output format - # where we used to save a list in output.data - # See https://github.com/lastmile-ai/aiconfig/issues/630 - elif isinstance(d, list): - return d - return "" - case _: - raise ValueError(f"Unexpected output type: {type(output)}") + if not output: + output = aiconfig.get_latest_output(prompt) + + if not output: + return "" + + if output.output_type == "execute_result": + output_data = output.data + if isinstance(output_data, str): + return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # HuggingFace Text generation does not support function + # calls so shouldn't get here, but just being safe + return json.dumps(output_data.value, indent=2) + + # Doing this to be backwards-compatible with old output format + # where we used to save a list in output.data + # See https://github.com/lastmile-ai/aiconfig/issues/630 + if isinstance(output_data, list): + # TODO: This is an error since list is not compatible with str type + return output_data + return "" + raise ValueError(f"Output is an unexpected output type: {type(output)}") diff --git a/python/src/aiconfig/default_parsers/hf.py b/python/src/aiconfig/default_parsers/hf.py index c07d01824..4c2467cee 100644 --- a/python/src/aiconfig/default_parsers/hf.py +++ b/python/src/aiconfig/default_parsers/hf.py @@ -1,10 +1,7 @@ import copy +import json from typing import TYPE_CHECKING, Any, Dict, List, Optional -from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser -from aiconfig.model_parser import InferenceOptions -from aiconfig.util.config_utils import get_api_key_from_environment - # HuggingFace API imports from huggingface_hub import InferenceClient from huggingface_hub.inference._text_generation import ( @@ -12,13 +9,19 @@ TextGenerationStreamResponse, ) -from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata from aiconfig import CallbackEvent - +from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser +from aiconfig.model_parser import InferenceOptions +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, + PromptMetadata, +) +from aiconfig.util.config_utils import get_api_key_from_environment from aiconfig.util.params import resolve_prompt -# ModelParser Utils -# Type hint imports # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -94,11 +97,10 @@ def construct_stream_output( delta = data if options and options.stream_callback: options.stream_callback(delta, accumulated_message, index) - output = ExecuteResult( **{ "output_type": "execute_result", - "data": copy.deepcopy(accumulated_message), + "data": accumulated_message or '', "execution_count": index, "metadata": metadata, } @@ -115,7 +117,7 @@ def construct_regular_output(response: TextGenerationResponse, response_includes output = ExecuteResult( **{ "output_type": "execute_result", - "data": response.generated_text, + "data": response.generated_text or '', "execution_count": 0, "metadata": metadata, } @@ -316,14 +318,19 @@ def get_output_text( return "" if output.output_type == "execute_result": - assert isinstance(output, ExecuteResult) output_data = output.data if isinstance(output_data, str): return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # HuggingFace Text generation does not support function + # calls so shouldn't get here, but just being safe + return json.dumps(output_data.value, indent=2) # Doing this to be backwards-compatible with old output format # where we used to save the TextGenerationResponse or # TextGenerationStreamResponse in output.data - elif hasattr(output_data, "generated_text"): + if hasattr(output_data, "generated_text"): return output_data.generated_text return "" diff --git a/python/src/aiconfig/default_parsers/openai.py b/python/src/aiconfig/default_parsers/openai.py index fe15f37c5..1426239d4 100644 --- a/python/src/aiconfig/default_parsers/openai.py +++ b/python/src/aiconfig/default_parsers/openai.py @@ -1,4 +1,5 @@ import copy +import json from abc import abstractmethod from typing import TYPE_CHECKING, Dict, List, Optional, Union @@ -7,7 +8,17 @@ from aiconfig.callback import CallbackEvent from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions -from aiconfig.schema import ExecuteResult, Output, Prompt, PromptInput, PromptMetadata +from aiconfig.schema import ( + ExecuteResult, + FunctionCallData, + Output, + OutputDataWithValue, + OutputDataWithToolCallsValue, + Prompt, + PromptInput, + PromptMetadata, + ToolCallData, +) from aiconfig.util.config_utils import get_api_key_from_environment from aiconfig.util.params import ( resolve_prompt, @@ -94,7 +105,7 @@ async def serialize( role = messsage["role"] if role == "user" or role == "function": # Serialize User message as a prompt and save the assistant response as an output - assistant_response = None + assistant_response : Union[ChatCompletionMessage, None] = None if i + 1 < len(conversation_data["messages"]): next_message = conversation_data["messages"][i + 1] if next_message["role"] == "assistant": @@ -108,13 +119,16 @@ async def serialize( assistant_output = [] if assistant_response is not None: - assistant_content = assistant_response.get("content", "") + output_data = build_output_data(assistant_response) + metadata = {"rawResponse": assistant_response} + if hasattr(assistant_response, "role") and assistant_response.role is not None: + metadata["role"] = assistant_response.role assistant_output = [ ExecuteResult( output_type="execute_result", execution_count=None, - data=assistant_content, - metadata={"rawResponse": assistant_response}, + data=output_data, + metadata=metadata, ) ] prompt = Prompt( @@ -272,18 +286,25 @@ async def run_inference( if key != "choices" } for i, choice in enumerate(response.get("choices")): + output_message = choice["message"] + output_data = build_output_data(output_message) + response_without_choices.update( {"finish_reason": choice.get("finish_reason")} ) - output_message = choice["message"] - # TODO (rossdanlm): Support function call handling as output_data - output_content = output_message.get("content", "") + metadata = { + "rawResponse": output_message, + **response_without_choices + } + if hasattr(output_message, "role") and output_message.role is not None: + metadata["role"] = output_message.role + output = ExecuteResult( **{ "output_type": "execute_result", - "data": output_content, + "data": output_data, "execution_count": i, - "metadata": {"rawResponse": output_message, **response_without_choices}, + "metadata": metadata, } ) @@ -299,7 +320,7 @@ async def run_inference( for i, choice in enumerate(chunk["choices"]): index = choice.get("index") - accumulated_message_for_choice = messages.get(index, {}) + accumulated_message_for_choice = messages.get(index, "") delta = choice.get("delta") if options and options.stream_callback: @@ -310,7 +331,7 @@ async def run_inference( output = ExecuteResult( **{ "output_type": "execute_result", - "data": copy.deepcopy(accumulated_message_for_choice), + "data": accumulated_message_for_choice, "execution_count": index, "metadata": {"finish_reason": choice.get("finish_reason")}, } @@ -353,19 +374,22 @@ def get_output_text( return "" if output.output_type == "execute_result": - message = output.data - if isinstance(message, str): - return message - elif output.metadata["rawResponse"].get("function_call", None): - return output.metadata["rawResponse"].function_call + output_data = output.data + if isinstance(output_data, str): + return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # If we get here that means it must be of kind tool_calls + return json.dumps(output_data.value, indent=2) # Doing this to be backwards-compatible with old output format # where we used to save the ChatCompletionMessage in output.data - if isinstance(message, ChatCompletionMessage): - if hasattr(message, "content") and message.content is not None: - return message.content - elif message.function_call is not None: - return str(message.function_call) + if isinstance(output_data, ChatCompletionMessage): + if hasattr(output_data, "content") and output_data.content is not None: + return output_data.content + elif output_data.function_call is not None: + return str(output_data.function_call) return "" @@ -505,27 +529,48 @@ def add_prompt_as_message( if output.output_type == "execute_result": assert isinstance(output, ExecuteResult) output_data = output.data - # TODO (rossdanlm): Support function call handling as output_data - if isinstance(output_data, str): - if output.metadata["rawResponse"].get("role", "") == "assistant": - output_message = { - "content": output_data, - "role": "assistant", - } - function_call = output.metadata["rawResponse"].get("function_call", None) - tool_calls = output.metadata["rawResponse"].get("tool_calls", None) - if function_call is not None: - output_message["function_call"] = function_call - if tool_calls is not None: - output_message["tool_calls"] = tool_calls - messages.append(output_message) + role = output.metadata.get("assistant", None) or \ + ("rawResponse" in output.metadata and + output.metadata["rawResponse"].get("assistant", None)) + + if role == "assistant": + output_message = {} + content : Union[str, None] = None + function_call : Union[FunctionCallData, None] = None + if isinstance(output_data, str): + content = output_data + elif isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + content = output_data.value + elif output_data.kind == "tool_calls": + assert isinstance(output, OutputDataWithToolCallsValue) + function_call = \ + output_data.value[len(output_data.value) - 1] \ + .function + + output_message["content"] = content + output_message["role"] = "assistant" + + # We should update this to use ChatCompletionAssistantMessageParam + # object with field `tool_calls. See comment for details: + # https://github.com/lastmile-ai/aiconfig/pull/610#discussion_r1437174736 + if function_call is not None: + output_message["function_call"] = function_call + + name = output.metadata.get("name", None) or \ + ("rawResponse" in output.metadata and + output.metadata["rawResponse"].get("name", None)) + if name is not None: + output_message["name"] = name + + messages.append(output_message) # Doing this to be backwards-compatible with old output format # where we used to save the ChatCompletionMessage in output.data elif isinstance(output_data, ChatCompletionMessage): if output_data.role == "assistant": messages.append(output_data) - + return messages @@ -536,3 +581,40 @@ def is_prompt_template(prompt: Prompt): return isinstance(prompt.input, str) or ( hasattr(prompt.input, "data") and isinstance(prompt.input.data, str) ) + +def build_output_data( + message: Union[ChatCompletionMessage, None], +) -> Union[OutputDataWithValue, str, None]: + output_data : Union[OutputDataWithValue, str, None] = None + if message is not None: + if message.content is not None: + output_data = message.content + elif message.tool_calls is not None: + tool_calls = [ToolCallData( + id = item.id, + function=FunctionCallData( + arguments=item.function.arguments, + name=item.function.name, + ), + type='function' + ) for item in message.tool_calls] + output_data = OutputDataWithToolCallsValue( + kind="tool_calls", + value = tool_calls, + ) + + # Deprecated, use tool_calls instead + elif message.function_call is not None: + tool_calls = [ToolCallData( + id = "function_call_data", + function=FunctionCallData( + arguments=message.function_call.arguments, + name=message.function_call.name, + ), + type='function' + )] + output_data = OutputDataWithToolCallsValue( + kind="tool_calls", + value = tool_calls, + ) + return output_data diff --git a/python/src/aiconfig/default_parsers/palm.py b/python/src/aiconfig/default_parsers/palm.py index 408388c7f..97b927253 100644 --- a/python/src/aiconfig/default_parsers/palm.py +++ b/python/src/aiconfig/default_parsers/palm.py @@ -1,16 +1,22 @@ +import json from typing import TYPE_CHECKING, Dict, List, Optional import google.generativeai as palm from google.generativeai.text import Completion from google.generativeai.types.discuss_types import MessageDict +from aiconfig.callback import CallbackEvent from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser +from aiconfig.model_parser import InferenceOptions +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, + PromptMetadata, +) from aiconfig.util.params import resolve_parameters, resolve_prompt -from ..callback import CallbackEvent -from ..model_parser import InferenceOptions -from ..schema import ExecuteResult, Output, Prompt, PromptMetadata - if TYPE_CHECKING: from aiconfig.Config import AIConfigRuntime @@ -377,15 +383,19 @@ def get_output_text( return "" if output.output_type == "execute_result": - assert isinstance(output, ExecuteResult) output_data = output.data - if isinstance(output_data, str): return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # HuggingFace Text generation does not support function + # calls so shouldn't get here, but just being safe + return json.dumps(output_data.value, indent=2) # Doing this to be backwards-compatible with old output format # where we used to save the MessageDict in output.data - elif isinstance(output_data, MessageDict): + if isinstance(output_data, MessageDict): if output_data.get("content"): return output_data("content") return "" diff --git a/python/tests/parsers/test_openai_util.py b/python/tests/parsers/test_openai_util.py index fbb7467da..623ded3f8 100644 --- a/python/tests/parsers/test_openai_util.py +++ b/python/tests/parsers/test_openai_util.py @@ -2,9 +2,15 @@ import pytest from aiconfig.Config import AIConfigRuntime from aiconfig.default_parsers.openai import refine_chat_completion_params +from aiconfig.schema import ( + ExecuteResult, + OutputDataWithValue, + Prompt, + PromptInput, + PromptMetadata, +) from mock import patch -from aiconfig import ExecuteResult, Prompt, PromptInput, PromptMetadata from ..conftest import mock_openai_chat_completion from ..util.file_path_utils import get_absolute_file_path_from_relative @@ -277,7 +283,7 @@ async def test_serialize(set_temporary_env_vars): ], } - + prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, "prompt") new_prompt = prompts[1]