Skip to content

Commit

Permalink
[AIC-py] hf machine translation (text to text) model parser (#753)
Browse files Browse the repository at this point in the history
[AIC-py] hf mt model parser

text translation sans streaming. Streaming is not trivial here

Test:

```
import asyncio
from aiconfig_extension_hugging_face.local_inference.text_translation import HuggingFaceTextTranslationTransformer

from aiconfig import AIConfigRuntime, InferenceOptions, CallbackManager

# Load the aiconfig.

mp = HuggingFaceTextTranslationTransformer()

AIConfigRuntime.register_model_parser(mp, "translation_en_to_fr")

config = AIConfigRuntime.load("/Users/jonathan/Projects/aiconfig/test_hf_transl.aiconfig.json")
config.callback_manager = CallbackManager([])


def print_stream(data, _accumulated_data, _index):
    print(data, end="", flush=True)


async def run():
    # print("Stream")
    # options = InferenceOptions(stream=True, stream_callback=print_stream)
    # out = await config.run("test_hf_trans", options=options)
    # print("Output:\n", out)

    print("no stream")
    options = InferenceOptions(stream=False)
    out = await config.run("test_hf_trans", options=options)
    print("Output:\n", out)


asyncio.run(run())

```

---
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/lastmile-ai/aiconfig/pull/753).
* __->__ #753
* #740
  • Loading branch information
jonathanlastmileai authored Jan 5, 2024
2 parents 1625d10 + 9fd71be commit 00b7acd
Show file tree
Hide file tree
Showing 3 changed files with 614 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer

LOCAL_INFERENCE_CLASSES = ["HuggingFaceText2ImageDiffusor", "HuggingFaceTextGenerationTransformer"]
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient

LOCAL_INFERENCE_CLASSES = [
"HuggingFaceText2ImageDiffusor",
"HuggingFaceTextGenerationTransformer",
"HuggingFaceTextSummarizationTransformer",
"HuggingFaceTextTranslationTransformer",
]
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationClient"]
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
import copy
import json
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import (
AutoTokenizer,
Pipeline,
pipeline,
TextIteratorStreamer,
)

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_prompt

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
from aiconfig.Config import AIConfigRuntime


# Step 1: define Helpers
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
"""
Refines the completion params for the HF text summarization api. Removes any unsupported params.
The supported keys were found by looking at the HF text summarization api. `huggingface_hub.InferenceClient.text_summarization()`
"""

supported_keys = {
"max_length",
"max_new_tokens",
"min_length",
"min_new_tokens",
"early_stopping",
"max_time",
"do_sample",
"num_beams",
"num_beam_groups",
"penalty_alpha",
"use_cache",
"temperature",
"top_k",
"top_p",
"typical_p",
"epsilon_cutoff",
"eta_cutoff",
"diversity_penalty",
"repetition_penalty",
"encoder_repetition_penalty",
"length_penalty",
"no_repeat_ngram_size",
"bad_words_ids",
"force_words_ids",
"renormalize_logits",
"constraints",
"forced_bos_token_id",
"forced_eos_token_id",
"remove_invalid_values",
"exponential_decay_length_penalty",
"suppress_tokens",
"begin_suppress_tokens",
"forced_decoder_ids",
"sequence_bias",
"guidance_scale",
"low_memory",
"num_return_sequences",
"output_attentions",
"output_hidden_states",
"output_scores",
"return_dict_in_generate",
"pad_token_id",
"bos_token_id",
"eos_token_id",
"encoder_no_repeat_ngram_size",
"decoder_start_token_id",
"num_assistant_tokens",
"num_assistant_tokens_schedule",
}

completion_data = {}
for key in model_settings:
if key.lower() in supported_keys:
completion_data[key.lower()] = model_settings[key]

return completion_data


def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output:
"""
Construct regular output per response result, without streaming enabled
"""
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": result["summary_text"],
"execution_count": execution_count,
"metadata": {},
}
)
return output


def construct_stream_output(
streamer: TextIteratorStreamer,
options: InferenceOptions,
) -> Output:
"""
Constructs the output for a stream response.
Args:
streamer (TextIteratorStreamer): Streams the output. See:
https://huggingface.co/docs/transformers/v4.35.2/en/internal/summarization_utils#transformers.TextIteratorStreamer
options (InferenceOptions): The inference options. Used to determine
the stream callback.
"""
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": "", # We update this below
"execution_count": 0, # Multiple outputs are not supported for streaming
"metadata": {},
}
)
accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)

output.data = accumulated_message
return output


class HuggingFaceTextSummarizationTransformer(ParameterizedModelParser):
"""
A model parser for HuggingFace models of type text summarization task using transformers.
"""

def __init__(self):
"""
Returns:
HuggingFaceTextSummarizationTransformer
Usage:
1. Create a new model parser object with the model ID of the model to use.
parser = HuggingFaceTextSummarizationTransformer()
2. Add the model parser to the registry.
config.register_model_parser(parser)
"""
super().__init__()
self.summarizers: dict[str, Pipeline] = {}

def id(self) -> str:
"""
Returns an identifier for the Model Parser
"""
return "HuggingFaceTextSummarizationTransformer"

async def serialize(
self,
prompt_name: str,
data: Any,
ai_config: "AIConfigRuntime",
parameters: Optional[Dict[str, Any]] = None,
**kwargs,
) -> List[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Args:
prompt_name (str): The prompt to be serialized.
inference_settings (dict): Model-specific inference settings to be serialized.
Returns:
List[Prompt]: Serialized representation of the prompt and inference settings.
"""
data = copy.deepcopy(data)

# assume data is completion params for HF text summarization
prompt_input = data["prompt"]

# Prompt is handled, remove from data
data.pop("prompt", None)

model_metadata = ai_config.get_model_metadata(data, self.id())
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)
return [prompt]

async def deserialize(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
_options,
params: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
"""
Defines how to parse a prompt in the .aiconfig for a particular model
and constructs the completion params for that model.
Args:
serialized_data (str): Serialized data from the .aiconfig.
Returns:
dict: Model-specific completion parameters.
"""
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
completion_data = refine_chat_completion_params(model_settings)

# Add resolved prompt
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
completion_data["prompt"] = resolved_prompt
return completion_data

async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Args:
prompt (str): The input prompt.
inference_settings (dict): Model-specific inference settings.
Returns:
InferenceResponse: The response from the model.
"""
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
inputs = completion_data.pop("prompt", None)

model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.summarizers:
self.summarizers[model_name] = pipeline("summarization", model=model_name)
summarizer = self.summarizers[model_name]

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
if should_stream:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer

outputs: List[Output] = []
output = None

def _summarize():
return summarizer(inputs, **completion_data)

if not should_stream:
response: List[Any] = _summarize()
for count, result in enumerate(response):
output = construct_regular_output(result, count)
outputs.append(output)
else:
if completion_data.get("num_return_sequences", 1) > 1:
raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1")
if not streamer:
raise ValueError("Stream option is selected but streamer is not initialized")

# For streaming, cannot call `summarizer` directly otherwise response will be blocking
thread = threading.Thread(target=_summarize)
thread.start()
output = construct_stream_output(streamer, options)
if output is not None:
outputs.append(output)

prompt.outputs = outputs
return prompt.outputs

def get_output_text(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
output: Optional[Output] = None,
) -> str:
if output is None:
output = aiconfig.get_latest_output(prompt)

if output is None:
return ""

# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# HuggingFace Text summarization does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""
Loading

0 comments on commit 00b7acd

Please sign in to comment.