Skip to content

Commit

Permalink
[python][image] Save output.data images in OutputData format, not jus…
Browse files Browse the repository at this point in the history
…t strings

This will make it easy for Ryan on frontend to know exactly what data type we need. This is also backwards compatible for existing AIConfigs that have data saved only as a text string

##test
Automated tests
  • Loading branch information
Rossdan Craig [email protected] committed Dec 24, 2023
1 parent bff0b53 commit 430623a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata
from aiconfig.schema import (
ExecuteResult,
Output,
OutputData,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_prompt

# Circuluar Dependency Type Hints
Expand Down Expand Up @@ -116,7 +122,7 @@ def refine_image_completion_params(unfiltered_completion_params: Dict[str, Any])
return completion_params


def construct_regular_output(
def construct_output(
image: Image.Image,
has_nsfw: Optional[bool],
execution_count: int
Expand All @@ -135,10 +141,14 @@ def pillow_image_to_base64_string(img : Image.Image):
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")

data = OutputData(
kind="base64",
value=pillow_image_to_base64_string(image),
)
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": pillow_image_to_base64_string(image),
"data": data,
"execution_count": execution_count,
"metadata": {"has_nsfw": has_nsfw},
"mime_type": "image/png",
Expand Down Expand Up @@ -315,7 +325,7 @@ async def run_inference(
has_nsfw_responses,
)
):
output = construct_regular_output(image, has_nsfw, count)
output = construct_output(image, has_nsfw, count)
outputs.append(output)

prompt.outputs = outputs
Expand All @@ -338,8 +348,9 @@ def get_output_text(
if output.output_type == "execute_result":
if isinstance(output.data, str):
return output.data
else:
return ""
elif isinstance(output.data, OutputData):
return output.data.value
return ""

def _get_device(self) -> str:
if torch.cuda.is_available():
Expand Down
29 changes: 22 additions & 7 deletions python/src/aiconfig/default_parsers/dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@

import openai
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.schema import (
ExecuteResult,
Output,
OutputData,
Prompt,
PromptMetadata,
)
from aiconfig.util.config_utils import get_api_key_from_environment
from aiconfig.util.params import resolve_prompt
from openai import OpenAI

# Dall-E API imports
from openai.types import Image, ImagesResponse

from aiconfig.schema import ExecuteResult, Output, Prompt, PromptMetadata

# ModelParser Utils
# Type hint imports

Expand Down Expand Up @@ -41,10 +46,19 @@ def refine_image_completion_params(model_settings):


def construct_output(image_data: Image, execution_count: int) -> Output:
data = None
if image_data.b64_json is not None:
data = OutputData(kind="base64", value=image_data.b64_json)
elif image_data.url is not None:
data = OutputData(kind="file_uri", value=image_data.url)
else:
raise ValueError(
f"Did not receive a valid image type from image_data: {image_data}"
)
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": image_data.b64_json or image_data.url,
"data": data,
"execution_count": execution_count,
"metadata": {"revised_prompt": image_data.revised_prompt},
"mime_type": "image/png",
Expand Down Expand Up @@ -124,7 +138,7 @@ async def serialize(
model=model_metadata, parameters=parameters, **kwargs
),
)
return [prompt]
return prompt

# TODO (rossdanlm): Update documentation for args
async def deserialize(
Expand Down Expand Up @@ -201,7 +215,8 @@ def get_output_text(
# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
if isinstance(output.data, str):
if isinstance(output.data, OutputData):
return output.data.value
elif isinstance(output.data, str):
return output.data
else:
return ""
return ""
5 changes: 2 additions & 3 deletions python/tests/parsers/test_dalle_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@ async def test_serialize_basic(set_temporary_env_vars: None):
"size": "1024x1024",
}
aiconfig = AIConfigRuntime.create()
serialized_prompts = await aiconfig.serialize(
serialized_prompt = await aiconfig.serialize(
"dall-e-3", completion_params, prompt_name="panda_eating_dumplings"
)
new_prompt = serialized_prompts[0]
assert new_prompt == Prompt(
assert serialized_prompt == Prompt(
name="panda_eating_dumplings",
input="Panda eating dumplings on a yellow mountain",
metadata=PromptMetadata(
Expand Down

0 comments on commit 430623a

Please sign in to comment.