From 074b768ba6eee4fb6dd1b2493c258cee26d2ed75 Mon Sep 17 00:00:00 2001
From: "Rossdan Craig rossdan@lastmileai.dev" <>
Date: Wed, 10 Jan 2024 01:39:03 -0500
Subject: [PATCH 1/5] [HF][streaming][1/n] Text Summarization
TSIA
Adding streaming functionality to text summarization model parser
## Test Plan
Rebase onto and test it with https://github.com/lastmile-ai/aiconfig/pull/852/commits/11ace0a5a31b9404a97965e5fa478d5b19adcb67.
Follow the README from AIConfig Editor https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev, then run these command
```bash
aiconfig_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/huggingface.aiconfig.json
parsers_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/hf_model_parsers.py
alias aiconfig="python3 -m 'aiconfig.scripts.aiconfig_cli'"
aiconfig edit --aiconfig-path=$aiconfig_path --server-port=8080 --server-mode=debug_servers --parsers-module-path=$parsers_path
```
Then in AIConfig Editor run the prompt (it will be streaming format by default)
https://github.com/lastmile-ai/aiconfig/assets/151060367/e91a1d8b-a3e9-459c-9eb1-2d8e5ec58e73
---
.../local_inference/text_generation.py | 2 +-
.../local_inference/text_summarization.py | 11 +++++++++--
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py
index e0941b0fd..c6218c82e 100644
--- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py
+++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py
@@ -251,7 +251,7 @@ async def run_inference(
not "stream" in completion_data or completion_data.get("stream") != False
)
if should_stream:
- tokenizer : AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer
diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py
index bba735b4f..32b90b908 100644
--- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py
+++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py
@@ -128,13 +128,18 @@ def construct_stream_output(
"metadata": {},
}
)
+
accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
+ # For some reason these symbols aren't filtered out by the streamer
+ new_text = new_text.replace("", "")
+ new_text = new_text.replace("", "")
+
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
-
output.data = accumulated_message
+
return output
@@ -245,7 +250,9 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
- should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
+ should_stream = (options.stream if options else False) and (
+ not "stream" in completion_data or completion_data.get("stream") != False
+ )
if should_stream:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
From 13a4c6ed44a3307807bb5a8b8a8e7407f4f4e45b Mon Sep 17 00:00:00 2001
From: "Rossdan Craig rossdan@lastmileai.dev" <>
Date: Wed, 10 Jan 2024 02:08:54 -0500
Subject: [PATCH 2/5] [HF][streaming][2/n] Text Translation
TSIA
Adding streaming output support for text translation model parser. I also fixed a bug where we didn't pass in `"translation"` key into the pipeline
## Test Plan
Rebase onto and test it: https://github.com/lastmile-ai/aiconfig/commit/5b7434483e2521110bbe3cc3380d9b99a6d4e8be.
Follow the README from AIConfig Editor https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev, then run these command
```bash
aiconfig_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/huggingface.aiconfig.json
parsers_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/hf_model_parsers.py
alias aiconfig="python3 -m 'aiconfig.scripts.aiconfig_cli'"
aiconfig edit --aiconfig-path=$aiconfig_path --server-port=8080 --server-mode=debug_servers --parsers-module-path=$parsers_path
```
With Streaming
https://github.com/lastmile-ai/aiconfig/assets/151060367/d7bc9df2-2993-4709-bf9b-c5b7979fb00f
Without Streaming
https://github.com/lastmile-ai/aiconfig/assets/151060367/71eb6ab3-5d6f-4c5d-8b82-f3daf4c5e610
---
.../HuggingFace/python/requirements.txt | 5 ++--
.../local_inference/text_translation.py | 24 +++++++++++++++----
2 files changed, 22 insertions(+), 7 deletions(-)
diff --git a/extensions/HuggingFace/python/requirements.txt b/extensions/HuggingFace/python/requirements.txt
index 6388e1c9e..79e5db10b 100644
--- a/extensions/HuggingFace/python/requirements.txt
+++ b/extensions/HuggingFace/python/requirements.txt
@@ -10,11 +10,12 @@ huggingface_hub
#Hugging Face Libraries - Local Inference Tranformers & Diffusors
accelerate # Used to help speed up image generation
-diffusers # Used for image + audio generation
+diffusers # Used for image generation
+scipy # array -> wav file, text-speech. torchaudio.save seems broken.
+sentencepiece # Used for text translation
torch
torchvision
torchaudio
-scipy # array -> wav file, text-speech. torchaudio.save seems broken.
transformers # Used for text generation
#Other
diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py
index 9ee8bb357..860a11e46 100644
--- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py
+++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_translation.py
@@ -129,12 +129,19 @@ def construct_stream_output(
"metadata": {},
}
)
+
accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
+ # For some reason these symbols aren't filtered out by the streamer
+ new_text = new_text.replace("", "")
+ new_text = new_text.replace("", "")
+ new_text = new_text.replace("", "")
+
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
output.data = accumulated_message
+
return output
@@ -240,19 +247,26 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.translators:
- self.translators[model_name] = pipeline(model_name)
+ self.translators[model_name] = pipeline("translation", model_name)
translator = self.translators[model_name]
# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
- should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
+ should_stream = (options.stream if options else False) and (
+ not "stream" in completion_data or completion_data.get("stream") != False
+ )
if should_stream:
- raise NotImplementedError("Streaming is not supported for HuggingFace Text Translation")
+ tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
+ streamer = TextIteratorStreamer(tokenizer)
+ completion_data["streamer"] = streamer
+
+ def _translate():
+ return translator(inputs, **completion_data)
outputs: List[Output] = []
output = None
if not should_stream:
- response: List[Any] = translator(inputs, **completion_data)
+ response: List[Any] = _translate()
for count, result in enumerate(response):
output = construct_regular_output(result, count)
outputs.append(output)
@@ -263,7 +277,7 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
raise ValueError("Stream option is selected but streamer is not initialized")
# For streaming, cannot call `translator` directly otherwise response will be blocking
- thread = threading.Thread(target=translator, kwargs=completion_data)
+ thread = threading.Thread(target=_translate)
thread.start()
output = construct_stream_output(streamer, options)
if output is not None:
From 1f161e5b4fafdad0980735d45f66f532ea993749 Mon Sep 17 00:00:00 2001
From: "Rossdan Craig rossdan@lastmileai.dev" <>
Date: Wed, 10 Jan 2024 02:49:07 -0500
Subject: [PATCH 3/5] [HF][streaming][3/n] Text2Speech (no streaming, but
updating docs on completion params)
Ok this one is weird. Today, streaming is only ever supported on text outputs in Transformers library. See `BaseStreamer` in here: https://github.com/search?q=repo%3Ahuggingface%2Ftransformers%20BaseStreamer&type=code
In the future it may support other formats, but not yet. For example, OpenAI supports it: https://community.openai.com/t/streaming-from-text-to-speech-api/493784
Anyways, I basically here only did some updates to docs to clarify why completion params were null. Jonathan and I synced about this briefly ofline, but I forgot again so wanted to capture it here so no one forgets
---
.../local_inference/text_2_speech.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py
index 85dee4add..97e172fde 100644
--- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py
+++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_2_speech.py
@@ -25,6 +25,8 @@
# Step 1: define Helpers
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
+ # There are from the transformers Github repo:
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2534
supported_keys = {
"torch_dtype",
"force_download",
@@ -61,9 +63,11 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict
def refine_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]:
- supported_keys = {
- # ???
- }
+ # Note: There seems to be no public API docs on what completion
+ # params are supported for text to speech:
+ # https://huggingface.co/docs/transformers/tasks/text-to-speech#inference
+ # The only one mentioned is `forward_params` which can contain `speaker_embeddings`
+ supported_keys = {}
completion_params: Dict[str, Any] = {}
for key in unfiltered_completion_params:
From 19d7844b0da33d143e370031b60397c11d0183a5 Mon Sep 17 00:00:00 2001
From: "Rossdan Craig rossdan@lastmileai.dev" <>
Date: Wed, 10 Jan 2024 05:02:48 -0500
Subject: [PATCH 4/5] [HF][streaming][4/n] Image2Text (no streaming, but lots
of fixing)
This model parser does not support streaming (surprising!):
```
TypeError: ImageToTextPipeline._sanitize_parameters() got an unexpected keyword argument 'streamer'
```
In general, I mainly just did a lot of fixing up to make sure that this worked as expected. Things I fixed:
1. Now works for multiple images (it did before, but didn't process responses for each properly, just put the entire response)
2. Constructing responses to be in pure text output
3. Specified the completion params that are supported (only 2: https://github.com/huggingface/transformers/blob/701298d2d3d5c7bde45e71cce12736098e3f05ef/src/transformers/pipelines/image_to_text.py#L97-L102C13)
Next diff I will add support for b64 encoded image format --> we need to convert this to a PIL, see https://github.com/huggingface/transformers/blob/701298d2d3d5c7bde45e71cce12736098e3f05ef/src/transformers/pipelines/image_to_text.py#L83
## Test Plan
Rebase onto and test it: 5f3b667f8c31b59c23555a269ccd4d17388f6361.
Follow the README from AIConfig Editor https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev, then run these command
```bash
aiconfig_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/huggingface.aiconfig.json
parsers_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/hf_model_parsers.py
alias aiconfig="python3 -m 'aiconfig.scripts.aiconfig_cli'"
aiconfig edit --aiconfig-path=$aiconfig_path --server-port=8080 --server-mode=debug_servers --parsers-module-path=$parsers_path
```
Then in AIConfig Editor run the prompt (streaming not supported for this model so just took screenshots)
These are the images I tested


Before
After
---
.../local_inference/image_2_text.py | 116 +++++++++++++++---
.../local_inference/text_generation.py | 16 +--
.../local_inference/text_summarization.py | 4 +-
3 files changed, 107 insertions(+), 29 deletions(-)
diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py
index ad408ca17..aab79c965 100644
--- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py
+++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py
@@ -1,11 +1,21 @@
+import json
from typing import Any, Dict, Optional, List, TYPE_CHECKING
+from transformers import (
+ Pipeline,
+ pipeline,
+)
+
from aiconfig import ParameterizedModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
-import torch
-from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment
-
-from transformers import pipeline, Pipeline
-
+from aiconfig.schema import (
+ Attachment,
+ ExecuteResult,
+ Output,
+ OutputDataWithValue,
+ Prompt,
+)
+
+# Circular Dependency Type Hints
if TYPE_CHECKING:
from aiconfig import AIConfigRuntime
@@ -93,10 +103,11 @@ async def deserialize(
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))
# Build Completion data
- completion_params = self.get_model_settings(prompt, aiconfig)
+ model_settings = self.get_model_settings(prompt, aiconfig)
+ completion_params = refine_completion_params(model_settings)
+ #Add image inputs
inputs = validate_and_retrieve_image_from_attachments(prompt)
-
completion_params["inputs"] = inputs
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
@@ -110,24 +121,93 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
{"prompt": prompt, "options": options, "parameters": parameters},
)
)
- model_name = aiconfig.get_model_name(prompt)
-
- self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)
- captioner = self.pipelines[model_name]
completion_data = await self.deserialize(prompt, aiconfig, parameters)
inputs = completion_data.pop("inputs")
- model = completion_data.pop("model")
- response = captioner(inputs, **completion_data)
- output = ExecuteResult(output_type="execute_result", data=response, metadata={})
+ model_name: str | None = aiconfig.get_model_name(prompt)
+ if isinstance(model_name, str) and model_name not in self.pipelines:
+ self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)
+ captioner = self.pipelines[model_name]
+
+ outputs: List[Output] = []
+ response: List[Any] = captioner(inputs, **completion_data)
+ for count, result in enumerate(response):
+ output: Output = construct_regular_output(result, count)
+ outputs.append(output)
- prompt.outputs = [output]
- await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
+ prompt.outputs = outputs
+ print(f"{prompt.outputs=}")
+ await aiconfig.callback_manager.run_callbacks(
+ CallbackEvent(
+ "on_run_complete",
+ __name__,
+ {"result": prompt.outputs},
+ )
+ )
return prompt.outputs
- def get_output_text(self, response: dict[str, Any]) -> str:
- raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer")
+ def get_output_text(
+ self,
+ prompt: Prompt,
+ aiconfig: "AIConfigRuntime",
+ output: Optional[Output] = None,
+ ) -> str:
+ if output is None:
+ output = aiconfig.get_latest_output(prompt)
+
+ if output is None:
+ return ""
+
+ # TODO (rossdanlm): Handle multiple outputs in list
+ # https://github.com/lastmile-ai/aiconfig/issues/467
+ if output.output_type == "execute_result":
+ output_data = output.data
+ if isinstance(output_data, str):
+ return output_data
+ if isinstance(output_data, OutputDataWithValue):
+ if isinstance(output_data.value, str):
+ return output_data.value
+ # HuggingFace Text summarization does not support function
+ # calls so shouldn't get here, but just being safe
+ return json.dumps(output_data.value, indent=2)
+ return ""
+
+
+def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Refines the completion params for the HF image to text api. Removes any unsupported params.
+ The supported keys were found by looking at the HF ImageToTextPipeline.__call__ method
+ """
+ supported_keys = {
+ "max_new_tokens",
+ "timeout",
+ }
+
+ completion_data = {}
+ for key in model_settings:
+ if key.lower() in supported_keys:
+ completion_data[key.lower()] = model_settings[key]
+
+ return completion_data
+
+# Helper methods
+def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output:
+ """
+ Construct regular output per response result, without streaming enabled
+ """
+ output = ExecuteResult(
+ **{
+ "output_type": "execute_result",
+ # For some reason result is always in list format we haven't found
+ # a way of being able to return multiple sequences from the image
+ # to text pipeline
+ "data": result[0]["generated_text"],
+ "execution_count": execution_count,
+ "metadata": {},
+ }
+ )
+ return output
def validate_attachment_type_is_image(attachment: Attachment):
diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py
index c6218c82e..4da5d7037 100644
--- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py
+++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_generation.py
@@ -153,7 +153,7 @@ def __init__(self):
config.register_model_parser(parser)
"""
super().__init__()
- self.generators : dict[str, Pipeline]= {}
+ self.generators: dict[str, Pipeline]= {}
def id(self) -> str:
"""
@@ -217,14 +217,14 @@ async def deserialize(
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
completion_data = refine_chat_completion_params(model_settings)
-
+
#Add resolved prompt
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
completion_data["prompt"] = resolved_prompt
return completion_data
async def run_inference(
- self, prompt: Prompt, aiconfig : "AIConfigRuntime", options : InferenceOptions, parameters: Dict[str, Any]
+ self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
@@ -239,8 +239,8 @@ async def run_inference(
"""
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
completion_data["text_inputs"] = completion_data.pop("prompt", None)
-
- model_name : str = aiconfig.get_model_name(prompt)
+
+ model_name: str | None = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.generators:
self.generators[model_name] = pipeline('text-generation', model=model_name)
generator = self.generators[model_name]
@@ -255,10 +255,10 @@ async def run_inference(
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer
- outputs : List[Output] = []
+ outputs: List[Output] = []
output = None
if not should_stream:
- response : List[Any] = generator(**completion_data)
+ response: List[Any] = generator(**completion_data)
for count, result in enumerate(response):
output = construct_regular_output(result, count)
outputs.append(output)
@@ -267,7 +267,7 @@ async def run_inference(
raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1")
if not streamer:
raise ValueError("Stream option is selected but streamer is not initialized")
-
+
# For streaming, cannot call `generator` directly otherwise response will be blocking
thread = threading.Thread(target=generator, kwargs=completion_data)
thread.start()
diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py
index 32b90b908..2b3b61358 100644
--- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py
+++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py
@@ -258,12 +258,10 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer
- outputs: List[Output] = []
- output = None
-
def _summarize():
return summarizer(inputs, **completion_data)
+ outputs: List[Output] = []
if not should_stream:
response: List[Any] = _summarize()
for count, result in enumerate(response):
From fa9d88a6e8f41be10bc002da58341644b9e5d767 Mon Sep 17 00:00:00 2001
From: "Rossdan Craig rossdan@lastmileai.dev" <>
Date: Wed, 10 Jan 2024 05:06:17 -0500
Subject: [PATCH 5/5] [HF][5/n] Image2Text: Allow base64 inputs for images
Before we didn't allow base64, only URI (either local or http or https). This is good becuase our text2Image model parser outputs into a base64 format, so this will allow us to chain model prompts!
## Test Plan
Rebase and test on https://github.com/lastmile-ai/aiconfig/commit/0d7ae2b32ac2eb75cd41ac91e222c5bb6bb5c918.
Follow the README from AIConfig Editor https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev, then run these command
```bash
aiconfig_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/huggingface.aiconfig.json
parsers_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/hf_model_parsers.py
alias aiconfig="python3 -m 'aiconfig.scripts.aiconfig_cli'"
aiconfig edit --aiconfig-path=$aiconfig_path --server-port=8080 --server-mode=debug_servers --parsers-module-path=$parsers_path
```
Then in AIConfig Editor run the prompt (streaming not supported so just took screenshots)
These are the images I tested (with bear being in base64 format)


---
.../local_inference/image_2_text.py | 25 +++++++++++++------
1 file changed, 18 insertions(+), 7 deletions(-)
diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py
index aab79c965..ddb0eb624 100644
--- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py
+++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/image_2_text.py
@@ -1,5 +1,8 @@
+import base64
import json
-from typing import Any, Dict, Optional, List, TYPE_CHECKING
+from io import BytesIO
+from PIL import Image
+from typing import Any, Dict, Optional, List, TYPE_CHECKING, Union
from transformers import (
Pipeline,
pipeline,
@@ -107,7 +110,7 @@ async def deserialize(
completion_params = refine_completion_params(model_settings)
#Add image inputs
- inputs = validate_and_retrieve_image_from_attachments(prompt)
+ inputs = validate_and_retrieve_images_from_attachments(prompt)
completion_params["inputs"] = inputs
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
@@ -218,7 +221,7 @@ def validate_attachment_type_is_image(attachment: Attachment):
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.")
-def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
+def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[str, Image]]:
"""
Retrieves the image uri's from each attachment in the prompt input.
@@ -232,15 +235,23 @@ def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.")
- image_uris: list[str] = []
+ images: list[Union[str, Image]] = []
for i, attachment in enumerate(prompt.input.attachments):
validate_attachment_type_is_image(attachment)
- if not isinstance(attachment.data, str):
+ input_data = attachment.data
+ if not isinstance(input_data, str):
# See todo above, but for now only support uri's
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.")
- image_uris.append(attachment.data)
+ # Really basic heurestic to check if the data is a base64 encoded str
+ # vs. uri. This will be fixed once we have standardized inputs
+ # See https://github.com/lastmile-ai/aiconfig/issues/829
+ if len(input_data) > 10000:
+ pil_image : Image = Image.open(BytesIO(base64.b64decode(input_data)))
+ images.append(pil_image)
+ else:
+ images.append(input_data)
- return image_uris
+ return images