Skip to content

Commit

Permalink
[RFC][AIC-py] hf summarization model parser
Browse files Browse the repository at this point in the history
Very similar to text gen

Test:

```
...
    {
      "name": "test_hf_sum",
      "input": "HMS Duncan was a D-class destroyer ...", # [contents of https://en.wikipedia.org/wiki/HMS_Duncan_(D99)]
      "metadata": {
        "model": {
          "name": "stevhliu/my_awesome_billsum_model",
          "settings": {
            "min_length": 100,
            "max_length": 200,
            "num_beams": 1
          }
        }
      }
    },
...
}
```

```
import asyncio
from aiconfig_extension_hugging_face.local_inference.text_summarization import HuggingFaceTextSummarizationTransformer

from aiconfig import AIConfigRuntime, InferenceOptions, CallbackManager

# Load the aiconfig.

mp = HuggingFaceTextSummarizationTransformer()

AIConfigRuntime.register_model_parser(mp, "stevhliu/my_awesome_billsum_model")

config = AIConfigRuntime.load("/Users/jonathan/Projects/aiconfig/cookbooks/Getting-Started/travel.aiconfig.json")
config.callback_manager = CallbackManager([])


def print_stream(data, _accumulated_data, _index):
    print(data, end="", flush=True)


async def run():
    print("Stream")
    options = InferenceOptions(stream=True, stream_callback=print_stream)
    out = await config.run("test_hf_sum", options=options)
    print("Output:\n", out)

    print("no stream")
    options = InferenceOptions(stream=False)
    out = await config.run("test_hf_sum", options=options)
    print("Output:\n", out)


asyncio.run(run())


# OUT:

Stream
Token indices sequence length is longer than the specified maximum sequence length for this model (2778 > 512). Running this sequence through the model will result in indexing errors
/opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:430: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.
  warnings.warn(
/opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:449: UserWarning: `num_beams` is set to 1. However, `length_penalty` is set to `2.0` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `length_penalty`.
  warnings.warn(
<pad> escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war  The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the </s>Output:
 [ExecuteResult(output_type='execute_result', execution_count=0, data="<pad> escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war  The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the </s>", mime_type=None, metadata={})]
no stream
Output:
 [ExecuteResult(output_type='execute_result', execution_count=0, data="escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war . The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the .", mime_type=None, metadata={})]

```
  • Loading branch information
jonathanlastmileai committed Jan 4, 2024
1 parent 6bbe147 commit 13e590f
Show file tree
Hide file tree
Showing 2 changed files with 309 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor
from .local_inference.text_generation import HuggingFaceTextGenerationTransformer
from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer

LOCAL_INFERENCE_CLASSES = ["HuggingFaceText2ImageDiffusor", "HuggingFaceTextGenerationTransformer"]
# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient

LOCAL_INFERENCE_CLASSES = ["HuggingFaceText2ImageDiffusor", "HuggingFaceTextGenerationTransformer", "HuggingFaceTextSummarizationTransformer"]
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationClient"]
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
import copy
import json
import threading
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from transformers import (
AutoTokenizer,
Pipeline,
pipeline,
TextIteratorStreamer,
)

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_prompt

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
from aiconfig.Config import AIConfigRuntime


# Step 1: define Helpers
def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]:
"""
Refines the completion params for the HF text summarization api. Removes any unsupported params.
The supported keys were found by looking at the HF text summarization api. `huggingface_hub.InferenceClient.text_summarization()`
"""

supported_keys = {
"max_length",
"max_new_tokens",
"min_length",
"min_new_tokens",
"early_stopping",
"max_time",
"do_sample",
"num_beams",
"num_beam_groups",
"penalty_alpha",
"use_cache",
"temperature",
"top_k",
"top_p",
"typical_p",
"epsilon_cutoff",
"eta_cutoff",
"diversity_penalty",
"repetition_penalty",
"encoder_repetition_penalty",
"length_penalty",
"no_repeat_ngram_size",
"bad_words_ids",
"force_words_ids",
"renormalize_logits",
"constraints",
"forced_bos_token_id",
"forced_eos_token_id",
"remove_invalid_values",
"exponential_decay_length_penalty",
"suppress_tokens",
"begin_suppress_tokens",
"forced_decoder_ids",
"sequence_bias",
"guidance_scale",
"low_memory",
"num_return_sequences",
"output_attentions",
"output_hidden_states",
"output_scores",
"return_dict_in_generate",
"pad_token_id",
"bos_token_id",
"eos_token_id",
"encoder_no_repeat_ngram_size",
"decoder_start_token_id",
"num_assistant_tokens",
"num_assistant_tokens_schedule",
}

