Skip to content

Commit

Permalink
[HF][5/n] Image2Text: Allow base64 inputs for images (#856)
Browse files Browse the repository at this point in the history
[HF][5/n] Image2Text: Allow base64 inputs for images

Before we didn't allow base64, only URI (either local or http or https).
This is good becuase our text2Image model parser outputs into a base64
format, so this will allow us to chain model prompts!

## Test Plan

Rebase and test on
0d7ae2b.

Follow the README from AIConfig Editor
https://github.com/lastmile-ai/aiconfig/tree/main/python/src/aiconfig/editor#dev,
then run these command
```bash
aiconfig_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/huggingface.aiconfig.json
parsers_path=/Users/rossdancraig/Projects/aiconfig/cookbooks/Gradio/hf_model_parsers.py
alias aiconfig="python3 -m 'aiconfig.scripts.aiconfig_cli'"
aiconfig edit --aiconfig-path=$aiconfig_path --server-port=8080 --server-mode=debug_servers --parsers-module-path=$parsers_path
```

Then in AIConfig Editor run the prompt (streaming not supported so just
took screenshots)

These are the images I tested (with bear being in base64 format)

![fox_in_forest](https://github.com/lastmile-ai/aiconfig/assets/151060367/ca7d1723-9e12-4cc8-9d8d-41fa9f466919)

![bear-eating-honey](https://github.com/lastmile-ai/aiconfig/assets/151060367/a947d89e-c02a-4c64-8183-ff1c85802859)

<img width="1281" alt="Screenshot 2024-01-10 at 04 57 44"
src="https://github.com/lastmile-ai/aiconfig/assets/151060367/ea60cbc5-e6ab-4bf2-82e7-17f3182fdc5c">

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/856).
* __->__ #856
* #855
* #854
* #853
* #851
  • Loading branch information
saqadri authored Jan 10, 2024
2 parents 0c9e704 + fa9d88a commit 8d0a59c
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 49 deletions.
5 changes: 3 additions & 2 deletions extensions/HuggingFace/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ huggingface_hub

#Hugging Face Libraries - Local Inference Tranformers & Diffusors
accelerate # Used to help speed up image generation
diffusers # Used for image + audio generation
diffusers # Used for image generation
scipy # array -> wav file, text-speech. torchaudio.save seems broken.
sentencepiece # Used for text translation
torch
torchvision
torchaudio
scipy # array -> wav file, text-speech. torchaudio.save seems broken.
transformers # Used for text generation

#Other
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,24 @@
from typing import Any, Dict, Optional, List, TYPE_CHECKING
import base64
import json
from io import BytesIO
from PIL import Image
from typing import Any, Dict, Optional, List, TYPE_CHECKING, Union
from transformers import (
Pipeline,
pipeline,
)

from aiconfig import ParameterizedModelParser, InferenceOptions
from aiconfig.callback import CallbackEvent
import torch
from aiconfig.schema import Prompt, Output, ExecuteResult, Attachment

from transformers import pipeline, Pipeline

from aiconfig.schema import (
Attachment,
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
)

# Circular Dependency Type Hints
if TYPE_CHECKING:
from aiconfig import AIConfigRuntime

Expand Down Expand Up @@ -93,10 +106,11 @@ async def deserialize(
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_start", __name__, {"prompt": prompt, "params": params}))

# Build Completion data
completion_params = self.get_model_settings(prompt, aiconfig)

inputs = validate_and_retrieve_image_from_attachments(prompt)
model_settings = self.get_model_settings(prompt, aiconfig)
completion_params = refine_completion_params(model_settings)

#Add image inputs
inputs = validate_and_retrieve_images_from_attachments(prompt)
completion_params["inputs"] = inputs

await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_deserialize_complete", __name__, {"output": completion_params}))
Expand All @@ -110,24 +124,93 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
{"prompt": prompt, "options": options, "parameters": parameters},
)
)
model_name = aiconfig.get_model_name(prompt)

self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)

captioner = self.pipelines[model_name]
completion_data = await self.deserialize(prompt, aiconfig, parameters)
inputs = completion_data.pop("inputs")
model = completion_data.pop("model")
response = captioner(inputs, **completion_data)

