Skip to content

Commit

Permalink
Fix image model parsers issue for output type not being correct (#739)
Browse files Browse the repository at this point in the history
Fix image model parsers issue for output type not being correct



I made this mistake when editing the outputs in image output format
(initially in #608, but
actual mistake came in
#637). It's becuase I set
this as Output TYPE (`OutputDataWithValue`), not Output CLASS
(`OutputDataWithStringValue`)

This highlights the need for:

1. Better automated tests. We currently have `load()` and `serialize()`,
but not for run: #294 . I
thought this was pretty simple change so didn't manually test

## Test Plan

Follow dev README to setup the local editor:
https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev,
then run this command (or just do it manually with local editor)
```
curl http://localhost:8080/api/run -d '{"prompt_name":"get_activities"}' -X POST -H 'Content-Type: application/json'
```

| Before | After |
|--|--|
| <img width="1920" alt="Screenshot 2024-01-03 at 15 18 14"
src="https://github.com/lastmile-ai/aiconfig/assets/151060367/0a6b9345-60bb-475a-a85f-44981e028b21">
| <img width="1920" alt="Screenshot 2024-01-03 at 15 16 35"
src="https://github.com/lastmile-ai/aiconfig/assets/151060367/428aafe1-79a0-49fc-85ee-621a78eacd04">
|

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/739).
* __->__ #739
* #737
  • Loading branch information
rossdanlm authored Jan 3, 2024
2 parents f862a24 + 9c3c2a6 commit 2567e43
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
OutputDataWithStringValue,
Prompt,
PromptMetadata,
)
Expand Down Expand Up @@ -141,7 +141,7 @@ def pillow_image_to_base64_string(img : Image.Image):
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")

