From 4925521546a619bbca90c1e7bd110b6600408704 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Jan 2024 17:04:11 -0500 Subject: [PATCH] Explicitly set "model" property for models that support it This change allows us to automatically set the "model" for the following model parsers: * OpenAI * Anyscale Endpoint * Hugging Face text generation * PaLM * Dall-E Without this change the local editor breaks because if a user doesn't explicitly specify the "model" property in settings metadata, inference fails for (some of) the above model parsers. This isn't ideal, because the user has already selected the model they want to use from the model selector, so shouldn't need to additionally specify a model id as well. There is still a larger issue of how we distinguish a model ID from a model parser ID. Filed #782 to come up with a proper fix for that. --- cookbooks/HuggingFace/hf.py | 2 +- python/src/aiconfig/default_parsers/dalle.py | 11 ++++++-- python/src/aiconfig/default_parsers/openai.py | 11 ++++++-- python/src/aiconfig/default_parsers/palm.py | 11 ++++++-- python/src/aiconfig/model_parser.py | 8 ++++-- python/src/aiconfig/schema.py | 3 +- python/tests/parsers/test_openai_util.py | 28 +++++++++++++------ 7 files changed, 52 insertions(+), 22 deletions(-) diff --git a/cookbooks/HuggingFace/hf.py b/cookbooks/HuggingFace/hf.py index 8257b326e..64e84e3f5 100644 --- a/cookbooks/HuggingFace/hf.py +++ b/cookbooks/HuggingFace/hf.py @@ -207,7 +207,7 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - resolved_prompt = resolve_prompt(prompt, params, aiconfig) + resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) diff --git a/python/src/aiconfig/default_parsers/dalle.py b/python/src/aiconfig/default_parsers/dalle.py index 47acccac3..9b403d9b1 100644 --- a/python/src/aiconfig/default_parsers/dalle.py +++ b/python/src/aiconfig/default_parsers/dalle.py @@ -25,7 +25,7 @@ # TODO: Centralize this into utils function with the HuggingFaceTextParser class -def refine_image_completion_params(model_settings): +def refine_image_completion_params(model_settings, aiconfig, prompt): """ Refines the completion params for the Dall-E request API. Removes any unsupported params. The supported keys were found by looking at the OpenAI Dall-E API: https://platform.openai.com/docs/api-reference/images/create?lang=python` @@ -37,6 +37,11 @@ def refine_image_completion_params(model_settings): if key.lower() in supported_keys: completion_data[key.lower()] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data @@ -141,11 +146,11 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: dict: Model-specific completion parameters. """ # Get inputs from aiconfig - resolved_prompt = resolve_prompt(prompt, params, aiconfig) + resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig) model_settings = self.get_model_settings(prompt, aiconfig) # Build Completion data - completion_data = refine_image_completion_params(model_settings) + completion_data = refine_image_completion_params(model_settings, aiconfig, prompt) completion_data["prompt"] = resolved_prompt return completion_data diff --git a/python/src/aiconfig/default_parsers/openai.py b/python/src/aiconfig/default_parsers/openai.py index e3e4d95ac..efbd99b80 100644 --- a/python/src/aiconfig/default_parsers/openai.py +++ b/python/src/aiconfig/default_parsers/openai.py @@ -160,7 +160,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion params model_settings = self.get_model_settings(prompt, aiconfig) - completion_params = refine_chat_completion_params(model_settings) + completion_params = refine_chat_completion_params(model_settings, aiconfig, prompt) # In the case thhat the messages array weren't saves as part of the model settings, build it here. Messages array is used for conversation history. if not completion_params.get("messages"): @@ -407,10 +407,10 @@ def multi_choice_message_reducer(messages: Union[Dict[int, dict], None], chunk: return messages -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings, aiconfig, prompt): # completion parameters to be used for openai's chat completion api # system prompt handled separately - # streaming handled seperatly. + # streaming handled separately. # Messages array built dynamically and handled separately supported_keys = { "frequency_penalty", @@ -436,6 +436,11 @@ def refine_chat_completion_params(model_settings): if key in model_settings: completion_data[key] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data diff --git a/python/src/aiconfig/default_parsers/palm.py b/python/src/aiconfig/default_parsers/palm.py index edf5fb323..749c2e540 100644 --- a/python/src/aiconfig/default_parsers/palm.py +++ b/python/src/aiconfig/default_parsers/palm.py @@ -90,7 +90,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings) + completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) prompt_str = resolve_prompt(prompt, params, aiconfig) @@ -246,7 +246,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings) + completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) # TODO: handle if user specifies previous messages in settings completion_data["messages"] = [] @@ -366,7 +366,7 @@ def get_output_text( return "" -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings, aiconfig, prompt): # completion parameters to be used for Palm's chat completion api # messages handled seperately supported_keys = { @@ -384,6 +384,11 @@ def refine_chat_completion_params(model_settings): if key.lower() in supported_keys: completion_data[key.lower()] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data diff --git a/python/src/aiconfig/model_parser.py b/python/src/aiconfig/model_parser.py index 13e559020..f5c817e05 100644 --- a/python/src/aiconfig/model_parser.py +++ b/python/src/aiconfig/model_parser.py @@ -174,11 +174,13 @@ def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dic else: # Merge config and prompt settings with prompt settings taking precedent model_settings = {} - global_settings = aiconfig.get_global_settings(model_metadata.name) - prompt_setings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + + model_name = prompt_settings.get("model", model_metadata.name) + global_settings = aiconfig.get_global_settings(model_name) model_settings.update(global_settings) - model_settings.update(prompt_setings) + model_settings.update(prompt_settings) return model_settings diff --git a/python/src/aiconfig/schema.py b/python/src/aiconfig/schema.py index 72dbe57ba..2e16f32fd 100644 --- a/python/src/aiconfig/schema.py +++ b/python/src/aiconfig/schema.py @@ -289,7 +289,8 @@ def get_model_name(self, prompt: Union[str, Prompt]) -> str: return prompt.metadata.model else: # Expect a ModelMetadata object - return prompt.metadata.model.name + prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + return prompt_settings.get("model", prompt.metadata.model.name) def set_default_model(self, model_name: Union[str, None]): """ diff --git a/python/tests/parsers/test_openai_util.py b/python/tests/parsers/test_openai_util.py index e8ae540db..81a92f0c6 100644 --- a/python/tests/parsers/test_openai_util.py +++ b/python/tests/parsers/test_openai_util.py @@ -4,20 +4,32 @@ from aiconfig.default_parsers.openai import refine_chat_completion_params from mock import patch -from aiconfig.schema import ExecuteResult, Prompt, PromptInput, PromptMetadata +from aiconfig.schema import ExecuteResult, Prompt, PromptInput, PromptMetadata, ModelMetadata from ..conftest import mock_openai_chat_completion from ..util.file_path_utils import get_absolute_file_path_from_relative def test_refine_chat_completion_params(): - model_settings_with_stream_and_system_prompt = { - "n": "3", - "stream": True, - "system_prompt": "system_prompt", - "random_attribute": "value_doesn't_matter", - } - refined_params = refine_chat_completion_params(model_settings_with_stream_and_system_prompt) + prompt = Prompt( + name="test", + input="test", + metadata=PromptMetadata( + model=ModelMetadata( + name="gpt-4", + settings={ + "n": "3", + "stream": True, + "system_prompt": "system_prompt", + "random_attribute": "value_doesn't_matter", + }, + ) + ), + ) + + aiconfig = AIConfigRuntime.create(name="test_refine_chat_completion_params", prompts=[prompt]) + + refined_params = refine_chat_completion_params(prompt.metadata.model.settings, aiconfig, prompt) assert "system_prompt" not in refined_params assert "stream" in refined_params