output = ExecuteResult(output_type="execute_result", data=response, metadata={})
model_name: str | None = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.pipelines:
self.pipelines[model_name] = pipeline(task="image-to-text", model=model_name)
captioner = self.pipelines[model_name]

outputs: List[Output] = []
response: List[Any] = captioner(inputs, **completion_data)
for count, result in enumerate(response):
output: Output = construct_regular_output(result, count)
outputs.append(output)

prompt.outputs = [output]
await aiconfig.callback_manager.run_callbacks(CallbackEvent("on_run_complete", __name__, {"result": prompt.outputs}))
prompt.outputs = outputs
print(f"{prompt.outputs=}")
await aiconfig.callback_manager.run_callbacks(
CallbackEvent(
"on_run_complete",
__name__,
{"result": prompt.outputs},
)
)
return prompt.outputs

def get_output_text(self, response: dict[str, Any]) -> str:
raise NotImplementedError("get_output_text is not implemented for HuggingFaceImage2TextTransformer")
def get_output_text(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
output: Optional[Output] = None,
) -> str:
if output is None:
output = aiconfig.get_latest_output(prompt)

if output is None:
return ""

# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# HuggingFace Text summarization does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""


def refine_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
"""
Refines the completion params for the HF image to text api. Removes any unsupported params.
The supported keys were found by looking at the HF ImageToTextPipeline.__call__ method
"""
supported_keys = {
"max_new_tokens",
"timeout",
}

completion_data = {}
for key in model_settings:
if key.lower() in supported_keys:
completion_data[key.lower()] = model_settings[key]

return completion_data

# Helper methods
def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output:
"""
Construct regular output per response result, without streaming enabled
"""
output = ExecuteResult(
**{
"output_type": "execute_result",
# For some reason result is always in list format we haven't found
# a way of being able to return multiple sequences from the image
# to text pipeline
"data": result[0]["generated_text"],
"execution_count": execution_count,
"metadata": {},
}
)
return output


def validate_attachment_type_is_image(attachment: Attachment):
Expand All @@ -138,7 +221,7 @@ def validate_attachment_type_is_image(attachment: Attachment):
raise ValueError(f"Invalid attachment mimetype {attachment.mime_type}. Expected image mimetype.")


def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
def validate_and_retrieve_images_from_attachments(prompt: Prompt) -> list[Union[str, Image]]:
"""
Retrieves the image uri's from each attachment in the prompt input.
Expand All @@ -152,15 +235,23 @@ def validate_and_retrieve_image_from_attachments(prompt: Prompt) -> list[str]:
if not hasattr(prompt.input, "attachments") or len(prompt.input.attachments) == 0:
raise ValueError(f"No attachments found in input for prompt {prompt.name}. Please add an image attachment to the prompt input.")

image_uris: list[str] = []
images: list[Union[str, Image]] = []

for i, attachment in enumerate(prompt.input.attachments):
validate_attachment_type_is_image(attachment)

if not isinstance(attachment.data, str):
input_data = attachment.data
if not isinstance(input_data, str):
# See todo above, but for now only support uri's
raise ValueError(f"Attachment #{i} data is not a uri. Please specify a uri for the image attachment in prompt {prompt.name}.")

image_uris.append(attachment.data)
# Really basic heurestic to check if the data is a base64 encoded str
# vs. uri. This will be fixed once we have standardized inputs
# See https://github.com/lastmile-ai/aiconfig/issues/829
if len(input_data) > 10000:
pil_image : Image = Image.open(BytesIO(base64.b64decode(input_data)))
images.append(pil_image)
else:
images.append(input_data)

return image_uris
return images
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

# Step 1: define Helpers
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
# There are from the transformers Github repo:
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2534
supported_keys = {
"torch_dtype",
"force_download",
Expand Down Expand Up @@ -61,9 +63,11 @@ def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict


def refine_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]:
supported_keys = {
# ???
}
# Note: There seems to be no public API docs on what completion
# params are supported for text to speech:
# https://huggingface.co/docs/transformers/tasks/text-to-speech#inference
# The only one mentioned is `forward_params` which can contain `speaker_embeddings`
supported_keys = {}

completion_params: Dict[str, Any] = {}
for key in unfiltered_completion_params:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self):
config.register_model_parser(parser)
"""
super().__init__()
self.generators : dict[str, Pipeline]= {}
self.generators: dict[str, Pipeline]= {}

