Skip to content

Commit

Permalink
Fix parameterization issues (#866)
Browse files Browse the repository at this point in the history
Fix parameterization issues

Now that we have non-text input model parsers, we can see some issues in
our parameterization.

1) ASR and Image-to-Text modelparsers should NOT be
`ParameterizedModelParser` instances, since their inputs cannot be
parameterized.

2) Parameter resolution logic shouldn't throw errors in the case where
it's parsing prompts belonging to regular model parsers.

Test Plan:
Ran the translation prompt for this aiconfig:

```
prompts: [
{
      "name": "image_caption",
      "input": {
        "attachments": [
          {
            "data": "https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/1742/bear_with_honey.png",
            "mime_type": "image/png"
          },
          {
            "data": "https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/7275/fox_in_forest.png",
            "mime_type": "image/png"
          }
        ]
      },
      "metadata": {
        "model": {
          "name": "HuggingFaceImage2TextTransformer",
          "settings": {
            "max_new_tokens": 10,
            "model": "Salesforce/blip-image-captioning-base"
          }
        },
        "parameters": {}
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "execution_count": 0,
          "data": "a bear sitting on a rock eating honey",
          "metadata": {}
        },
        {
          "output_type": "execute_result",
          "execution_count": 1,
          "data": "a red fox in the woods",
          "metadata": {}
        }
      ]
    },
    {
      "name": "translation",
      "input": "Once upon a time, in a lush and vibrant forest, there lived a magnificent creature known as the Quick Brown Fox. This fox was unlike any other, possessing incredible speed and agility that awed all the animals in the forest. With its fur as golden as the sun and its eyes as sharp as emeralds, the Quick Brown Fox was admired by everyone, from the tiniest hummingbird to the mightiest bear. The fox had a kind heart and would often lend a helping paw to those in need. The Quick Brown Fox had a particular fondness for games and challenges. It loved to test its skills against others, always seeking new adventures to satisfy its boundless curiosity. Its favorite game was called \"The Great Word Hunt,\" where it would embark on a quest to find hidden words scattered across the forest. \n\nOne day it got very old and died",
      "metadata": {
        "model": {
          "name": "HuggingFaceTextTranslationTransformer",
          "settings": {
            "model": "Helsinki-NLP/opus-mt-en-fr",
            "max_length": 100,
            "min_length": 50,
            "num_beams": 1
          }
        },
        "parameters": {}
      }
    }
]
```

Before:

```
File "/opt/homebrew/lib/python3.11/site-packages/aiconfig/util/params.py", line 235, in get_prompt_template
    raise Exception(f"Cannot get prompt template string from prompt: {prompt.input}")
Exception: Cannot get prompt template string from prompt: attachments=[Attachment(data='https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/1742/bear_with_honey.png', mime_type='image/png', metadata=None), Attachment(data='https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/7275/fox_in_forest.png', mime_type='image/png', metadata=None)] data=None
```

After:
* Beautiful translation

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/866).
* #826
* __->__ #866
* #863
  • Loading branch information
saqadri authored Jan 10, 2024
2 parents 9f5cd3d + 69778fa commit bf2fc42
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
from .local_inference.util import get_hf_model

UTILS = [get_hf_model]

LOCAL_INFERENCE_CLASSES = [
"HuggingFaceAutomaticSpeechRecognitionTransformer",
Expand All @@ -18,4 +20,4 @@
"HuggingFaceTextTranslationTransformer",
]
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES + UTILS
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import torch
from transformers import pipeline, Pipeline
from aiconfig import ParameterizedModelParser, InferenceOptions
from aiconfig_extension_hugging_face.local_inference.util import get_hf_model
from aiconfig import ModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment


if TYPE_CHECKING:
from aiconfig import AIConfigRuntime


class HuggingFaceAutomaticSpeechRecognitionTransformer(ParameterizedModelParser):
class HuggingFaceAutomaticSpeechRecognitionTransformer(ModelParser):
"""
Model Parser for HuggingFace ASR (Automatic Speech Recognition) models.
"""
Expand Down Expand Up @@ -85,7 +85,7 @@ async def deserialize(
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data}))
return completion_data

async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_start",
Expand All @@ -96,15 +96,16 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio

model_settings = self.get_model_settings(prompt, aiconfig)
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)
model_name = aiconfig.get_model_name(prompt)
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"

if isinstance(model_name, str) and model_name not in self.pipelines:
if key not in self.pipelines:
device = self._get_device()
if pipeline_creation_data.get("device", None) is None:
pipeline_creation_data["device"] = device
self.pipelines[model_name] = pipeline(task="automatic-speech-recognition", **pipeline_creation_data)
self.pipelines[key] = pipeline(task="automatic-speech-recognition", model=model_name, **pipeline_creation_data)

