diff --git a/cookbooks/HuggingFace/hf.py b/cookbooks/HuggingFace/hf.py index 8257b326e..64e84e3f5 100644 --- a/cookbooks/HuggingFace/hf.py +++ b/cookbooks/HuggingFace/hf.py @@ -207,7 +207,7 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - resolved_prompt = resolve_prompt(prompt, params, aiconfig) + resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) diff --git a/python/src/aiconfig/default_parsers/dalle.py b/python/src/aiconfig/default_parsers/dalle.py index 47acccac3..9b403d9b1 100644 --- a/python/src/aiconfig/default_parsers/dalle.py +++ b/python/src/aiconfig/default_parsers/dalle.py @@ -25,7 +25,7 @@ # TODO: Centralize this into utils function with the HuggingFaceTextParser class -def refine_image_completion_params(model_settings): +def refine_image_completion_params(model_settings, aiconfig, prompt): """ Refines the completion params for the Dall-E request API. Removes any unsupported params. The supported keys were found by looking at the OpenAI Dall-E API: https://platform.openai.com/docs/api-reference/images/create?lang=python` @@ -37,6 +37,11 @@ def refine_image_completion_params(model_settings): if key.lower() in supported_keys: completion_data[key.lower()] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data @@ -141,11 +146,11 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: dict: Model-specific completion parameters. """ # Get inputs from aiconfig - resolved_prompt = resolve_prompt(prompt, params, aiconfig) + resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig) model_settings = self.get_model_settings(prompt, aiconfig) # Build Completion data - completion_data = refine_image_completion_params(model_settings) + completion_data = refine_image_completion_params(model_settings, aiconfig, prompt) completion_data["prompt"] = resolved_prompt return completion_data diff --git a/python/src/aiconfig/default_parsers/openai.py b/python/src/aiconfig/default_parsers/openai.py index e3e4d95ac..efbd99b80 100644 --- a/python/src/aiconfig/default_parsers/openai.py +++ b/python/src/aiconfig/default_parsers/openai.py @@ -160,7 +160,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion params model_settings = self.get_model_settings(prompt, aiconfig) - completion_params = refine_chat_completion_params(model_settings) + completion_params = refine_chat_completion_params(model_settings, aiconfig, prompt) # In the case thhat the messages array weren't saves as part of the model settings, build it here. Messages array is used for conversation history. if not completion_params.get("messages"): @@ -407,10 +407,10 @@ def multi_choice_message_reducer(messages: Union[Dict[int, dict], None], chunk: return messages -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings, aiconfig, prompt): # completion parameters to be used for openai's chat completion api # system prompt handled separately - # streaming handled seperatly. + # streaming handled separately. # Messages array built dynamically and handled separately supported_keys = { "frequency_penalty", @@ -436,6 +436,11 @@ def refine_chat_completion_params(model_settings): if key in model_settings: completion_data[key] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data diff --git a/python/src/aiconfig/default_parsers/palm.py b/python/src/aiconfig/default_parsers/palm.py index edf5fb323..749c2e540 100644 --- a/python/src/aiconfig/default_parsers/palm.py +++ b/python/src/aiconfig/default_parsers/palm.py @@ -90,7 +90,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings) + completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) prompt_str = resolve_prompt(prompt, params, aiconfig) @@ -246,7 +246,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings) + completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) # TODO: handle if user specifies previous messages in settings completion_data["messages"] = [] @@ -366,7 +366,7 @@ def get_output_text( return "" -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings, aiconfig, prompt): # completion parameters to be used for Palm's chat completion api # messages handled seperately supported_keys = { @@ -384,6 +384,11 @@ def refine_chat_completion_params(model_settings): if key.lower() in supported_keys: completion_data[key.lower()] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data diff --git a/python/src/aiconfig/model_parser.py b/python/src/aiconfig/model_parser.py index 13e559020..f5c817e05 100644 --- a/python/src/aiconfig/model_parser.py +++ b/python/src/aiconfig/model_parser.py @@ -174,11 +174,13 @@ def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dic else: # Merge config and prompt settings with prompt settings taking precedent model_settings = {} - global_settings = aiconfig.get_global_settings(model_metadata.name) - prompt_setings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + + model_name = prompt_settings.get("model", model_metadata.name) + global_settings = aiconfig.get_global_settings(model_name) model_settings.update(global_settings) - model_settings.update(prompt_setings) + model_settings.update(prompt_settings) return model_settings diff --git a/python/src/aiconfig/schema.py b/python/src/aiconfig/schema.py index 72dbe57ba..2e16f32fd 100644 --- a/python/src/aiconfig/schema.py +++ b/python/src/aiconfig/schema.py @@ -289,7 +289,8 @@ def get_model_name(self, prompt: Union[str, Prompt]) -> str: return prompt.metadata.model else: # Expect a ModelMetadata object - return prompt.metadata.model.name + prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + return prompt_settings.get("model", prompt.metadata.model.name) def set_default_model(self, model_name: Union[str, None]): """ diff --git a/python/tests/parsers/test_openai_util.py b/python/tests/parsers/test_openai_util.py index e8ae540db..81a92f0c6 100644 --- a/python/tests/parsers/test_openai_util.py +++ b/python/tests/parsers/test_openai_util.py @@ -4,20 +4,32 @@ from aiconfig.default_parsers.openai import refine_chat_completion_params from mock import patch -from aiconfig.schema import ExecuteResult, Prompt, PromptInput, PromptMetadata +from aiconfig.schema import ExecuteResult, Prompt, PromptInput, PromptMetadata, ModelMetadata from ..conftest import mock_openai_chat_completion from ..util.file_path_utils import get_absolute_file_path_from_relative def test_refine_chat_completion_params(): - model_settings_with_stream_and_system_prompt = { - "n": "3", - "stream": True, - "system_prompt": "system_prompt", - "random_attribute": "value_doesn't_matter", - } - refined_params = refine_chat_completion_params(model_settings_with_stream_and_system_prompt) + prompt = Prompt( + name="test", + input="test", + metadata=PromptMetadata( + model=ModelMetadata( + name="gpt-4", + settings={ + "n": "3", + "stream": True, + "system_prompt": "system_prompt", + "random_attribute": "value_doesn't_matter", + }, + ) + ), + ) + + aiconfig = AIConfigRuntime.create(name="test_refine_chat_completion_params", prompts=[prompt]) + + refined_params = refine_chat_completion_params(prompt.metadata.model.settings, aiconfig, prompt) assert "system_prompt" not in refined_params assert "stream" in refined_params