Skip to content

Commit

Permalink
Explicitly set "model" property for models that support it
Browse files Browse the repository at this point in the history
This change allows us to automatically set the "model" for the following model parsers:
* OpenAI
* Anyscale Endpoint
* Hugging Face text generation
* PaLM
* Dall-E

Without this change the local editor breaks because if a user doesn't explicitly specify the "model" property in settings metadata, inference fails for (some of) the above model parsers. This isn't ideal, because the user has already selected the model they want to use from the model selector, so shouldn't need to additionally specify a model id as well.

There is still a larger issue of how we distinguish a model ID from a model parser ID. Filed #782 to come up with a proper fix for that.
  • Loading branch information
saqadri committed Jan 5, 2024
1 parent b7a46a2 commit 6022f24
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 53 deletions.
32 changes: 12 additions & 20 deletions cookbooks/HuggingFace/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# Step 1: define Helpers


def refine_chat_completion_params(model_settings):
def refine_chat_completion_params(model_settings, aiconfig, prompt):
"""
Refines the completion params for the HF text generation api. Removes any unsupported params.
The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()`
Expand Down Expand Up @@ -60,6 +60,11 @@ def refine_chat_completion_params(model_settings):
if key.lower() in supported_keys:
completion_data[key.lower()] = model_settings[key]

# Explicitly set the model to use if not already specified
if completion_data.get("model") is None:
model_name = aiconfig.get_model_name(prompt)
completion_data["model"] = model_name

return completion_data


Expand Down Expand Up @@ -166,14 +171,7 @@ def id(self) -> str:
"""
return "HuggingFaceTextParser"

def serialize(
self,
prompt_name: str,
data: Any,
ai_config: "AIConfigRuntime",
parameters: Optional[Dict] = None,
**kwargs
) -> List[Prompt]:
def serialize(self, prompt_name: str, data: Any, ai_config: "AIConfigRuntime", parameters: Optional[Dict] = None, **kwargs) -> List[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Expand All @@ -196,9 +194,7 @@ def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)
return [prompt]