asr_pipeline = self.pipelines[model_name]
asr_pipeline = self.pipelines[key]
completion_data = await self.deserialize(prompt, aiconfig, parameters)

response = asr_pipeline(**completion_data)
Expand Down Expand Up @@ -234,8 +235,8 @@ def refine_asr_completion_params(unfiltered_completion_params: Dict[str, Any]) -
Note: This doesn't support base pipeline params like `num_workers`
TODO: Figure out how to find which params are supported.
TODO: Distinguish pipeline creation and refine completion
https://github.com/lastmile-ai/aiconfig/issues/825
TODO: Distinguish pipeline creation and refine completion
https://github.com/lastmile-ai/aiconfig/issues/825
https://github.com/lastmile-ai/aiconfig/issues/824
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
pipeline,
)

from aiconfig import ParameterizedModelParser, InferenceOptions
from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig import ModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
from aiconfig.schema import (
Attachment,
Expand All @@ -22,7 +24,7 @@
from aiconfig import AIConfigRuntime


class HuggingFaceImage2TextTransformer(ParameterizedModelParser):
class HuggingFaceImage2TextTransformer(ModelParser):
def __init__(self):
"""
Returns:
Expand Down Expand Up @@ -108,14 +110,14 @@ async def deserialize(
model_settings = self.get_model_settings(prompt, aiconfig)
completion_params = refine_completion_params(model_settings)

#Add image inputs
# Add image inputs
inputs = validate_and_retrieve_images_from_attachments(prompt)
completion_params["inputs"] = inputs

await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
return completion_params

async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]:
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_start",
Expand All @@ -127,10 +129,12 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
completion_data = await self.deserialize(prompt, aiconfig, parameters)
inputs = completion_data.pop("inputs")

model_name: str | None = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.pipelines:
self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)
captioner = self.pipelines[model_name]
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"

if key not in self.pipelines:
self.pipelines[key] = pipeline(task="image-to-text", model=model_name)
captioner = self.pipelines[key]

outputs: List[Output] = []
response: List[Any] = captioner(inputs, **completion_data)
Expand Down Expand Up @@ -189,6 +193,7 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:

return completion_data


# Helper methods
def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output:
"""
Expand All @@ -198,7 +203,7 @@ def construct_regular_output(result: Dict[str, str], execution_count: int) -> Ou
**{
"output_type": "execute_result",
# For some reason result is always in list format we haven't found
# a way of being able to return multiple sequences from the image
# a way of being able to return multiple sequences from the image
# to text pipeline
"data": result[0]["generated_text"],
"execution_count": execution_count,
Expand Down Expand Up @@ -251,7 +256,7 @@ def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[
# vs. uri. This will be fixed once we have standardized inputs
# See https://github.com/lastmile-ai/aiconfig/issues/829
if len(input_data) > 10000:
pil_image : Image = Image.open(BytesIO(base64.b64decode(input_data)))
pil_image: Image = Image.open(BytesIO(base64.b64decode(input_data)))
images.append(pil_image)
else:
images.append(input_data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PIL import Image
from transformers import Pipeline

from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
Expand Down Expand Up @@ -124,11 +125,12 @@ def refine_image_completion_params(unfiltered_completion_params: Dict[str, Any])
return completion_params


class ImageData():
class ImageData:
"""
Helper class to store each image response data as fields instead
Helper class to store each image response data as fields instead
of separate arrays. See `_refine_responses` for more details
"""

image: Image.Image
nsfw_content_detected: bool

Expand Down Expand Up @@ -289,17 +291,17 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
"""
print(pipeline_building_disclaimer_message)

model_name: str = aiconfig.get_model_name(prompt)
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"

# TODO (rossdanlm): Figure out a way to save model and re-use checkpoint
# Otherwise right now a lot of these models are taking 5 mins to load with 50
# num_inference_steps (default value). See here for more details:
# https://huggingface.co/docs/diffusers/using-diffusers/loading#checkpoint-variants
if isinstance(model_name, str) and model_name not in self.generators:
if key not in self.generators:
device = self._get_device()
self.generators[model_name] = AutoPipelineForText2Image.from_pretrained(pretrained_model_or_path=model_name, **pipeline_creation_data).to(
device
)
generator = self.generators[model_name]
self.generators[key] = AutoPipelineForText2Image.from_pretrained(pretrained_model_or_path=model_name, **pipeline_creation_data).to(device)
generator = self.generators[key]

disclaimer_long_response_print_message = """\n
Calling image generation. This can take a long time, (up to SEVERAL MINUTES depending
Expand Down Expand Up @@ -370,21 +372,22 @@ def _get_device(self) -> str:
return "mps"
return "cpu"


def _refine_responses(
response_images: List[Image.Image],
response_images: List[Image.Image],
nsfw_content_detected: List[bool],
) -> List[ImageData]:
"""
Helper function for taking the separate response data lists (`images` and
`nsfw_content_detected`) from StableDiffusionPipelineOutput or
`nsfw_content_detected`) from StableDiffusionPipelineOutput or
StableDiffusionXLPipelineOutput and merging this data into a single array
containing ImageData which stores information at the image-level. This
makes processing later easier since all the data we need is stored in a
containing ImageData which stores information at the image-level. This
makes processing later easier since all the data we need is stored in a
single object, so we don't need to compare two separate lists
Args:
response_images List[Image.Image]: List of images
nsfw_content_detected List[bool]: List of whether the image at that
nsfw_content_detected List[bool]: List of whether the image at that
corresponding index from `response_images` has detected that it
contains nsfw_content. It is possible for this list to be empty
Expand All @@ -396,8 +399,5 @@ def _refine_responses(
# Use zip.longest because nsfw_content_detected can be empty
itertools.zip_longest(response_images, nsfw_content_detected)
)
image_data_objects: List[ImageData] = [
ImageData(image=image, nsfw_content_detected=has_nsfw)
for (image, has_nsfw) in merged_responses
]
image_data_objects: List[ImageData] = [ImageData(image=image, nsfw_content_detected=has_nsfw) for (image, has_nsfw) in merged_responses]
return image_data_objects
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import io
import json
import numpy as np
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import Pipeline, pipeline
from scipy.io.wavfile import write as write_wav

from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
Expand All @@ -24,7 +26,7 @@

# Step 1: define Helpers
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
# These are from the transformers Github repo:
# These are from the transformers Github repo:
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2534
supported_keys = {
"torch_dtype",
Expand Down Expand Up @@ -63,7 +65,7 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict

def refine_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]:
# Note: There seems to be no public API docs on what completion
# params are supported for text to speech:
# params are supported for text to speech:
# https://huggingface.co/docs/transformers/tasks/text-to-speech#inference
# The only one mentioned is `forward_params` which can contain `speaker_embeddings`
supported_keys = {}
Expand Down Expand Up @@ -194,10 +196,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
model_settings = self.get_model_settings(prompt, aiconfig)
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)

model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.synthesizers:
self.synthesizers[model_name] = pipeline("text-to-speech", model_name)
synthesizer = self.synthesizers[model_name]
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"
if key not in self.synthesizers:
self.synthesizers[key] = pipeline("text-to-speech", model=model_name)
synthesizer = self.synthesizers[key]

completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
inputs = completion_data.pop("prompt", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
TextIteratorStreamer,
)

from aiconfig_extension_hugging_face.local_inference.util import get_hf_model

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
Expand Down Expand Up @@ -79,7 +81,7 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
"encoder_no_repeat_ngram_size",
"decoder_start_token_id",
"num_assistant_tokens",
"num_assistant_tokens_schedule"
"num_assistant_tokens_schedule",
}

completion_data = {}
Expand All @@ -104,6 +106,7 @@ def construct_regular_output(result: Dict[str, str], execution_count: int) -> Ou
)
return output


def construct_stream_output(
streamer: TextIteratorStreamer,
options: InferenceOptions,
Expand All @@ -119,13 +122,13 @@ def construct_stream_output(
"""
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": "", # We update this below
"execution_count": 0, #Multiple outputs are not supported for streaming
"metadata": {},
}
)
**{
"output_type": "execute_result",
"data": "", # We update this below
"execution_count": 0, # Multiple outputs are not supported for streaming
"metadata": {},
}
)
accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
Expand All @@ -152,7 +155,7 @@ def __init__(self):
config.register_model_parser(parser)
"""
super().__init__()
self.generators: dict[str, Pipeline]= {}
self.generators: dict[str, Pipeline] = {}

def id(self) -> str:
"""
Expand Down Expand Up @@ -190,9 +193,7 @@ async def serialize(
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(
model=model_metadata, parameters=parameters, **kwargs
),
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)
return [prompt]

Expand All @@ -217,14 +218,12 @@ async def deserialize(
model_settings = self.get_model_settings(prompt, aiconfig)
completion_data = refine_completion_params(model_settings)

#Add resolved prompt
# Add resolved prompt
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
completion_data["prompt"] = resolved_prompt
return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]
) -> List[Output]:
async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Expand All @@ -239,16 +238,15 @@ async def run_inference(
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
completion_data["text_inputs"] = completion_data.pop("prompt", None)

model_name: str | None = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.generators:
self.generators[model_name] = pipeline('text-generation', model=model_name)
generator = self.generators[model_name]
model_name = get_hf_model(aiconfig, prompt, self)
key = model_name if model_name is not None else "__default__"
if key not in self.generators:
self.generators[key] = pipeline("text-generation", model=model_name)
generator = self.generators[key]

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
if should_stream:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
Expand Down
Loading

0 comments on commit bf2fc42

Please sign in to comment.