data = OutputDataWithValue(
data = OutputDataWithStringValue(
kind="base64",
value=pillow_image_to_base64_string(image),
)
Expand Down Expand Up @@ -346,7 +346,7 @@ def get_output_text(
# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
if isinstance(output.data, OutputDataWithValue):
if isinstance(output.data, OutputDataWithStringValue):
return output.data.value
elif isinstance(output.data, str):
return output.data
Expand Down
40 changes: 11 additions & 29 deletions python/src/aiconfig/default_parsers/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,15 @@

import openai
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.config_utils import get_api_key_from_environment
from aiconfig.util.params import resolve_prompt
from openai import OpenAI

# Dall-E API imports
from openai.types import Image, ImagesResponse

from aiconfig.schema import ExecuteResult, Output, OutputDataWithStringValue, Prompt, PromptMetadata

# ModelParser Utils
# Type hint imports

Expand Down Expand Up @@ -48,13 +43,11 @@ def refine_image_completion_params(model_settings):
def construct_output(image_data: Image, execution_count: int) -> Output:
data = None
if image_data.b64_json is not None:
data = OutputDataWithValue(kind="base64", value=image_data.b64_json)
data = OutputDataWithStringValue(kind="base64", value=str(image_data.b64_json))
elif image_data.url is not None:
data = OutputDataWithValue(kind="file_uri", value=image_data.url)
data = OutputDataWithStringValue(kind="file_uri", value=str(image_data.url))
else:
raise ValueError(
f"Did not receive a valid image type from image_data: {image_data}"
)
raise ValueError(f"Did not receive a valid image type from image_data: {image_data}")
output = ExecuteResult(
**{
"output_type": "execute_result",
Expand Down Expand Up @@ -90,10 +83,7 @@ def __init__(self, model_id: str = "dall-e-3"):
"dall-e-3",
}
if model_id.lower() not in supported_models:
raise ValueError(
"{model_id}"
+ " is not a valid model ID for Dall-E image generation. Supported models: {supported_models}."
)
raise ValueError("{model_id}" + " is not a valid model ID for Dall-E image generation. Supported models: {supported_models}.")
self.model_id = model_id

self.client = None
Expand Down Expand Up @@ -134,16 +124,12 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)
return [prompt]

# TODO (rossdanlm): Update documentation for args
async def deserialize(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}
) -> Dict:
async def deserialize(self, prompt: Prompt, aiconfig: "AIConfigRuntime", params: Optional[Dict] = {}) -> Dict:
"""
Defines how to parse a prompt in the .aiconfig for a particular model
and constructs the completion params for that model.
Expand All @@ -163,9 +149,7 @@ async def deserialize(
completion_data["prompt"] = resolved_prompt
return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig, _options, parameters
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig, _options, parameters) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand All @@ -185,9 +169,7 @@ async def run_inference(

completion_data = await self.deserialize(prompt, aiconfig, parameters)

print(
"Calling image generation. This can take several seconds, please hold on..."
)
print("Calling image generation. This can take several seconds, please hold on...")
response: ImagesResponse = self.client.images.generate(**completion_data)

outputs = []
Expand Down Expand Up @@ -215,7 +197,7 @@ def get_output_text(
# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
if isinstance(output.data, OutputDataWithValue):
if isinstance(output.data, OutputDataWithStringValue):
return output.data.value
elif isinstance(output.data, str):
return output.data
Expand Down
23 changes: 7 additions & 16 deletions python/src/aiconfig/default_parsers/hf.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
import copy
import json
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union

# HuggingFace API imports
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import (
TextGenerationResponse,
TextGenerationStreamResponse,
)

from aiconfig import CallbackEvent
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.config_utils import get_api_key_from_environment
from aiconfig.util.params import resolve_prompt

# HuggingFace API imports
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse

from aiconfig import CallbackEvent
from aiconfig.schema import ExecuteResult, Output, OutputDataWithValue, Prompt, PromptMetadata

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
Expand Down
13 changes: 5 additions & 8 deletions python/src/aiconfig/default_parsers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,24 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import openai
from openai.types.chat import ChatCompletionMessage
from aiconfig.callback import CallbackEvent
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.util.config_utils import get_api_key_from_environment
from aiconfig.util.params import resolve_prompt, resolve_prompt_string, resolve_system_prompt
from openai.types.chat import ChatCompletionMessage

from aiconfig.schema import (
ExecuteResult,
FunctionCallData,
Output,
OutputDataWithValue,
OutputDataWithToolCallsValue,
OutputDataWithValue,
Prompt,
PromptInput,
PromptMetadata,
ToolCallData,
)
from aiconfig.util.config_utils import get_api_key_from_environment
from aiconfig.util.params import (
resolve_prompt,
resolve_prompt_string,
resolve_system_prompt,
)

if TYPE_CHECKING:
from aiconfig.Config import AIConfigRuntime
Expand Down
12 changes: 3 additions & 9 deletions python/src/aiconfig/default_parsers/palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,14 @@
from typing import TYPE_CHECKING, Dict, List, Optional

import google.generativeai as palm
from google.generativeai.text import Completion
from google.generativeai.types.discuss_types import MessageDict
from aiconfig.callback import CallbackEvent
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_parameters, resolve_prompt
from google.generativeai.text import Completion
from google.generativeai.types.discuss_types import MessageDict

from aiconfig.schema import ExecuteResult, Output, OutputDataWithValue, Prompt, PromptMetadata

if TYPE_CHECKING:
from aiconfig.Config import AIConfigRuntime
Expand Down
22 changes: 6 additions & 16 deletions python/src/aiconfig/default_parsers/parameterized_model_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from aiconfig.model_parser import InferenceOptions, ModelParser
from aiconfig.util.params import get_dependency_graph, resolve_prompt_string

from aiconfig.schema import AIConfig, ExecuteResult, JSONObject, Output, Prompt, PromptInput
from aiconfig.schema import AIConfig, JSONObject, Output, Prompt, PromptInput

if typing.TYPE_CHECKING:
from aiconfig import AIConfigRuntime
Expand Down Expand Up @@ -49,15 +49,11 @@ async def run(
) -> List[Output]:
# maybe use prompt metadata instead of kwargs?
if kwargs.get("run_with_dependencies", False):
return await self.run_with_dependencies(
prompt, aiconfig, options, parameters
)
return await self.run_with_dependencies(prompt, aiconfig, options, parameters)
else:
return await self.run_inference(prompt, aiconfig, options, parameters)

async def run_with_dependencies(
self, prompt: Prompt, aiconfig: AIConfig, options=None, parameters: Dict = {}
) -> List[Output]:
async def run_with_dependencies(self, prompt: Prompt, aiconfig: AIConfig, options=None, parameters: Dict = {}) -> List[Output]:
"""
Executes the AI model with the resolved dependencies and prompt references and returns the API response.
Expand All @@ -69,9 +65,7 @@ async def run_with_dependencies(
Returns:
ExecuteResult: An Object containing the response from the AI model.
"""
dependency_graph = get_dependency_graph(
prompt, aiconfig.prompts, aiconfig.prompt_index
)
dependency_graph = get_dependency_graph(prompt, aiconfig.prompts, aiconfig.prompt_index)

# Create a set to keep track of visited prompts
visited_prompts = set()
Expand Down Expand Up @@ -129,11 +123,7 @@ def get_prompt_template(self, prompt: Prompt, aiConfig: "AIConfigRuntime") -> st
"""
if isinstance(prompt.input, str):
return prompt.input
elif isinstance(prompt.input, PromptInput) and isinstance(
prompt.input.data, str
):
elif isinstance(prompt.input, PromptInput) and isinstance(prompt.input.data, str):
return prompt.input.data
else:
raise Exception(
f"Cannot get prompt template string from prompt input: {prompt.input}"
)
raise Exception(f"Cannot get prompt template string from prompt input: {prompt.input}")

0 comments on commit 2567e43

Please sign in to comment.