From 6022f24b0ad6099febd96d143ffbffd5a0ea9d75 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Fri, 5 Jan 2024 15:46:02 -0500 Subject: [PATCH] Explicitly set "model" property for models that support it This change allows us to automatically set the "model" for the following model parsers: * OpenAI * Anyscale Endpoint * Hugging Face text generation * PaLM * Dall-E Without this change the local editor breaks because if a user doesn't explicitly specify the "model" property in settings metadata, inference fails for (some of) the above model parsers. This isn't ideal, because the user has already selected the model they want to use from the model selector, so shouldn't need to additionally specify a model id as well. There is still a larger issue of how we distinguish a model ID from a model parser ID. Filed #782 to come up with a proper fix for that. --- cookbooks/HuggingFace/hf.py | 32 +++++++----------- python/src/aiconfig/default_parsers/dalle.py | 11 +++++-- python/src/aiconfig/default_parsers/openai.py | 11 +++++-- python/src/aiconfig/default_parsers/palm.py | 11 +++++-- python/src/aiconfig/model_parser.py | 33 ++++++------------- python/src/aiconfig/schema.py | 3 +- 6 files changed, 48 insertions(+), 53 deletions(-) diff --git a/cookbooks/HuggingFace/hf.py b/cookbooks/HuggingFace/hf.py index aad4e1f2b..d0a055e00 100644 --- a/cookbooks/HuggingFace/hf.py +++ b/cookbooks/HuggingFace/hf.py @@ -29,7 +29,7 @@ # Step 1: define Helpers -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings, aiconfig, prompt): """ Refines the completion params for the HF text generation api. Removes any unsupported params. The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()` @@ -60,6 +60,11 @@ def refine_chat_completion_params(model_settings): if key.lower() in supported_keys: completion_data[key.lower()] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data @@ -166,14 +171,7 @@ def id(self) -> str: """ return "HuggingFaceTextParser" - def serialize( - self, - prompt_name: str, - data: Any, - ai_config: "AIConfigRuntime", - parameters: Optional[Dict] = None, - **kwargs - ) -> List[Prompt]: + def serialize(self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", parameters: Optional[Dict] = None, **kwargs) -> List[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. @@ -196,9 +194,7 @@ def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata( - model=model_metadata, parameters=parameters, **kwargs - ), + metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), ) return [prompt] @@ -219,20 +215,18 @@ async def deserialize( Returns: dict: Model-specific completion parameters. """ - resolved_prompt = resolve_prompt(prompt, params, aiconfig) + resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig) # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings) + completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) completion_data["prompt"] = resolved_prompt return completion_data - async def run_inference( - self, prompt: Prompt, aiconfig, options, parameters - ) -> List[Output]: + async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -247,9 +241,7 @@ async def run_inference( completion_data = await self.deserialize(prompt, aiconfig, options, parameters) # if stream enabled in runtime options and config, then stream. Otherwise don't stream. - stream = (options.stream if options else False) and ( - not "stream" in completion_data or completion_data.get("stream") != False - ) + stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) response = self.client.text_generation(**completion_data) response_is_detailed = completion_data.get("details", False) diff --git a/python/src/aiconfig/default_parsers/dalle.py b/python/src/aiconfig/default_parsers/dalle.py index 47acccac3..9b403d9b1 100644 --- a/python/src/aiconfig/default_parsers/dalle.py +++ b/python/src/aiconfig/default_parsers/dalle.py @@ -25,7 +25,7 @@ # TODO: Centralize this into utils function with the HuggingFaceTextParser class -def refine_image_completion_params(model_settings): +def refine_image_completion_params(model_settings, aiconfig, prompt): """ Refines the completion params for the Dall-E request API. Removes any unsupported params. The supported keys were found by looking at the OpenAI Dall-E API: https://platform.openai.com/docs/api-reference/images/create?lang=python` @@ -37,6 +37,11 @@ def refine_image_completion_params(model_settings): if key.lower() in supported_keys: completion_data[key.lower()] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data @@ -141,11 +146,11 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: dict: Model-specific completion parameters. """ # Get inputs from aiconfig - resolved_prompt = resolve_prompt(prompt, params, aiconfig) + resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig) model_settings = self.get_model_settings(prompt, aiconfig) # Build Completion data - completion_data = refine_image_completion_params(model_settings) + completion_data = refine_image_completion_params(model_settings, aiconfig, prompt) completion_data["prompt"] = resolved_prompt return completion_data diff --git a/python/src/aiconfig/default_parsers/openai.py b/python/src/aiconfig/default_parsers/openai.py index e3e4d95ac..efbd99b80 100644 --- a/python/src/aiconfig/default_parsers/openai.py +++ b/python/src/aiconfig/default_parsers/openai.py @@ -160,7 +160,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion params model_settings = self.get_model_settings(prompt, aiconfig) - completion_params = refine_chat_completion_params(model_settings) + completion_params = refine_chat_completion_params(model_settings, aiconfig, prompt) # In the case thhat the messages array weren't saves as part of the model settings, build it here. Messages array is used for conversation history. if not completion_params.get("messages"): @@ -407,10 +407,10 @@ def multi_choice_message_reducer(messages: Union[Dict[int, dict], None], chunk: return messages -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings, aiconfig, prompt): # completion parameters to be used for openai's chat completion api # system prompt handled separately - # streaming handled seperatly. + # streaming handled separately. # Messages array built dynamically and handled separately supported_keys = { "frequency_penalty", @@ -436,6 +436,11 @@ def refine_chat_completion_params(model_settings): if key in model_settings: completion_data[key] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data diff --git a/python/src/aiconfig/default_parsers/palm.py b/python/src/aiconfig/default_parsers/palm.py index edf5fb323..749c2e540 100644 --- a/python/src/aiconfig/default_parsers/palm.py +++ b/python/src/aiconfig/default_parsers/palm.py @@ -90,7 +90,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings) + completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) prompt_str = resolve_prompt(prompt, params, aiconfig) @@ -246,7 +246,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings) + completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt) # TODO: handle if user specifies previous messages in settings completion_data["messages"] = [] @@ -366,7 +366,7 @@ def get_output_text( return "" -def refine_chat_completion_params(model_settings): +def refine_chat_completion_params(model_settings, aiconfig, prompt): # completion parameters to be used for Palm's chat completion api # messages handled seperately supported_keys = { @@ -384,6 +384,11 @@ def refine_chat_completion_params(model_settings): if key.lower() in supported_keys: completion_data[key.lower()] = model_settings[key] + # Explicitly set the model to use if not already specified + if completion_data.get("model") is None: + model_name = aiconfig.get_model_name(prompt) + completion_data["model"] = model_name + return completion_data diff --git a/python/src/aiconfig/model_parser.py b/python/src/aiconfig/model_parser.py index 1f2a6d694..f412e502c 100644 --- a/python/src/aiconfig/model_parser.py +++ b/python/src/aiconfig/model_parser.py @@ -33,7 +33,7 @@ async def serialize( **kwargs, ) -> List[Prompt]: """ - Serialize a prompt and additional metadata/model settings into a + Serialize a prompt and additional metadata/model settings into a list of Prompt objects that can be saved in the AIConfig. Args: @@ -150,9 +150,7 @@ def get_output_text( str: The output text from the model inference response. """ - def get_model_settings( - self, prompt: Prompt, aiconfig: "AIConfigRuntime" - ) -> Dict[str, Any]: + def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dict[str, Any]: """ Extracts the AI model's settings from the configuration. If both prompt and config level settings are defined, merge them with prompt settings taking precedence. @@ -166,10 +164,7 @@ def get_model_settings( return aiconfig.get_global_settings(self.id()) # Check if the prompt exists in the config - if ( - prompt.name not in aiconfig.prompt_index - or aiconfig.prompt_index[prompt.name] != prompt - ): + if prompt.name not in aiconfig.prompt_index or aiconfig.prompt_index[prompt.name] != prompt: raise IndexError(f"Prompt '{prompt.name}' not in config.") model_metadata = prompt.metadata.model if prompt.metadata else None @@ -178,9 +173,7 @@ def get_model_settings( # Use Default Model default_model = aiconfig.get_default_model() if not default_model: - raise KeyError( - f"No default model specified in AIConfigMetadata, and prompt `{prompt.name}` does not specify a model." - ) + raise KeyError(f"No default model specified in AIConfigMetadata, and prompt `{prompt.name}` does not specify a model.") return aiconfig.get_global_settings(default_model) elif isinstance(model_metadata, str): # Use Global settings @@ -188,15 +181,13 @@ def get_model_settings( else: # Merge config and prompt settings with prompt settings taking precedent model_settings = {} - global_settings = aiconfig.get_global_settings(model_metadata.name) - prompt_setings = ( - prompt.metadata.model.settings - if prompt.metadata.model.settings is not None - else {} - ) + prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + + model_name = prompt_settings.get("model", model_metadata.name) + global_settings = aiconfig.get_global_settings(model_name) model_settings.update(global_settings) - model_settings.update(prompt_setings) + model_settings.update(prompt_settings) return model_settings @@ -205,11 +196,7 @@ def print_stream_callback(data, accumulated_data, index: int): """ Default streamCallback function that prints the output to the console. """ - print( - "\ndata: {}\naccumulated_data:{}\nindex:{}\n".format( - data, accumulated_data, index - ) - ) + print("\ndata: {}\naccumulated_data:{}\nindex:{}\n".format(data, accumulated_data, index)) def print_stream_delta(data, accumulated_data, index: int): diff --git a/python/src/aiconfig/schema.py b/python/src/aiconfig/schema.py index 72dbe57ba..2e16f32fd 100644 --- a/python/src/aiconfig/schema.py +++ b/python/src/aiconfig/schema.py @@ -289,7 +289,8 @@ def get_model_name(self, prompt: Union[str, Prompt]) -> str: return prompt.metadata.model else: # Expect a ModelMetadata object - return prompt.metadata.model.name + prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {} + return prompt_settings.get("model", prompt.metadata.model.name) def set_default_model(self, model_name: Union[str, None]): """