From 5e533b0a0c68603115a4392ad160dd782c80e858 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Wed, 10 Jan 2024 17:45:19 -0500 Subject: [PATCH 1/2] Support custom specified models for HF model parsers The HF prompt schema has been updated in #850 to support a "model" property, which can specify the model to use for a given HF task. This change updates the model parsers to use this setting. Test Plan: Tested with the following prompt: ``` { "name": "image_caption", "input": { "attachments": [ { "data": "https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/1742/bear_with_honey.png", "mime_type": "image/png" }, { "data": "https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/7275/fox_in_forest.png", "mime_type": "image/png" } ] }, "metadata": { "model": { "name": "HuggingFaceImage2TextTransformer", "settings": { "max_new_tokens": 10, "model": "Salesforce/blip-image-captioning-base" } }, "parameters": {} }, "outputs": [ { "output_type": "execute_result", "execution_count": 0, "data": "a bear sitting on a rock eating honey", "metadata": {} }, { "output_type": "execute_result", "execution_count": 1, "data": "a red fox in the woods", "metadata": {} } ] }, ``` Validated the the model setting was respected and worked. --- .../__init__.py | 4 +- .../automatic_speech_recognition.py | 15 ++++--- .../local_inference/image_2_text.py | 19 +++++--- .../local_inference/text_2_image.py | 34 +++++++------- .../local_inference/text_2_speech.py | 17 ++++--- .../local_inference/text_generation.py | 44 +++++++++---------- .../local_inference/text_summarization.py | 15 ++++--- .../local_inference/text_translation.py | 15 ++++--- .../local_inference/util.py | 28 ++++++++++++ 9 files changed, 115 insertions(+), 76 deletions(-) create mode 100644 extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py index 70f811527..a6e28cab8 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py @@ -6,7 +6,9 @@ from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer from .local_inference.text_translation import HuggingFaceTextTranslationTransformer from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser +from .local_inference.util import get_hf_model +UTILS = [get_hf_model] LOCAL_INFERENCE_CLASSES = [ "HuggingFaceAutomaticSpeechRecognitionTransformer", @@ -18,4 +20,4 @@ "HuggingFaceTextTranslationTransformer", ] REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"] -__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES +__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES + UTILS diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py index c627e8986..82a5f37f5 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py @@ -2,11 +2,11 @@ import torch from transformers import pipeline, Pipeline +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model from aiconfig import ParameterizedModelParser, InferenceOptions from aiconfig.callback import CallbackEvent from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment - if TYPE_CHECKING: from aiconfig import AIConfigRuntime @@ -96,15 +96,16 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio model_settings = self.get_model_settings(prompt, aiconfig) [pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings) - model_name = aiconfig.get_model_name(prompt) + model_name = get_hf_model(aiconfig, prompt, self) + key = model_name if model_name is not None else "__default__" - if isinstance(model_name, str) and model_name not in self.pipelines: + if key not in self.pipelines: device = self._get_device() if pipeline_creation_data.get("device", None) is None: pipeline_creation_data["device"] = device - self.pipelines[model_name] = pipeline(task="automatic-speech-recognition", **pipeline_creation_data) + self.pipelines[key] = pipeline(task="automatic-speech-recognition", model=model_name, **pipeline_creation_data) - asr_pipeline = self.pipelines[model_name] + asr_pipeline = self.pipelines[key] completion_data = await self.deserialize(prompt, aiconfig, parameters) response = asr_pipeline(**completion_data) @@ -234,8 +235,8 @@ def refine_asr_completion_params(unfiltered_completion_params: Dict[str, Any]) - Note: This doesn't support base pipeline params like `num_workers` TODO: Figure out how to find which params are supported. - TODO: Distinguish pipeline creation and refine completion - https://github.com/lastmile-ai/aiconfig/issues/825 + TODO: Distinguish pipeline creation and refine completion + https://github.com/lastmile-ai/aiconfig/issues/825 https://github.com/lastmile-ai/aiconfig/issues/824 """ diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py index 9d72ef9c2..5e7f2527b 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py @@ -8,6 +8,8 @@ pipeline, ) +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model + from aiconfig import ParameterizedModelParser, InferenceOptions from aiconfig.callback import CallbackEvent from aiconfig.schema import ( @@ -108,7 +110,7 @@ async def deserialize( model_settings = self.get_model_settings(prompt, aiconfig) completion_params = refine_completion_params(model_settings) - #Add image inputs + # Add image inputs inputs = validate_and_retrieve_images_from_attachments(prompt) completion_params["inputs"] = inputs @@ -127,10 +129,12 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio completion_data = await self.deserialize(prompt, aiconfig, parameters) inputs = completion_data.pop("inputs") - model_name: str | None = aiconfig.get_model_name(prompt) - if isinstance(model_name, str) and model_name not in self.pipelines: - self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name) - captioner = self.pipelines[model_name] + model_name = get_hf_model(aiconfig, prompt, self) + key = model_name if model_name is not None else "__default__" + + if key not in self.pipelines: + self.pipelines[key] = pipeline(task="image-to-text", model=model_name) + captioner = self.pipelines[key] outputs: List[Output] = [] response: List[Any] = captioner(inputs, **completion_data) @@ -189,6 +193,7 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: return completion_data + # Helper methods def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: """ @@ -198,7 +203,7 @@ def construct_regular_output(result: Dict[str, str], execution_count: int) -> Ou **{ "output_type": "execute_result", # For some reason result is always in list format we haven't found - # a way of being able to return multiple sequences from the image + # a way of being able to return multiple sequences from the image # to text pipeline "data": result[0]["generated_text"], "execution_count": execution_count, @@ -251,7 +256,7 @@ def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[ # vs. uri. This will be fixed once we have standardized inputs # See https://github.com/lastmile-ai/aiconfig/issues/829 if len(input_data) > 10000: - pil_image : Image = Image.open(BytesIO(base64.b64decode(input_data))) + pil_image: Image = Image.open(BytesIO(base64.b64decode(input_data))) images.append(pil_image) else: images.append(input_data) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py index f9a82fb13..360ef9dc5 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_image.py @@ -11,6 +11,7 @@ from PIL import Image from transformers import Pipeline +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions @@ -124,11 +125,12 @@ def refine_image_completion_params(unfiltered_completion_params: Dict[str, Any]) return completion_params -class ImageData(): +class ImageData: """ - Helper class to store each image response data as fields instead + Helper class to store each image response data as fields instead of separate arrays. See `_refine_responses` for more details """ + image: Image.Image nsfw_content_detected: bool @@ -289,17 +291,17 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio """ print(pipeline_building_disclaimer_message) - model_name: str = aiconfig.get_model_name(prompt) + model_name = get_hf_model(aiconfig, prompt, self) + key = model_name if model_name is not None else "__default__" + # TODO (rossdanlm): Figure out a way to save model and re-use checkpoint # Otherwise right now a lot of these models are taking 5 mins to load with 50 # num_inference_steps (default value). See here for more details: # https://huggingface.co/docs/diffusers/using-diffusers/loading#checkpoint-variants - if isinstance(model_name, str) and model_name not in self.generators: + if key not in self.generators: device = self._get_device() - self.generators[model_name] = AutoPipelineForText2Image.from_pretrained(pretrained_model_or_path=model_name, **pipeline_creation_data).to( - device - ) - generator = self.generators[model_name] + self.generators[key] = AutoPipelineForText2Image.from_pretrained(pretrained_model_or_path=model_name, **pipeline_creation_data).to(device) + generator = self.generators[key] disclaimer_long_response_print_message = """\n Calling image generation. This can take a long time, (up to SEVERAL MINUTES depending @@ -370,21 +372,22 @@ def _get_device(self) -> str: return "mps" return "cpu" + def _refine_responses( - response_images: List[Image.Image], + response_images: List[Image.Image], nsfw_content_detected: List[bool], ) -> List[ImageData]: """ Helper function for taking the separate response data lists (`images` and - `nsfw_content_detected`) from StableDiffusionPipelineOutput or + `nsfw_content_detected`) from StableDiffusionPipelineOutput or StableDiffusionXLPipelineOutput and merging this data into a single array - containing ImageData which stores information at the image-level. This - makes processing later easier since all the data we need is stored in a + containing ImageData which stores information at the image-level. This + makes processing later easier since all the data we need is stored in a single object, so we don't need to compare two separate lists Args: response_images List[Image.Image]: List of images - nsfw_content_detected List[bool]: List of whether the image at that + nsfw_content_detected List[bool]: List of whether the image at that corresponding index from `response_images` has detected that it contains nsfw_content. It is possible for this list to be empty @@ -396,8 +399,5 @@ def _refine_responses( # Use zip.longest because nsfw_content_detected can be empty itertools.zip_longest(response_images, nsfw_content_detected) ) - image_data_objects: List[ImageData] = [ - ImageData(image=image, nsfw_content_detected=has_nsfw) - for (image, has_nsfw) in merged_responses - ] + image_data_objects: List[ImageData] = [ImageData(image=image, nsfw_content_detected=has_nsfw) for (image, has_nsfw) in merged_responses] return image_data_objects diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py index 2474014cc..7c132108f 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py @@ -3,10 +3,12 @@ import io import json import numpy as np -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional from transformers import Pipeline, pipeline from scipy.io.wavfile import write as write_wav +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model + from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( @@ -24,7 +26,7 @@ # Step 1: define Helpers def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]: - # These are from the transformers Github repo: + # These are from the transformers Github repo: # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2534 supported_keys = { "torch_dtype", @@ -63,7 +65,7 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict def refine_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]: # Note: There seems to be no public API docs on what completion - # params are supported for text to speech: + # params are supported for text to speech: # https://huggingface.co/docs/transformers/tasks/text-to-speech#inference # The only one mentioned is `forward_params` which can contain `speaker_embeddings` supported_keys = {} @@ -194,10 +196,11 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio model_settings = self.get_model_settings(prompt, aiconfig) [pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings) - model_name: str = aiconfig.get_model_name(prompt) - if isinstance(model_name, str) and model_name not in self.synthesizers: - self.synthesizers[model_name] = pipeline("text-to-speech", model_name) - synthesizer = self.synthesizers[model_name] + model_name = get_hf_model(aiconfig, prompt, self) + key = model_name if model_name is not None else "__default__" + if key not in self.synthesizers: + self.synthesizers[key] = pipeline("text-to-speech", model=model_name) + synthesizer = self.synthesizers[key] completion_data = await self.deserialize(prompt, aiconfig, options, parameters) inputs = completion_data.pop("prompt", None) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py index 774365334..7ad79cf97 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py @@ -9,6 +9,8 @@ TextIteratorStreamer, ) +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model + from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( @@ -79,7 +81,7 @@ def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: "encoder_no_repeat_ngram_size", "decoder_start_token_id", "num_assistant_tokens", - "num_assistant_tokens_schedule" + "num_assistant_tokens_schedule", } completion_data = {} @@ -104,6 +106,7 @@ def construct_regular_output(result: Dict[str, str], execution_count: int) -> Ou ) return output + def construct_stream_output( streamer: TextIteratorStreamer, options: InferenceOptions, @@ -119,13 +122,13 @@ def construct_stream_output( """ output = ExecuteResult( - **{ - "output_type": "execute_result", - "data": "", # We update this below - "execution_count": 0, #Multiple outputs are not supported for streaming - "metadata": {}, - } - ) + **{ + "output_type": "execute_result", + "data": "", # We update this below + "execution_count": 0, # Multiple outputs are not supported for streaming + "metadata": {}, + } + ) accumulated_message = "" for new_text in streamer: if isinstance(new_text, str): @@ -152,7 +155,7 @@ def __init__(self): config.register_model_parser(parser) """ super().__init__() - self.generators: dict[str, Pipeline]= {} + self.generators: dict[str, Pipeline] = {} def id(self) -> str: """ @@ -190,9 +193,7 @@ async def serialize( prompt = Prompt( name=prompt_name, input=prompt_input, - metadata=PromptMetadata( - model=model_metadata, parameters=parameters, **kwargs - ), + metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), ) return [prompt] @@ -217,14 +218,12 @@ async def deserialize( model_settings = self.get_model_settings(prompt, aiconfig) completion_data = refine_completion_params(model_settings) - #Add resolved prompt + # Add resolved prompt resolved_prompt = resolve_prompt(prompt, params, aiconfig) completion_data["prompt"] = resolved_prompt return completion_data - async def run_inference( - self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any] - ) -> List[Output]: + async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: """ Invoked to run a prompt in the .aiconfig. This method should perform the actual model inference based on the provided prompt and inference settings. @@ -239,16 +238,15 @@ async def run_inference( completion_data = await self.deserialize(prompt, aiconfig, options, parameters) completion_data["text_inputs"] = completion_data.pop("prompt", None) - model_name: str | None = aiconfig.get_model_name(prompt) - if isinstance(model_name, str) and model_name not in self.generators: - self.generators[model_name] = pipeline('text-generation', model=model_name) - generator = self.generators[model_name] + model_name = get_hf_model(aiconfig, prompt, self) + key = model_name if model_name is not None else "__default__" + if key not in self.generators: + self.generators[key] = pipeline("text-generation", model=model_name) + generator = self.generators[key] # if stream enabled in runtime options and config, then stream. Otherwise don't stream. streamer = None - should_stream = (options.stream if options else False) and ( - not "stream" in completion_data or completion_data.get("stream") != False - ) + should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) if should_stream: tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) streamer = TextIteratorStreamer(tokenizer) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py index 532bbbadc..2c0125a28 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py @@ -9,6 +9,8 @@ TextIteratorStreamer, ) +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model + from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( @@ -242,16 +244,15 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio completion_data = await self.deserialize(prompt, aiconfig, options, parameters) inputs = completion_data.pop("prompt", None) - model_name: str = aiconfig.get_model_name(prompt) - if isinstance(model_name, str) and model_name not in self.summarizers: - self.summarizers[model_name] = pipeline("summarization", model=model_name) - summarizer = self.summarizers[model_name] + model_name = get_hf_model(aiconfig, prompt, self) + key = model_name if model_name is not None else "__default__" + if key not in self.summarizers: + self.summarizers[key] = pipeline("summarization", model=model_name) + summarizer = self.summarizers[key] # if stream enabled in runtime options and config, then stream. Otherwise don't stream. streamer = None - should_stream = (options.stream if options else False) and ( - not "stream" in completion_data or completion_data.get("stream") != False - ) + should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) if should_stream: tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) streamer = TextIteratorStreamer(tokenizer) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py index 5100e3e24..abdc5a625 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py @@ -9,6 +9,8 @@ TextIteratorStreamer, ) +from aiconfig_extension_hugging_face.local_inference.util import get_hf_model + from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions from aiconfig.schema import ( @@ -244,16 +246,15 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio completion_data = await self.deserialize(prompt, aiconfig, options, parameters) inputs = completion_data.pop("prompt", None) - model_name: str = aiconfig.get_model_name(prompt) - if isinstance(model_name, str) and model_name not in self.translators: - self.translators[model_name] = pipeline("translation", model_name) - translator = self.translators[model_name] + model_name = get_hf_model(aiconfig, prompt, self) + key = model_name if model_name is not None else "__default__" + if key not in self.translators: + self.translators[key] = pipeline("translation", model_name) + translator = self.translators[key] # if stream enabled in runtime options and config, then stream. Otherwise don't stream. streamer = None - should_stream = (options.stream if options else False) and ( - not "stream" in completion_data or completion_data.get("stream") != False - ) + should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) if should_stream: tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) streamer = TextIteratorStreamer(tokenizer) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py new file mode 100644 index 000000000..14db41a3a --- /dev/null +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/util.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING +from aiconfig import ParameterizedModelParser +from aiconfig.schema import Prompt + + +# Circular Dependency Type Hints +if TYPE_CHECKING: + from aiconfig import AIConfigRuntime + + +def get_hf_model(aiconfig: "AIConfigRuntime", prompt: Prompt, model_parser: ParameterizedModelParser) -> str | None: + """ + Returns the HuggingFace model to use for the given prompt and model parser. + """ + model_name: str | None = aiconfig.get_model_name(prompt) + model_settings = model_parser.get_model_settings(prompt, aiconfig) + hf_model = model_settings.get("model", None) + + if hf_model is not None and isinstance(hf_model, str): + # If the model property is set in the model settings, use that. + return hf_model + elif model_name == model_parser.id(): + # If the model name is the name of the model parser itself, + # then return None to signify using the default model for that HF task. + return None + else: + # Otherwise the model name is the name of the model to use. + return model_name From 69778fa57cfb64dc5504e1942f2066e4f29e0a43 Mon Sep 17 00:00:00 2001 From: Sarmad Qadri Date: Wed, 10 Jan 2024 17:54:10 -0500 Subject: [PATCH 2/2] Fix parameterization issues Now that we have non-text input model parsers, we can see some issues in our parameterization. 1) ASR and Image-to-Text modelparsers should NOT be `ParameterizedModelParser` instances, since their inputs cannot be parameterized. 2) Parameter resolution logic shouldn't throw errors in the case where it's parsing prompts belonging to regular model parsers. Test Plan: Ran the translation prompt for this aiconfig: ``` prompts: [ { "name": "image_caption", "input": { "attachments": [ { "data": "https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/1742/bear_with_honey.png", "mime_type": "image/png" }, { "data": "https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/7275/fox_in_forest.png", "mime_type": "image/png" } ] }, "metadata": { "model": { "name": "HuggingFaceImage2TextTransformer", "settings": { "max_new_tokens": 10, "model": "Salesforce/blip-image-captioning-base" } }, "parameters": {} }, "outputs": [ { "output_type": "execute_result", "execution_count": 0, "data": "a bear sitting on a rock eating honey", "metadata": {} }, { "output_type": "execute_result", "execution_count": 1, "data": "a red fox in the woods", "metadata": {} } ] }, { "name": "translation", "input": "Once upon a time, in a lush and vibrant forest, there lived a magnificent creature known as the Quick Brown Fox. This fox was unlike any other, possessing incredible speed and agility that awed all the animals in the forest. With its fur as golden as the sun and its eyes as sharp as emeralds, the Quick Brown Fox was admired by everyone, from the tiniest hummingbird to the mightiest bear. The fox had a kind heart and would often lend a helping paw to those in need. The Quick Brown Fox had a particular fondness for games and challenges. It loved to test its skills against others, always seeking new adventures to satisfy its boundless curiosity. Its favorite game was called \"The Great Word Hunt,\" where it would embark on a quest to find hidden words scattered across the forest. \n\nOne day it got very old and died", "metadata": { "model": { "name": "HuggingFaceTextTranslationTransformer", "settings": { "model": "Helsinki-NLP/opus-mt-en-fr", "max_length": 100, "min_length": 50, "num_beams": 1 } }, "parameters": {} } } ] ``` Before: ``` File "/opt/homebrew/lib/python3.11/site-packages/aiconfig/util/params.py", line 235, in get_prompt_template raise Exception(f"Cannot get prompt template string from prompt: {prompt.input}") Exception: Cannot get prompt template string from prompt: attachments=[Attachment(data='https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/1742/bear_with_honey.png', mime_type='image/png', metadata=None), Attachment(data='https://s3.amazonaws.com/lastmileai.aiconfig.public/uploads/2024_1_10_18_41_31/7275/fox_in_forest.png', mime_type='image/png', metadata=None)] data=None ``` After: * Beautiful translation --- .../local_inference/automatic_speech_recognition.py | 6 +++--- .../local_inference/image_2_text.py | 6 +++--- python/src/aiconfig/util/params.py | 12 ++++++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py index 82a5f37f5..54830533a 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/automatic_speech_recognition.py @@ -3,7 +3,7 @@ import torch from transformers import pipeline, Pipeline from aiconfig_extension_hugging_face.local_inference.util import get_hf_model -from aiconfig import ParameterizedModelParser, InferenceOptions +from aiconfig import ModelParser, InferenceOptions from aiconfig.callback import CallbackEvent from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment @@ -11,7 +11,7 @@ from aiconfig import AIConfigRuntime -class HuggingFaceAutomaticSpeechRecognitionTransformer(ParameterizedModelParser): +class HuggingFaceAutomaticSpeechRecognitionTransformer(ModelParser): """ Model Parser for HuggingFace ASR (Automatic Speech Recognition) models. """ @@ -85,7 +85,7 @@ async def deserialize( await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_data})) return completion_data - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: + async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: await aiconfig.callback_manager.run_callbacks( CallbackEvent( "on_run_start", diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py index 5e7f2527b..7cbf54ccf 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py @@ -10,7 +10,7 @@ from aiconfig_extension_hugging_face.local_inference.util import get_hf_model -from aiconfig import ParameterizedModelParser, InferenceOptions +from aiconfig import ModelParser, InferenceOptions from aiconfig.callback import CallbackEvent from aiconfig.schema import ( Attachment, @@ -24,7 +24,7 @@ from aiconfig import AIConfigRuntime -class HuggingFaceImage2TextTransformer(ParameterizedModelParser): +class HuggingFaceImage2TextTransformer(ModelParser): def __init__(self): """ Returns: @@ -117,7 +117,7 @@ async def deserialize( await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params})) return completion_params - async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: + async def run(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> list[Output]: await aiconfig.callback_manager.run_callbacks( CallbackEvent( "on_run_start", diff --git a/python/src/aiconfig/util/params.py b/python/src/aiconfig/util/params.py index 33905c034..959d2a278 100644 --- a/python/src/aiconfig/util/params.py +++ b/python/src/aiconfig/util/params.py @@ -207,7 +207,7 @@ def resolve_parameters(params, prompt: Prompt, ai_config: "AIConfig"): return resolved_prompt -def get_prompt_template(prompt: Prompt, aiconfig: "AIConfigRuntime"): +def get_prompt_template(prompt: Prompt, aiconfig: "AIConfigRuntime") -> str | None: """ Returns the template for a prompt. @@ -224,13 +224,13 @@ def get_prompt_template(prompt: Prompt, aiconfig: "AIConfigRuntime"): if isinstance(model_parser, ParameterizedModelParser): return model_parser.get_prompt_template(prompt, aiconfig) - - if isinstance(prompt.input, str): + elif isinstance(prompt.input, str): return prompt.input elif isinstance(prompt.input.data, str): return prompt.input.data else: - raise Exception(f"Cannot get prompt template string from prompt: {prompt.input}") + # Not all model inputs are parameterizable (e.g. image inputs) + return None def collect_prompt_references(current_prompt: Prompt, ai_config: "AIConfigRuntime"): @@ -243,6 +243,10 @@ def collect_prompt_references(current_prompt: Prompt, ai_config: "AIConfigRuntim break prompt_input = get_prompt_template(previous_prompt, ai_config) + if prompt_input is None: + # not all model inputs are parameterizable + continue + prompt_output = ai_config.get_output_text(previous_prompt, ai_config.get_latest_output(previous_prompt)) if previous_prompt.outputs else None prompt_references[previous_prompt.name] = { "input": prompt_input,