diff --git a/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py b/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py index 51539639b..9766ee4d8 100644 --- a/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py +++ b/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py @@ -71,9 +71,7 @@ def construct_regular_outputs(response: "AsyncGenerateContentResponse") -> list[ return output_list -async def construct_stream_outputs( - response: "AsyncGenerateContentResponse", options: InferenceOptions -) -> list[Output]: +async def construct_stream_outputs(response: "AsyncGenerateContentResponse", options: InferenceOptions) -> list[Output]: """ Construct Outputs while also streaming the response with stream callback @@ -132,17 +130,17 @@ async def serialize( The data passed in is the completion params a user would use to call the Gemini API directly. If the user wanted to call the Gemini API directly, they might do something like this: - + ``` model = genai.GenerativeModel('gemini-pro') completion_params = {"contents": "Hello"} - + model.generate_content(**completion_params) - # Note: The above line is the same as doing this: + # Note: The above line is the same as doing this: model.generate_content(contents="Hello") ``` * Important: The contents field is what contains the input data. In this case, prompt input would be the contents field. - + Args: prompt (str): The prompt to be serialized. @@ -152,7 +150,7 @@ async def serialize( str: Serialized representation of the prompt and inference settings. Sample Usage: - 1. + 1. completion_params = {"contents": "Hello"} serialized_prompts = await ai_config.serialize("prompt", completion_params, "gemini-pro") @@ -188,7 +186,7 @@ async def serialize( data = copy.deepcopy(data) contents = data.pop("contents", None) - model_name = self.model.model_name[len("models/"):] + model_name = self.model.model_name[len("models/") :] model_metadata = ai_config.get_model_metadata(data, model_name) prompts = [] @@ -202,7 +200,9 @@ async def serialize( # } contents_is_role_dict = isinstance(contents, dict) and "role" in contents and "parts" # Multi Turn means that the contents is a list of dicts with alternating role and parts. See for more info: https://ai.google.dev/tutorials/python_quickstart#multi-turn_conversations - contents_is_multi_turn = isinstance(contents, list) and all(isinstance(item, dict) and "role" in item and "parts" in item for item in contents) + contents_is_multi_turn = isinstance(contents, list) and all( + isinstance(item, dict) and "role" in item and "parts" in item for item in contents + ) if contents is None: raise ValueError("No contents found in data. Gemini api request requires a contents field") @@ -226,14 +226,21 @@ async def serialize( outputs = [ ExecuteResult( **{ - "output_type": "execute_result", + "output_type": "execute_result", "data": model_message_parts[0], - "metadata": {"rawResponse": model_message}, + "metadata": {"raw_response": model_message}, } ) ] i += 1 - prompt = Prompt(**{"name": f'{prompt_name}_{len(prompts) + 1}', "input": user_message_parts, "metadata": {"model": model_metadata}, "outputs": outputs}) + prompt = Prompt( + **{ + "name": f"{prompt_name}_{len(prompts) + 1}", + "input": user_message_parts, + "metadata": {"model": model_metadata}, + "outputs": outputs, + } + ) prompts.append(prompt) i += 1 else: @@ -255,11 +262,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks( - CallbackEvent( - "on_deserialize_start", __name__, {"prompt": prompt, "params": params} - ) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) @@ -270,9 +273,9 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: messages = self._construct_chat_history(prompt, aiconfig, params) resolved_prompt = resolve_prompt(prompt, params, aiconfig) - + messages.append({"role": "user", "parts": [{"text": resolved_prompt}]}) - + completion_data["contents"] = messages else: # If contents is already set, do not construct chat history. TODO: @Ankush-lastmile @@ -287,15 +290,13 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # This is checking attributes and not a dict like object. in schema.py, PromptInput allows arbitrary attributes/data, and gets serialized as an attribute because it is a pydantic type if not hasattr(prompt_input, "contents"): # The source code show cases this more than the docs. This curl request docs similar to python sdk: https://cloud.google.com/vertex-ai/docs/generative-ai/model-reference/gemini#request_body - raise ValueError("Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API") + raise ValueError( + "Unable to deserialize input. Prompt input type is not a string, Gemini Model Parser expects prompt input to contain a 'contents' field as expected by Gemini API" + ) - completion_data['contents'] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params) + completion_data["contents"] = parameterize_supported_gemini_input_data(prompt_input.contents, prompt, aiconfig, params) - await aiconfig.callback_manager.run_callbacks( - CallbackEvent( - "on_deserialize_complete", __name__, {"output": completion_data} - ) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) return completion_data async def run_inference( @@ -347,9 +348,7 @@ async def run_inference( outputs = construct_regular_outputs(response) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks( - CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) return prompt.outputs def get_output_text( @@ -375,27 +374,21 @@ def get_output_text( # We use this method in the model parser functionality itself, # so we must error to let the user know non-string # formats are not supported - error_message = \ -""" + error_message = """ We currently only support chats where prior messages from a user are only a single message string instance, please see docstring for more details: https://github.com/lastmile-ai/aiconfig/blob/v1.1.8/extensions/Gemini/python/src/aiconfig_extension_gemini/Gemini.py#L31-L33 """ raise ValueError(error_message) - def _construct_chat_history( - self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict - ) -> List: + def _construct_chat_history(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Dict) -> List: """ Constructs the chat history for the model """ messages = [] # Default to always use chat context - remember_chat_context = not hasattr( - prompt.metadata, "remember_chat_context" - ) or ( - hasattr(prompt.metadata, "remember_chat_context") - and prompt.metadata.remember_chat_context != False + remember_chat_context = not hasattr(prompt.metadata, "remember_chat_context") or ( + hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False ) if remember_chat_context: # handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output @@ -404,21 +397,13 @@ def _construct_chat_history( if previous_prompt.name == prompt.name: break - previous_prompt_is_same_model = aiconfig.get_model_name( - previous_prompt - ) == aiconfig.get_model_name(prompt) + previous_prompt_is_same_model = aiconfig.get_model_name(previous_prompt) == aiconfig.get_model_name(prompt) if previous_prompt_is_same_model: - previous_prompt_template = resolve_prompt( - previous_prompt, params, aiconfig - ) + previous_prompt_template = resolve_prompt(previous_prompt, params, aiconfig) previous_prompt_output = aiconfig.get_latest_output(previous_prompt) - previous_prompt_output_text = self.get_output_text( - previous_prompt, aiconfig, previous_prompt_output - ) + previous_prompt_output_text = self.get_output_text(previous_prompt, aiconfig, previous_prompt_output) - messages.append( - {"role": "user", "parts": [{"text": previous_prompt_template}]} - ) + messages.append({"role": "user", "parts": [{"text": previous_prompt_template}]}) messages.append( { "role": "model", @@ -426,7 +411,7 @@ def _construct_chat_history( } ) - return messages + return messages def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> str: """ @@ -445,26 +430,18 @@ def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> st elif isinstance(contents, list): return " ".join(contents) elif isinstance(contents, dict): - parts= contents["parts"] + parts = contents["parts"] if isinstance(parts, str): return parts elif isinstance(parts, list): return " ".join(parts) else: - raise Exception( - f"Cannot get prompt template string from prompt input: {prompt.input}" - ) + raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}") else: - raise Exception( - f"Cannot get prompt template string from prompt input: {prompt.input}" - ) + raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}") - - else: - raise Exception( - f"Cannot get prompt template string from prompt input: {prompt.input}" - ) + raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}") def refine_chat_completion_params(model_settings): @@ -522,9 +499,7 @@ def contains_prompt_template(prompt: Prompt): """ Check if a prompt's input is a valid string. """ - return isinstance(prompt.input, str) or ( - hasattr(prompt.input, "data") and isinstance(prompt.input.data, str) - ) + return isinstance(prompt.input, str) or (hasattr(prompt.input, "data") and isinstance(prompt.input.data, str)) AIConfigRuntime.register_model_parser(GeminiModelParser("gemini-pro"), "gemini-pro") diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py index 5f05e655c..434648d0d 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py @@ -109,7 +109,7 @@ def construct_stream_output( def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output: - metadata = {"rawResponse": response} + metadata = {"raw_response": response} if response_includes_details: metadata["details"] = response.details @@ -209,15 +209,13 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata( - model=model_metadata, parameters=parameters, **kwargs - ), + metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), ) prompts.append(prompt) - - await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts })) - + + await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts})) + return prompts async def deserialize( @@ -251,9 +249,7 @@ async def deserialize( return completion_data - async def run_inference( - self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any] - ) -> List[Output]: + async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/extensions/HuggingFace/typescript/hf.ts b/extensions/HuggingFace/typescript/hf.ts index cf25f8674..0bf1993b8 100644 --- a/extensions/HuggingFace/typescript/hf.ts +++ b/extensions/HuggingFace/typescript/hf.ts @@ -311,7 +311,7 @@ function constructOutput(response: TextGenerationOutput): Output { output_type: "execute_result", data: response.generated_text, execution_count: 0, - metadata: { rawResponse: response }, + metadata: { raw_response: response }, } as ExecuteResult; return output; diff --git a/python/src/aiconfig/default_parsers/hf.py b/python/src/aiconfig/default_parsers/hf.py index 43a77b289..f45332b5d 100644 --- a/python/src/aiconfig/default_parsers/hf.py +++ b/python/src/aiconfig/default_parsers/hf.py @@ -109,14 +109,14 @@ def construct_stream_output( def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output: - metadata = {"rawResponse": response} + metadata = {"raw_response": response} if response_includes_details: metadata["details"] = response.details output = ExecuteResult( **{ "output_type": "execute_result", - "data": response.generated_text or '', + "data": response.generated_text or "", "execution_count": 0, "metadata": metadata, } @@ -209,15 +209,13 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata( - model=model_metadata, parameters=parameters, **kwargs - ), + metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), ) prompts.append(prompt) - - await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts })) - + + await ai_config.callback_manager.run_callbacks(CallbackEvent("on_serialize_complete", __name__, {"result": prompts})) + return prompts async def deserialize( @@ -251,9 +249,7 @@ async def deserialize( return completion_data - async def run_inference( - self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any] - ) -> List[Output]: + async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: dict[Any, Any]) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. diff --git a/python/src/aiconfig/default_parsers/openai.py b/python/src/aiconfig/default_parsers/openai.py index 8569f8845..ec1afc621 100644 --- a/python/src/aiconfig/default_parsers/openai.py +++ b/python/src/aiconfig/default_parsers/openai.py @@ -87,9 +87,7 @@ async def serialize( break # Get the global settings for the model - model_name = ( - conversation_data["model"] if "model" in conversation_data else self.id() - ) + model_name = conversation_data["model"] if "model" in conversation_data else self.id() model_metadata = ai_config.get_model_metadata(conversation_data, model_name) # Remove messages array from model metadata. Handled separately @@ -105,7 +103,7 @@ async def serialize( role = messsage["role"] if role == "user" or role == "function": # Serialize User message as a prompt and save the assistant response as an output - assistant_response : Union[ChatCompletionMessage, None] = None + assistant_response: Union[ChatCompletionMessage, None] = None if i + 1 < len(conversation_data["messages"]): next_message = conversation_data["messages"][i + 1] if next_message["role"] == "assistant": @@ -113,14 +111,12 @@ async def serialize( i += 1 new_prompt_name = f"{prompt_name}_{len(prompts) + 1}" - input = ( - messsage["content"] if role == "user" else PromptInput(**messsage) - ) + input = messsage["content"] if role == "user" else PromptInput(**messsage) assistant_output = [] if assistant_response is not None: output_data = build_output_data(assistant_response) - metadata = {"rawResponse": assistant_response} + metadata = {"raw_response": assistant_response} if assistant_response.get("role", None) is not None: metadata["role"] = assistant_response.get("role") assistant_output = [ @@ -153,9 +149,7 @@ async def serialize( await ai_config.callback_manager.run_callbacks(event) return prompts - async def deserialize( - self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {} - ) -> Dict: + async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}) -> Dict: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -166,11 +160,7 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks( - CallbackEvent( - "on_deserialize_start", __name__, {"prompt": prompt, "params": params} - ) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) # Build Completion params model_settings = self.get_model_settings(prompt, aiconfig) @@ -186,17 +176,12 @@ async def deserialize( if isinstance(system_prompt, dict): # If system prompt is an object, then it should have content and role attributes system_prompt = system_prompt["content"] - resolved_system_prompt = resolve_system_prompt( - prompt, system_prompt, params, aiconfig - ) - completion_params["messages"].append( - {"content": resolved_system_prompt, "role": "system"} - ) + resolved_system_prompt = resolve_system_prompt(prompt, system_prompt, params, aiconfig) + completion_params["messages"].append({"content": resolved_system_prompt, "role": "system"}) # Default to always use chat context if not hasattr(prompt.metadata, "remember_chat_context") or ( - hasattr(prompt.metadata, "remember_chat_context") - and prompt.metadata.remember_chat_context != False + hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False ): # handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output for i, previous_prompt in enumerate(aiconfig.prompts): @@ -204,9 +189,7 @@ async def deserialize( if previous_prompt.name == prompt.name: break - if aiconfig.get_model_name( - previous_prompt - ) == aiconfig.get_model_name(prompt): + if aiconfig.get_model_name(previous_prompt) == aiconfig.get_model_name(prompt): # Add prompt and its output to completion data. Constructing this prompt will take into account available parameters. add_prompt_as_message( previous_prompt, @@ -226,11 +209,7 @@ async def deserialize( # Add in the latest prompt add_prompt_as_message(prompt, aiconfig, completion_params["messages"], params) - await aiconfig.callback_manager.run_callbacks( - CallbackEvent( - "on_deserialize_complete", __name__, {"output": completion_params} - ) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) return completion_params async def run_inference( @@ -280,22 +259,13 @@ async def run_inference( # # OpenAI>1.0.0 uses pydantic models for response response = response.model_dump(exclude_none=True) - response_without_choices = { - key: copy.deepcopy(value) - for key, value in response.items() - if key != "choices" - } + response_without_choices = {key: copy.deepcopy(value) for key, value in response.items() if key != "choices"} for i, choice in enumerate(response.get("choices")): output_message = choice["message"] output_data = build_output_data(output_message) - response_without_choices.update( - {"finish_reason": choice.get("finish_reason")} - ) - metadata = { - "rawResponse": output_message, - **response_without_choices - } + response_without_choices.update({"finish_reason": choice.get("finish_reason")}) + metadata = {"raw_response": output_message, **response_without_choices} if output_message.get("role", None) is not None: metadata["role"] = output_message.get("role") @@ -324,9 +294,7 @@ async def run_inference( delta = choice.get("delta") if options and options.stream_callback: - options.stream_callback( - delta, accumulated_message_for_choice, index - ) + options.stream_callback(delta, accumulated_message_for_choice, index) output = ExecuteResult( **{ @@ -342,9 +310,7 @@ async def run_inference( # rewrite or extend list of outputs? prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks( - CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) return prompt.outputs def get_prompt_template(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> str: @@ -353,9 +319,7 @@ def get_prompt_template(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> st """ if isinstance(prompt.input, str): return prompt.input - elif isinstance(prompt.input, PromptInput) and isinstance( - prompt.input.data, str - ): + elif isinstance(prompt.input, PromptInput) and isinstance(prompt.input.data, str): return prompt.input.data else: message = prompt.input @@ -437,9 +401,7 @@ def reduce(acc, delta): return acc -def multi_choice_message_reducer( - messages: Union[Dict[int, dict], None], chunk: dict -) -> Dict[int, dict]: +def multi_choice_message_reducer(messages: Union[Dict[int, dict], None], chunk: dict) -> Dict[int, dict]: if messages is None: messages = {} @@ -485,9 +447,7 @@ def refine_chat_completion_params(model_settings): return completion_data -def add_prompt_as_message( - prompt: Prompt, aiconfig: "AIConfigRuntime", messages: List, params=None -): +def add_prompt_as_message(prompt: Prompt, aiconfig: "AIConfigRuntime", messages: List, params=None): """ Converts a given prompt to a message and adds it to the specified messages list. @@ -501,17 +461,11 @@ def add_prompt_as_message( messages.append({"content": resolved_prompt, "role": "user"}) else: # Assumes Prompt input will be in the format of ChatCompletionMessageParam (with content, role, function_name, and name attributes) - resolved_prompt = resolve_prompt_string( - prompt, params, aiconfig, prompt.input.content - ) + resolved_prompt = resolve_prompt_string(prompt, params, aiconfig, prompt.input.content) prompt_input = prompt.input role = prompt_input.role if hasattr(prompt_input, "role") else "user" - fn_call = ( - prompt_input.function_call - if hasattr(prompt_input, "function_call") - else None - ) + fn_call = prompt_input.function_call if hasattr(prompt_input, "function_call") else None name = prompt_input.name if hasattr(prompt_input, "name") else None message_data = {"content": resolved_prompt, "role": role} @@ -529,14 +483,12 @@ def add_prompt_as_message( if output.output_type == "execute_result": assert isinstance(output, ExecuteResult) output_data = output.data - role = output.metadata.get("role", None) or \ - ("rawResponse" in output.metadata and - output.metadata["rawResponse"].get("role", None)) + role = output.metadata.get("role", None) or ("raw_response" in output.metadata and output.metadata["raw_response"].get("role", None)) if role == "assistant": output_message = {} - content : Union[str, None] = None - function_call : Union[FunctionCallData, None] = None + content: Union[str, None] = None + function_call: Union[FunctionCallData, None] = None if isinstance(output_data, str): content = output_data elif isinstance(output_data, OutputDataWithValue): @@ -544,22 +496,18 @@ def add_prompt_as_message( content = output_data.value elif output_data.kind == "tool_calls": assert isinstance(output, OutputDataWithToolCallsValue) - function_call = \ - output_data.value[len(output_data.value) - 1] \ - .function + function_call = output_data.value[len(output_data.value) - 1].function output_message["content"] = content output_message["role"] = role # We should update this to use ChatCompletionAssistantMessageParam # object with field `tool_calls. See comment for details: - # https://github.com/lastmile-ai/aiconfig/pull/610#discussion_r1437174736 + # https://github.com/lastmile-ai/aiconfig/pull/610#discussion_r1437174736 if function_call is not None: output_message["function_call"] = function_call - name = output.metadata.get("name", None) or \ - ("rawResponse" in output.metadata and - output.metadata["rawResponse"].get("name", None)) + name = output.metadata.get("name", None) or ("raw_response" in output.metadata and output.metadata["raw_response"].get("name", None)) if name is not None: output_message["name"] = name @@ -578,9 +526,8 @@ def is_prompt_template(prompt: Prompt): """ Check if a prompt's input is a valid string. """ - return isinstance(prompt.input, str) or ( - hasattr(prompt.input, "data") and isinstance(prompt.input.data, str) - ) + return isinstance(prompt.input, str) or (hasattr(prompt.input, "data") and isinstance(prompt.input.data, str)) + def build_output_data( message: Union[ChatCompletionMessage, None], @@ -588,28 +535,28 @@ def build_output_data( if message is None: return None - output_data : Union[OutputDataWithValue, str, None] = None + output_data: Union[OutputDataWithValue, str, None] = None if message.get("content") is not None: - output_data = message.get("content") #string + output_data = message.get("content") # string elif message.get("tool_calls") is not None: tool_calls = [ - ToolCallData( - id = item.id, - function=FunctionCallData( - arguments=item.function.arguments, - name=item.function.name, - ), - type='function' - ) + ToolCallData( + id=item.id, + function=FunctionCallData( + arguments=item.function.arguments, + name=item.function.name, + ), + type="function", + ) for item in message.get("tool_calls") - # It's possible that ChatCompletionMessageToolCall may - # support more than just function calls in the future - # so filter out other types - if item.type == 'function' + # It's possible that ChatCompletionMessageToolCall may + # support more than just function calls in the future + # so filter out other types + if item.type == "function" ] output_data = OutputDataWithToolCallsValue( kind="tool_calls", - value = tool_calls, + value=tool_calls, ) # Deprecated, use tool_calls instead @@ -617,16 +564,16 @@ def build_output_data( function_call = message.get("function_call") tool_calls = [ ToolCallData( - id = "function_call_data", # value here does not matter + id="function_call_data", # value here does not matter function=FunctionCallData( arguments=function_call["arguments"], name=function_call["name"], ), - type='function' + type="function", ) ] output_data = OutputDataWithToolCallsValue( kind="tool_calls", - value = tool_calls, + value=tool_calls, ) return output_data diff --git a/python/src/aiconfig/default_parsers/palm.py b/python/src/aiconfig/default_parsers/palm.py index feb21e974..34416c27e 100644 --- a/python/src/aiconfig/default_parsers/palm.py +++ b/python/src/aiconfig/default_parsers/palm.py @@ -8,10 +8,10 @@ from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( - ExecuteResult, - Output, - OutputDataWithValue, - Prompt, + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, PromptMetadata, ) from aiconfig.util.params import resolve_parameters, resolve_prompt @@ -81,9 +81,7 @@ async def serialize( return prompts - async def deserialize( - self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {} - ) -> Dict: + async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}) -> Dict: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -94,11 +92,7 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks( - CallbackEvent( - "on_deserialize_start", __name__, {"prompt": prompt, "params": params} - ) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) @@ -109,11 +103,7 @@ async def deserialize( # pass in the user prompt completion_data["prompt"] = prompt_str - await aiconfig.callback_manager.run_callbacks( - CallbackEvent( - "on_deserialize_complete", __name__, {"output": completion_data} - ) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) return completion_data async def run_inference( @@ -154,22 +144,20 @@ async def run_inference( outputs = [] # completion.candidates has all outputs. Candidates is an attribute of completion. Candidates is a dict. Taken from Google API impl for i, candidate in enumerate(completion.candidates): - # candidate is a TextCompletion obj (https://shorturl.at/emuG2), + # candidate is a TextCompletion obj (https://shorturl.at/emuG2), # but Pydantic TypedDict breaks for Python v<3.12 so using # generic Dict type - candidate: Dict + candidate: Dict output = ExecuteResult( output_type="execute_result", execution_count=i, data=candidate.get("output", ""), - metadata={"rawResponse": candidate}, + metadata={"raw_response": candidate}, ) outputs.append(output) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks( - CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) return outputs def get_output_text( @@ -247,9 +235,7 @@ async def serialize( await ai_config.callback_manager.run_callbacks(event) return prompts - async def deserialize( - self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {} - ) -> Dict: + async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}) -> Dict: """ Defines how to parse a prompt in the .aiconfig for a particular model and constructs the completion params for that model. @@ -260,11 +246,7 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - await aiconfig.callback_manager.run_callbacks( - CallbackEvent( - "on_deserialize_start", __name__, {"prompt": prompt, "params": params} - ) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params})) resolved_prompt = resolve_prompt(prompt, params, aiconfig) # Build Completion data @@ -277,8 +259,7 @@ async def deserialize( # Default to always use chat contextjkl; if not hasattr(prompt.metadata, "remember_chat_context") or ( - hasattr(prompt.metadata, "remember_chat_context") - and prompt.metadata.remember_chat_context != False + hasattr(prompt.metadata, "remember_chat_context") and prompt.metadata.remember_chat_context != False ): # handle chat history. check previous prompts for the same model. if same model, add prompt and its output to completion data if it has a completed output for i, previous_prompt in enumerate(aiconfig.prompts): @@ -293,12 +274,8 @@ async def deserialize( # check if prompt has an output. PaLM Api requires this if len(previous_prompt.outputs) > 0: - resolved_previous_prompt = resolve_parameters( - {}, previous_prompt, aiconfig - ) - completion_data["messages"].append( - {"content": resolved_previous_prompt, "author": "0"} - ) + resolved_previous_prompt = resolve_parameters({}, previous_prompt, aiconfig) + completion_data["messages"].append({"content": resolved_previous_prompt, "author": "0"}) completion_data["messages"].append( { @@ -312,11 +289,7 @@ async def deserialize( # pass in the user prompt completion_data["messages"].append({"content": resolved_prompt, "author": "0"}) - await aiconfig.callback_manager.run_callbacks( - CallbackEvent( - "on_deserialize_complete", __name__, {"output": completion_data} - ) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) return completion_data async def run_inference( @@ -350,7 +323,7 @@ async def run_inference( response = palm.chat(**completion_data) outputs = [] for i, candidate in enumerate(response.candidates): - # candidate is a MessageDict obj (https://shorturl.at/jKY35), + # candidate is a MessageDict obj (https://shorturl.at/jKY35), # but Pydantic TypedDict breaks for Python v<3.12 so using # generic Dict type candidate: Dict @@ -359,15 +332,13 @@ async def run_inference( "output_type": "execute_result", "data": candidate.get("content", ""), "execution_count": i, - "metadata": {"rawResponse": response}, + "metadata": {"raw_response": response}, } ) outputs.append(output) prompt.outputs = outputs - await aiconfig.callback_manager.run_callbacks( - CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}) - ) + await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs})) return prompt.outputs def get_output_text( diff --git a/python/tests/parsers/test_openai_util.py b/python/tests/parsers/test_openai_util.py index 0e73bc5dd..9064d9cb1 100644 --- a/python/tests/parsers/test_openai_util.py +++ b/python/tests/parsers/test_openai_util.py @@ -23,9 +23,7 @@ def test_refine_chat_completion_params(): "system_prompt": "system_prompt", "random_attribute": "value_doesn't_matter", } - refined_params = refine_chat_completion_params( - model_settings_with_stream_and_system_prompt - ) + refined_params = refine_chat_completion_params(model_settings_with_stream_and_system_prompt) assert "system_prompt" not in refined_params assert "stream" in refined_params @@ -35,13 +33,9 @@ def test_refine_chat_completion_params(): @pytest.mark.asyncio async def test_get_output_text(set_temporary_env_vars): - with patch.object( - openai.chat.completions, "create", side_effect=mock_openai_chat_completion - ): + with patch.object(openai.chat.completions, "create", side_effect=mock_openai_chat_completion): config_relative_path = "../aiconfigs/basic_chatgpt_query_config.json" - config_absolute_path = get_absolute_file_path_from_relative( - __file__, config_relative_path - ) + config_absolute_path = get_absolute_file_path_from_relative(__file__, config_relative_path) aiconfig = AIConfigRuntime.load(config_absolute_path) await aiconfig.run("prompt1", {}) @@ -56,9 +50,7 @@ async def test_get_output_text(set_temporary_env_vars): @pytest.mark.asyncio async def test_serialize(set_temporary_env_vars): - with patch.object( - openai.chat.completions, "create", side_effect=mock_openai_chat_completion - ): + with patch.object(openai.chat.completions, "create", side_effect=mock_openai_chat_completion): # Test with one input prompt and system. No output completion_params = { "model": "gpt-3.5-turbo", @@ -71,9 +63,7 @@ async def test_serialize(set_temporary_env_vars): } aiconfig = AIConfigRuntime.create() - serialized_prompts = await aiconfig.serialize( - "gpt-3.5-turbo", completion_params, prompt_name="the prompt" - ) + serialized_prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, prompt_name="the prompt") new_prompt = serialized_prompts[0] # assert prompt serialized correctly into config @@ -112,9 +102,7 @@ async def test_serialize(set_temporary_env_vars): ], } - serialized_prompts = await aiconfig.serialize( - "gpt-3.5-turbo", completion_params, "prompt" - ) + serialized_prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, "prompt") new_prompt = serialized_prompts[0] expected_prompt = Prompt( @@ -141,13 +129,10 @@ async def test_serialize(set_temporary_env_vars): ExecuteResult( output_type="execute_result", execution_count=None, - data='Hello! How can I assist you today?', + data="Hello! How can I assist you today?", metadata={ - 'rawResponse': { - 'role': 'assistant', - 'content': 'Hello! How can I assist you today?' - }, - 'role': 'assistant', + "raw_response": {"role": "assistant", "content": "Hello! How can I assist you today?"}, + "role": "assistant", }, mime_type=None, ) @@ -159,7 +144,6 @@ async def test_serialize(set_temporary_env_vars): assert new_prompt.name == expected_prompt.name assert new_prompt == expected_prompt - # Test completion params with a function call input completion_params = { @@ -192,9 +176,7 @@ async def test_serialize(set_temporary_env_vars): ], } - serialized_prompts = await aiconfig.serialize( - "gpt-3.5-turbo", completion_params, "prompt" - ) + serialized_prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, "prompt") new_prompt = serialized_prompts[0] assert new_prompt == Prompt( name="prompt", @@ -285,10 +267,9 @@ async def test_serialize(set_temporary_env_vars): ], } - prompts = await aiconfig.serialize("gpt-3.5-turbo", completion_params, "prompt") new_prompt = prompts[1] - + expected_prompt = Prompt( name="prompt", input=PromptInput( @@ -339,11 +320,11 @@ async def test_serialize(set_temporary_env_vars): execution_count=None, data="The current weather in Boston is 22 degrees Celsius and sunny.", metadata={ - 'rawResponse': { - 'role': 'assistant', - 'content': 'The current weather in Boston is 22 degrees Celsius and sunny.', + "raw_response": { + "role": "assistant", + "content": "The current weather in Boston is 22 degrees Celsius and sunny.", }, - 'role': 'assistant', + "role": "assistant", }, mime_type=None, ) diff --git a/typescript/__tests__/parsers/hf/hf.test.ts b/typescript/__tests__/parsers/hf/hf.test.ts index 26ecda8d6..ce6e675f2 100644 --- a/typescript/__tests__/parsers/hf/hf.test.ts +++ b/typescript/__tests__/parsers/hf/hf.test.ts @@ -260,7 +260,7 @@ describe("HuggingFaceTextGeneration ModelParser", () => { data: "Test text generation", execution_count: 0, metadata: { - rawResponse: { + raw_response: { generated_text: "Test text generation", }, }, diff --git a/typescript/lib/parsers/hf.ts b/typescript/lib/parsers/hf.ts index cd8712967..a990548d8 100644 --- a/typescript/lib/parsers/hf.ts +++ b/typescript/lib/parsers/hf.ts @@ -305,7 +305,7 @@ function constructOutput(response: TextGenerationOutput): Output { output_type: "execute_result", data: response.generated_text, execution_count: 0, - metadata: { rawResponse: response }, + metadata: { raw_response: response }, } as ExecuteResult; return output; } diff --git a/typescript/lib/parsers/openai.ts b/typescript/lib/parsers/openai.ts index d13d14f31..79f4ec049 100644 --- a/typescript/lib/parsers/openai.ts +++ b/typescript/lib/parsers/openai.ts @@ -183,7 +183,7 @@ export class OpenAIModelParser extends ParameterizedModelParser