completion_data = {}
for key in model_settings:
if key.lower() in supported_keys:
completion_data[key.lower()] = model_settings[key]

return completion_data


def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output:
"""
Construct regular output per response result, without streaming enabled
"""
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": result["summary_text"],
"execution_count": execution_count,
"metadata": {},
}
)
return output


def construct_stream_output(
streamer: TextIteratorStreamer,
options: InferenceOptions,
) -> Output:
"""
Constructs the output for a stream response.
Args:
streamer (TextIteratorStreamer): Streams the output. See:
https://huggingface.co/docs/transformers/v4.35.2/en/internal/summarization_utils#transformers.TextIteratorStreamer
options (InferenceOptions): The inference options. Used to determine
the stream callback.
"""
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": "", # We update this below
"execution_count": 0, # Multiple outputs are not supported for streaming
"metadata": {},
}
)
accumulated_message = ""
for new_text in streamer:
if isinstance(new_text, str):
accumulated_message += new_text
options.stream_callback(new_text, accumulated_message, 0)

output.data = accumulated_message
return output


class HuggingFaceTextSummarizationTransformer(ParameterizedModelParser):
"""
A model parser for HuggingFace models of type text summarization task using transformers.
"""

def __init__(self):
"""
Returns:
HuggingFaceTextSummarizationTransformer
Usage:
1. Create a new model parser object with the model ID of the model to use.
parser = HuggingFaceTextSummarizationTransformer()
2. Add the model parser to the registry.
config.register_model_parser(parser)
"""
super().__init__()
self.summarizers: dict[str, Pipeline] = {}

def id(self) -> str:
"""
Returns an identifier for the Model Parser
"""
return "HuggingFaceTextSummarizationTransformer"

async def serialize(
self,
prompt_name: str,
data: Any,
ai_config: "AIConfigRuntime",
parameters: Optional[Dict[str, Any]] = None,
**kwargs,
) -> List[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Args:
prompt_name (str): The prompt to be serialized.
inference_settings (dict): Model-specific inference settings to be serialized.
Returns:
List[Prompt]: Serialized representation of the prompt and inference settings.
"""
data = copy.deepcopy(data)

# assume data is completion params for HF text summarization
prompt_input = data["prompt"]

# Prompt is handled, remove from data
data.pop("prompt", None)

model_metadata = ai_config.get_model_metadata(data, self.id())
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs),
)
return [prompt]

async def deserialize(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
_options,
params: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
"""
Defines how to parse a prompt in the .aiconfig for a particular model
and constructs the completion params for that model.
Args:
serialized_data (str): Serialized data from the .aiconfig.
Returns:
dict: Model-specific completion parameters.
"""
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
completion_data = refine_chat_completion_params(model_settings)

# Add resolved prompt
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
completion_data["prompt"] = resolved_prompt
return completion_data

async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Args:
prompt (str): The input prompt.
inference_settings (dict): Model-specific inference settings.
Returns:
InferenceResponse: The response from the model.
"""
completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
inputs = completion_data.pop("prompt", None)

model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.summarizers:
self.summarizers[model_name] = pipeline("summarization", model=model_name)
summarizer = self.summarizers[model_name]

# if stream enabled in runtime options and config, then stream. Otherwise don't stream.
streamer = None
should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False)
if should_stream:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextIteratorStreamer(tokenizer)
completion_data["streamer"] = streamer

outputs: List[Output] = []
output = None

def _summarize():
return summarizer(inputs, **completion_data)

if not should_stream:
response: List[Any] = _summarize()
for count, result in enumerate(response):
output = construct_regular_output(result, count)
outputs.append(output)
else:
if completion_data.get("num_return_sequences", 1) > 1:
raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1")
if not streamer:
raise ValueError("Stream option is selected but streamer is not initialized")

# For streaming, cannot call `summarizer` directly otherwise response will be blocking
thread = threading.Thread(target=_summarize)
thread.start()
output = construct_stream_output(streamer, options)
if output is not None:
outputs.append(output)

prompt.outputs = outputs
return prompt.outputs

def get_output_text(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
output: Optional[Output] = None,
) -> str:
if output is None:
output = aiconfig.get_latest_output(prompt)

if output is None:
return ""

# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# HuggingFace Text summarization does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)
return ""

0 comments on commit 13e590f

Please sign in to comment.