Skip to content

Commit

Permalink
Add error handling for add_ouput in Python SDK (#687)
Browse files Browse the repository at this point in the history
Summary:
This pull request addresses a critical flaw in the add_output method
within the Python SDK. Currently, the method incorrectly adds None to
the output list when the Output parameter is None, which contradicts the
intended behavior. The expected functionality is for the method to
reject None as an invalid output or, alternatively, raise an exception.

Changes Made:
- Modified the add_output method to raise an exception when the Output
parameter is None.

Context:
This change ensures that the add_output method behaves as expected,
providing more robust and predictable behavior when handling output
parameters.

Related Issues:
Closes #529 

Test Plan:
New test similar to `test_add_output_existing_prompt_no_overwrite` but
with overwrite needs to be added.

Things left to complete:
- [x] Add test for `add_output` with overwrite enabled
  • Loading branch information
rossdanlm authored Jan 3, 2024
2 parents cec0103 + 0513ac4 commit ed5fb9d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 15 deletions.
8 changes: 5 additions & 3 deletions python/src/aiconfig/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def delete_metadata(self, key: str, prompt_name: Optional[str] = None):

def add_output(self, prompt_name: str, output: Output, overwrite: bool = False):
"""
Add an output to the [rompt with the given name in the AIConfig
Add an output to the prompt with the given name in the AIConfig
Args:
prompt_name (str): The name of the prompt to add the output to.
Expand All @@ -834,8 +834,10 @@ def add_output(self, prompt_name: str, output: Output, overwrite: bool = False):
"""
prompt = self.get_prompt(prompt_name)
if not prompt:
raise IndexError(f"Cannot out output. Prompt '{prompt_name}' not found in config.")
if overwrite or not output:
raise IndexError(f"Cannot add output. Prompt '{prompt_name}' not found in config.")
if not output:
raise ValueError(f"Cannot add output to prompt '{prompt_name}'. Output is not defined.")
if overwrite:
prompt.outputs = [output]
else:
prompt.outputs.append(output)
Expand Down
80 changes: 68 additions & 12 deletions python/tests/test_programmatically_create_an_AIConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,12 +447,12 @@ def test_set_and_delete_metadata_ai_config(ai_config_runtime: AIConfigRuntime):

def test_set_and_delete_metadata_ai_config_prompt(ai_config_runtime: AIConfigRuntime):
"""Test deleting a non-existent metadata key at the AIConfig level."""
prompt1 = Prompt(
prompt = Prompt(
name="GreetingPrompt",
input="Hello, how are you?",
metadata=PromptMetadata(model="fakemodel"),
)
ai_config_runtime.add_prompt(prompt1.name, prompt1)
ai_config_runtime.add_prompt(prompt.name, prompt)
ai_config_runtime.set_metadata("testkey", "testvalue", "GreetingPrompt")

assert (
Expand All @@ -469,36 +469,92 @@ def test_set_and_delete_metadata_ai_config_prompt(ai_config_runtime: AIConfigRun

def test_add_output_existing_prompt_no_overwrite(ai_config_runtime: AIConfigRuntime):
"""Test adding an output to an existing prompt without overwriting."""
prompt1 = Prompt(
prompt = Prompt(
name="GreetingPrompt",
input="Hello, how are you?",
metadata=PromptMetadata(model="fakemodel"),
)
ai_config_runtime.add_prompt(prompt1.name, prompt1)
ai_config_runtime.add_prompt(prompt.name, prompt)
test_result = ExecuteResult(
output_type="execute_result",
execution_count=0.0,
data={"role": "assistant", "content": "test output"},
metadata={"finish_reason": "stop"},
execution_count=0,
data="test output",
metadata={
"raw_response": {"role": "assistant", "content": "test output"}
},
)
ai_config_runtime.add_output("GreetingPrompt", test_result)

assert ai_config_runtime.get_latest_output("GreetingPrompt") == test_result

test_result2 = ExecuteResult(
output_type="execute_result",
execution_count=0.0,
data={"role": "assistant", "content": "test output"},
metadata={"finish_reason": "stop"},
)
execution_count=0,
data="test output",
metadata={
"raw_response": {"role": "assistant", "content": "test output for second time"}
},
)

ai_config_runtime.add_output("GreetingPrompt", test_result2)
assert ai_config_runtime.get_latest_output("GreetingPrompt") == test_result2

ai_config_runtime.delete_output("GreetingPrompt")

assert ai_config_runtime.get_latest_output("GreetingPrompt") == None


def test_add_output_existing_prompt_overwrite(ai_config_runtime: AIConfigRuntime):
"""Test adding an output to an existing prompt with overwriting."""
original_output = ExecuteResult(
output_type="execute_result",
execution_count=0,
data="original output",
metadata={
"raw_response": {"role": "assistant", "content": "original output"}
},
)
prompt = Prompt(
name="GreetingPrompt",
input="Hello, how are you?",
metadata=PromptMetadata(model="fakemodel"),
outputs=[original_output],
)
ai_config_runtime.add_prompt(prompt.name, prompt)
# check that the original_output is there
assert ai_config_runtime.get_latest_output("GreetingPrompt") == original_output
expected_output = ExecuteResult(
output_type="execute_result",
execution_count=0,
data="original output",
metadata={
"raw_response": {"role": "assistant", "content": "original output"}
},
)
# overwrite the original_output
ai_config_runtime.add_output("GreetingPrompt", expected_output, True)
assert ai_config_runtime.get_latest_output("GreetingPrompt") == expected_output

def test_add_undefined_output_to_prompt(ai_config_runtime: AIConfigRuntime):
"""Test for adding an undefined output to a prompt with/without overwriting. Should result in an error."""
prompt = Prompt(
name="GreetingPrompt",
input="Hello, how are you?",
metadata=PromptMetadata(model="fakemodel"),
)
ai_config_runtime.add_prompt(prompt.name, prompt)
# Case 1: No output, overwrite param not defined
with pytest.raises(
ValueError,
match=r"Cannot add output to prompt 'GreetingPrompt'. Output is not defined.",
):
ai_config_runtime.add_output("GreetingPrompt", None)
# Case 2: No output, overwrite param set to True
with pytest.raises(
ValueError,
match=r"Cannot add output to prompt 'GreetingPrompt'. Output is not defined.",
):
ai_config_runtime.add_output("GreetingPrompt", None, True)

def test_extract_override_settings(ai_config_runtime: AIConfigRuntime):
initial_settings = {"topP": 0.9}

Expand Down

0 comments on commit ed5fb9d

Please sign in to comment.