Expand All @@ -219,20 +215,18 @@ async def deserialize(
Returns:
dict: Model-specific completion parameters.
"""
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig)

# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)

completion_data = refine_chat_completion_params(model_settings)
completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt)

completion_data["prompt"] = resolved_prompt

return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig, options, parameters
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig, options, parameters) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand All @@ -247,9 +241,7 @@ async def run_inference(
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)

response = self.client.text_generation(**completion_data)
response_is_detailed = completion_data.get("details", False)
Expand Down
11 changes: 8 additions & 3 deletions python/src/aiconfig/default_parsers/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


# TODO: Centralize this into utils function with the HuggingFaceTextParser class
def refine_image_completion_params(model_settings):
def refine_image_completion_params(model_settings, aiconfig, prompt):
"""
Refines the completion params for the Dall-E request API. Removes any unsupported params.
The supported keys were found by looking at the OpenAI Dall-E API: https://platform.openai.com/docs/api-reference/images/create?lang=python`
Expand All @@ -37,6 +37,11 @@ def refine_image_completion_params(model_settings):
if key.lower() in supported_keys:
completion_data[key.lower()] = model_settings[key]

# Explicitly set the model to use if not already specified
if completion_data.get("model") is None:
model_name = aiconfig.get_model_name(prompt)
completion_data["model"] = model_name

return completion_data


Expand Down Expand Up @@ -141,11 +146,11 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
dict: Model-specific completion parameters.
"""
# Get inputs from aiconfig
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig)
model_settings = self.get_model_settings(prompt, aiconfig)

# Build Completion data
completion_data = refine_image_completion_params(model_settings)
completion_data = refine_image_completion_params(model_settings, aiconfig, prompt)
completion_data["prompt"] = resolved_prompt
return completion_data

Expand Down
11 changes: 8 additions & 3 deletions python/src/aiconfig/default_parsers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# Build Completion params
model_settings = self.get_model_settings(prompt, aiconfig)

completion_params = refine_chat_completion_params(model_settings)
completion_params = refine_chat_completion_params(model_settings, aiconfig, prompt)

# In the case thhat the messages array weren't saves as part of the model settings, build it here. Messages array is used for conversation history.
if not completion_params.get("messages"):
Expand Down Expand Up @@ -407,10 +407,10 @@ def multi_choice_message_reducer(messages: Union[Dict[int, dict], None], chunk:
return messages


def refine_chat_completion_params(model_settings):
def refine_chat_completion_params(model_settings, aiconfig, prompt):
# completion parameters to be used for openai's chat completion api
# system prompt handled separately
# streaming handled seperatly.
# streaming handled separately.
# Messages array built dynamically and handled separately
supported_keys = {
"frequency_penalty",
Expand All @@ -436,6 +436,11 @@ def refine_chat_completion_params(model_settings):
if key in model_settings:
completion_data[key] = model_settings[key]

# Explicitly set the model to use if not already specified
if completion_data.get("model") is None:
model_name = aiconfig.get_model_name(prompt)
completion_data["model"] = model_name

return completion_data


Expand Down
11 changes: 8 additions & 3 deletions python/src/aiconfig/default_parsers/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)

completion_data = refine_chat_completion_params(model_settings)
completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt)

prompt_str = resolve_prompt(prompt, params, aiconfig)

Expand Down Expand Up @@ -246,7 +246,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)

completion_data = refine_chat_completion_params(model_settings)
completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt)

# TODO: handle if user specifies previous messages in settings
completion_data["messages"] = []
Expand Down Expand Up @@ -366,7 +366,7 @@ def get_output_text(
return ""


def refine_chat_completion_params(model_settings):
def refine_chat_completion_params(model_settings, aiconfig, prompt):
# completion parameters to be used for Palm's chat completion api
# messages handled seperately
supported_keys = {
Expand All @@ -384,6 +384,11 @@ def refine_chat_completion_params(model_settings):
if key.lower() in supported_keys:
completion_data[key.lower()] = model_settings[key]

# Explicitly set the model to use if not already specified
if completion_data.get("model") is None:
model_name = aiconfig.get_model_name(prompt)
completion_data["model"] = model_name

return completion_data


Expand Down
33 changes: 10 additions & 23 deletions python/src/aiconfig/model_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def serialize(
**kwargs,
) -> List[Prompt]:
"""
Serialize a prompt and additional metadata/model settings into a
Serialize a prompt and additional metadata/model settings into a
list of Prompt objects that can be saved in the AIConfig.
Args:
Expand Down Expand Up @@ -150,9 +150,7 @@ def get_output_text(
str: The output text from the model inference response.
"""

def get_model_settings(
self, prompt: Prompt, aiconfig: "AIConfigRuntime"
) -> Dict[str, Any]:
def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dict[str, Any]:
"""
Extracts the AI model's settings from the configuration. If both prompt and config level settings are defined, merge them with prompt settings taking precedence.
Expand All @@ -166,10 +164,7 @@ def get_model_settings(
return aiconfig.get_global_settings(self.id())

# Check if the prompt exists in the config
if (
prompt.name not in aiconfig.prompt_index
or aiconfig.prompt_index[prompt.name] != prompt
):
if prompt.name not in aiconfig.prompt_index or aiconfig.prompt_index[prompt.name] != prompt:
raise IndexError(f"Prompt '{prompt.name}' not in config.")

model_metadata = prompt.metadata.model if prompt.metadata else None
Expand All @@ -178,25 +173,21 @@ def get_model_settings(
# Use Default Model
default_model = aiconfig.get_default_model()
if not default_model:
raise KeyError(
f"No default model specified in AIConfigMetadata, and prompt `{prompt.name}` does not specify a model."
)
raise KeyError(f"No default model specified in AIConfigMetadata, and prompt `{prompt.name}` does not specify a model.")
return aiconfig.get_global_settings(default_model)
elif isinstance(model_metadata, str):
# Use Global settings
return aiconfig.get_global_settings(model_metadata)
else:
# Merge config and prompt settings with prompt settings taking precedent
model_settings = {}
global_settings = aiconfig.get_global_settings(model_metadata.name)
prompt_setings = (
prompt.metadata.model.settings
if prompt.metadata.model.settings is not None
else {}
)
prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {}

model_name = prompt_settings.get("model", model_metadata.name)
global_settings = aiconfig.get_global_settings(model_name)

model_settings.update(global_settings)
model_settings.update(prompt_setings)
model_settings.update(prompt_settings)

return model_settings

Expand All @@ -205,11 +196,7 @@ def print_stream_callback(data, accumulated_data, index: int):
"""
Default streamCallback function that prints the output to the console.
"""
print(
"\ndata: {}\naccumulated_data:{}\nindex:{}\n".format(
data, accumulated_data, index
)
)
print("\ndata: {}\naccumulated_data:{}\nindex:{}\n".format(data, accumulated_data, index))


def print_stream_delta(data, accumulated_data, index: int):
Expand Down
3 changes: 2 additions & 1 deletion python/src/aiconfig/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def get_model_name(self, prompt: Union[str, Prompt]) -> str:
return prompt.metadata.model
else:
# Expect a ModelMetadata object
return prompt.metadata.model.name
prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {}
return prompt_settings.get("model", prompt.metadata.model.name)

def set_default_model(self, model_name: Union[str, None]):
"""
Expand Down

0 comments on commit 6022f24

Please sign in to comment.