Skip to content

Commit

Permalink
Explicitly set "model" property for models that support it
Browse files Browse the repository at this point in the history
This change allows us to automatically set the "model" for the following model parsers:
* OpenAI
* Anyscale Endpoint
* Hugging Face text generation
* PaLM
* Dall-E

Without this change the local editor breaks because if a user doesn't explicitly specify the "model" property in settings metadata, inference fails for (some of) the above model parsers. This isn't ideal, because the user has already selected the model they want to use from the model selector, so shouldn't need to additionally specify a model id as well.

There is still a larger issue of how we distinguish a model ID from a model parser ID. Filed #782 to come up with a proper fix for that.
  • Loading branch information
saqadri committed Jan 5, 2024
1 parent aa1eff3 commit 4925521
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cookbooks/HuggingFace/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ async def deserialize(
Returns:
dict: Model-specific completion parameters.
"""
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig)

# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
Expand Down
11 changes: 8 additions & 3 deletions python/src/aiconfig/default_parsers/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


# TODO: Centralize this into utils function with the HuggingFaceTextParser class
def refine_image_completion_params(model_settings):
def refine_image_completion_params(model_settings, aiconfig, prompt):
"""
Refines the completion params for the Dall-E request API. Removes any unsupported params.
The supported keys were found by looking at the OpenAI Dall-E API: https://platform.openai.com/docs/api-reference/images/create?lang=python`
Expand All @@ -37,6 +37,11 @@ def refine_image_completion_params(model_settings):
if key.lower() in supported_keys:
completion_data[key.lower()] = model_settings[key]

# Explicitly set the model to use if not already specified
if completion_data.get("model") is None:
model_name = aiconfig.get_model_name(prompt)
completion_data["model"] = model_name

return completion_data


Expand Down Expand Up @@ -141,11 +146,11 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
dict: Model-specific completion parameters.
"""
# Get inputs from aiconfig
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
resolved_prompt = resolve_prompt(prompt, params if params is not None else {}, aiconfig)
model_settings = self.get_model_settings(prompt, aiconfig)

# Build Completion data
completion_data = refine_image_completion_params(model_settings)
completion_data = refine_image_completion_params(model_settings, aiconfig, prompt)
completion_data["prompt"] = resolved_prompt
return completion_data

Expand Down
11 changes: 8 additions & 3 deletions python/src/aiconfig/default_parsers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# Build Completion params
model_settings = self.get_model_settings(prompt, aiconfig)

completion_params = refine_chat_completion_params(model_settings)
completion_params = refine_chat_completion_params(model_settings, aiconfig, prompt)

# In the case thhat the messages array weren't saves as part of the model settings, build it here. Messages array is used for conversation history.
if not completion_params.get("messages"):
Expand Down Expand Up @@ -407,10 +407,10 @@ def multi_choice_message_reducer(messages: Union[Dict[int, dict], None], chunk:
return messages


def refine_chat_completion_params(model_settings):
def refine_chat_completion_params(model_settings, aiconfig, prompt):
# completion parameters to be used for openai's chat completion api
# system prompt handled separately
# streaming handled seperatly.
# streaming handled separately.
# Messages array built dynamically and handled separately
supported_keys = {
"frequency_penalty",
Expand All @@ -436,6 +436,11 @@ def refine_chat_completion_params(model_settings):
if key in model_settings:
completion_data[key] = model_settings[key]

# Explicitly set the model to use if not already specified
if completion_data.get("model") is None:
model_name = aiconfig.get_model_name(prompt)
completion_data["model"] = model_name

return completion_data


Expand Down
11 changes: 8 additions & 3 deletions python/src/aiconfig/default_parsers/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)

completion_data = refine_chat_completion_params(model_settings)
completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt)

prompt_str = resolve_prompt(prompt, params, aiconfig)

Expand Down Expand Up @@ -246,7 +246,7 @@ async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params:
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)

completion_data = refine_chat_completion_params(model_settings)
completion_data = refine_chat_completion_params(model_settings, aiconfig, prompt)

# TODO: handle if user specifies previous messages in settings
completion_data["messages"] = []
Expand Down Expand Up @@ -366,7 +366,7 @@ def get_output_text(
return ""


def refine_chat_completion_params(model_settings):
def refine_chat_completion_params(model_settings, aiconfig, prompt):
# completion parameters to be used for Palm's chat completion api
# messages handled seperately
supported_keys = {
Expand All @@ -384,6 +384,11 @@ def refine_chat_completion_params(model_settings):
if key.lower() in supported_keys:
completion_data[key.lower()] = model_settings[key]

# Explicitly set the model to use if not already specified
if completion_data.get("model") is None:
model_name = aiconfig.get_model_name(prompt)
completion_data["model"] = model_name

return completion_data


Expand Down
8 changes: 5 additions & 3 deletions python/src/aiconfig/model_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ def get_model_settings(self, prompt: Prompt, aiconfig: "AIConfigRuntime") -> Dic
else:
# Merge config and prompt settings with prompt settings taking precedent
model_settings = {}
global_settings = aiconfig.get_global_settings(model_metadata.name)
prompt_setings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {}
prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {}

model_name = prompt_settings.get("model", model_metadata.name)
global_settings = aiconfig.get_global_settings(model_name)

model_settings.update(global_settings)
model_settings.update(prompt_setings)
model_settings.update(prompt_settings)

return model_settings

Expand Down
3 changes: 2 additions & 1 deletion python/src/aiconfig/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,8 @@ def get_model_name(self, prompt: Union[str, Prompt]) -> str:
return prompt.metadata.model
else:
# Expect a ModelMetadata object
return prompt.metadata.model.name
prompt_settings = prompt.metadata.model.settings if prompt.metadata.model.settings is not None else {}
return prompt_settings.get("model", prompt.metadata.model.name)

def set_default_model(self, model_name: Union[str, None]):
"""
Expand Down
28 changes: 20 additions & 8 deletions python/tests/parsers/test_openai_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,32 @@
from aiconfig.default_parsers.openai import refine_chat_completion_params
from mock import patch

from aiconfig.schema import ExecuteResult, Prompt, PromptInput, PromptMetadata
from aiconfig.schema import ExecuteResult, Prompt, PromptInput, PromptMetadata, ModelMetadata

from ..conftest import mock_openai_chat_completion
from ..util.file_path_utils import get_absolute_file_path_from_relative


def test_refine_chat_completion_params():
model_settings_with_stream_and_system_prompt = {
"n": "3",
"stream": True,
"system_prompt": "system_prompt",
"random_attribute": "value_doesn't_matter",
}
refined_params = refine_chat_completion_params(model_settings_with_stream_and_system_prompt)
prompt = Prompt(
name="test",
input="test",
metadata=PromptMetadata(
model=ModelMetadata(
name="gpt-4",
settings={
"n": "3",
"stream": True,
"system_prompt": "system_prompt",
"random_attribute": "value_doesn't_matter",
},
)
),
)

aiconfig = AIConfigRuntime.create(name="test_refine_chat_completion_params", prompts=[prompt])

refined_params = refine_chat_completion_params(prompt.metadata.model.settings, aiconfig, prompt)

assert "system_prompt" not in refined_params
assert "stream" in refined_params
Expand Down

0 comments on commit 4925521

Please sign in to comment.