def id(self) -> str:
"""
Expand Down Expand Up @@ -217,14 +217,14 @@ async def deserialize(
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
completion_data = refine_chat_completion_params(model_settings)

#Add resolved prompt
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
completion_data["prompt"] = resolved_prompt
return completion_data

async def run_inference(
self, prompt: Prompt, aiconfig : "AIConfigRuntime", options : InferenceOptions, parameters: Dict[str, Any]
self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]
) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
Expand All @@ -239,8 +239,8 @@ async def run_inference(
"""
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
completion_data["text_inputs"] = completion_data.pop("prompt", None)
model_name : str = aiconfig.get_model_name(prompt)

model_name: str | None = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.generators:
self.generators[model_name] = pipeline('text-generation', model=model_name)
generator = self.generators[model_name]
Expand All @@ -251,14 +251,14 @@ async def run_inference(
not "stream" in completion_data or completion_data.get("stream") != False
)
if should_stream:
tokenizer : AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer

outputs : List[Output] = []
outputs: List[Output] = []
output = None
if not should_stream:
response : List[Any] = generator(**completion_data)
response: List[Any] = generator(**completion_data)
for count, result in enumerate(response):
output = construct_regular_output(result, count)
outputs.append(output)
Expand All @@ -267,7 +267,7 @@ async def run_inference(
raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1")
if not streamer:
raise ValueError("Stream option is selected but streamer is not initialized")

# For streaming, cannot call `generator` directly otherwise response will be blocking
thread = threading.Thread(target=generator, kwargs=completion_data)
thread.start()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,18 @@ def construct_stream_output(
"metadata": {},
}
)

accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
# For some reason these symbols aren't filtered out by the streamer
new_text = new_text.replace("</s>", "")
new_text = new_text.replace("<s>", "")

accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)

output.data = accumulated_message

return output


Expand Down Expand Up @@ -245,18 +250,18 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
should_stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
if should_stream:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer

outputs: List[Output] = []
output = None

def _summarize():
return summarizer(inputs, **completion_data)

outputs: List[Output] = []
if not should_stream:
response: List[Any] = _summarize()
for count, result in enumerate(response):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,19 @@ def construct_stream_output(
"metadata": {},
}
)

accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
# For some reason these symbols aren't filtered out by the streamer
new_text = new_text.replace("</s>", "")
new_text = new_text.replace("<s>", "")
new_text = new_text.replace("<pad>", "")

accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)
output.data = accumulated_message

return output


Expand Down Expand Up @@ -240,19 +247,26 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio

model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.translators:
self.translators[model_name] = pipeline(model_name)
self.translators[model_name] = pipeline("translation", model_name)
translator = self.translators[model_name]

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
should_stream = (options.stream if options else False) and (
not "stream" in completion_data or completion_data.get("stream") != False
)
if should_stream:
raise NotImplementedError("Streaming is not supported for HuggingFace Text Translation")
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer

def _translate():
return translator(inputs, **completion_data)

outputs: List[Output] = []
output = None
if not should_stream:
response: List[Any] = translator(inputs, **completion_data)
response: List[Any] = _translate()
for count, result in enumerate(response):
output = construct_regular_output(result, count)
outputs.append(output)
Expand All @@ -263,7 +277,7 @@ async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", optio
raise ValueError("Stream option is selected but streamer is not initialized")

# For streaming, cannot call `translator` directly otherwise response will be blocking
thread = threading.Thread(target=translator, kwargs=completion_data)
thread = threading.Thread(target=_translate)
thread.start()
output = construct_stream_output(streamer, options)
if output is not None:
Expand Down

0 comments on commit 8d0a59c

Please sign in to comment.