From 13e590f75fd298b2fa9320d071f6067809416875 Mon Sep 17 00:00:00 2001 From: Jonathan Lessinger Date: Wed, 3 Jan 2024 19:44:12 -0500 Subject: [PATCH] [RFC][AIC-py] hf summarization model parser Very similar to text gen Test: ``` ... { "name": "test_hf_sum", "input": "HMS Duncan was a D-class destroyer ...", # [contents of https://en.wikipedia.org/wiki/HMS_Duncan_(D99)] "metadata": { "model": { "name": "stevhliu/my_awesome_billsum_model", "settings": { "min_length": 100, "max_length": 200, "num_beams": 1 } } } }, ... } ``` ``` import asyncio from aiconfig_extension_hugging_face.local_inference.text_summarization import HuggingFaceTextSummarizationTransformer from aiconfig import AIConfigRuntime, InferenceOptions, CallbackManager # Load the aiconfig. mp = HuggingFaceTextSummarizationTransformer() AIConfigRuntime.register_model_parser(mp, "stevhliu/my_awesome_billsum_model") config = AIConfigRuntime.load("/Users/jonathan/Projects/aiconfig/cookbooks/Getting-Started/travel.aiconfig.json") config.callback_manager = CallbackManager([]) def print_stream(data, _accumulated_data, _index): print(data, end="", flush=True) async def run(): print("Stream") options = InferenceOptions(stream=True, stream_callback=print_stream) out = await config.run("test_hf_sum", options=options) print("Output:\n", out) print("no stream") options = InferenceOptions(stream=False) out = await config.run("test_hf_sum", options=options) print("Output:\n", out) asyncio.run(run()) # OUT: Stream Token indices sequence length is longer than the specified maximum sequence length for this model (2778 > 512). Running this sequence through the model will result in indexing errors /opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:430: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`. warnings.warn( /opt/homebrew/Caskroom/miniconda/base/envs/aiconfig/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:449: UserWarning: `num_beams` is set to 1. However, `length_penalty` is set to `2.0` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `length_penalty`. warnings.warn( escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the Output: [ExecuteResult(output_type='execute_result', execution_count=0, data=" escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the ", mime_type=None, metadata={})] no stream Output: [ExecuteResult(output_type='execute_result', execution_count=0, data="escorted the 13th Destroyer Flotilla in the Mediterranean and escorited the carrier Argus to Malta during the war . The ship was recalled home to be converted into an escoter destroyer in late 1942. The vessel was repaired and given a refit at Gibraltar on 16 November, and was sold for scrap later that year. The crew of the ship escores the ship to the Middle East, and the ship is a 'disaster' of the .", mime_type=None, metadata={})] ``` --- .../__init__.py | 5 +- .../local_inference/text_summarization.py | 305 ++++++++++++++++++ 2 files changed, 309 insertions(+), 1 deletion(-) create mode 100644 extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py index d08f71857..43e3bf721 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/__init__.py @@ -1,7 +1,10 @@ from .local_inference.text_2_image import HuggingFaceText2ImageDiffusor from .local_inference.text_generation import HuggingFaceTextGenerationTransformer from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient +from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer -LOCAL_INFERENCE_CLASSES = ["HuggingFaceText2ImageDiffusor", "HuggingFaceTextGenerationTransformer"] +# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient + +LOCAL_INFERENCE_CLASSES = ["HuggingFaceText2ImageDiffusor", "HuggingFaceTextGenerationTransformer", "HuggingFaceTextSummarizationTransformer"] REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationClient"] __ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py new file mode 100644 index 000000000..bba735b4f --- /dev/null +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/local_inference/text_summarization.py @@ -0,0 +1,305 @@ +import copy +import json +import threading +from typing import TYPE_CHECKING, Any, Dict, List, Optional +from transformers import ( + AutoTokenizer, + Pipeline, + pipeline, + TextIteratorStreamer, +) + +from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser +from aiconfig.model_parser import InferenceOptions +from aiconfig.schema import ( + ExecuteResult, + Output, + OutputDataWithValue, + Prompt, + PromptMetadata, +) +from aiconfig.util.params import resolve_prompt + +# Circuluar Dependency Type Hints +if TYPE_CHECKING: + from aiconfig.Config import AIConfigRuntime + + +# Step 1: define Helpers +def refine_chat_completion_params(model_settings: Dict[str, Any]) -> Dict[str, Any]: + """ + Refines the completion params for the HF text summarization api. Removes any unsupported params. + The supported keys were found by looking at the HF text summarization api. `huggingface_hub.InferenceClient.text_summarization()` + """ + + supported_keys = { + "max_length", + "max_new_tokens", + "min_length", + "min_new_tokens", + "early_stopping", + "max_time", + "do_sample", + "num_beams", + "num_beam_groups", + "penalty_alpha", + "use_cache", + "temperature", + "top_k", + "top_p", + "typical_p", + "epsilon_cutoff", + "eta_cutoff", + "diversity_penalty", + "repetition_penalty", + "encoder_repetition_penalty", + "length_penalty", + "no_repeat_ngram_size", + "bad_words_ids", + "force_words_ids", + "renormalize_logits", + "constraints", + "forced_bos_token_id", + "forced_eos_token_id", + "remove_invalid_values", + "exponential_decay_length_penalty", + "suppress_tokens", + "begin_suppress_tokens", + "forced_decoder_ids", + "sequence_bias", + "guidance_scale", + "low_memory", + "num_return_sequences", + "output_attentions", + "output_hidden_states", + "output_scores", + "return_dict_in_generate", + "pad_token_id", + "bos_token_id", + "eos_token_id", + "encoder_no_repeat_ngram_size", + "decoder_start_token_id", + "num_assistant_tokens", + "num_assistant_tokens_schedule", + } + + completion_data = {} + for key in model_settings: + if key.lower() in supported_keys: + completion_data[key.lower()] = model_settings[key] + + return completion_data + + +def construct_regular_output(result: Dict[str, str], execution_count: int) -> Output: + """ + Construct regular output per response result, without streaming enabled + """ + output = ExecuteResult( + **{ + "output_type": "execute_result", + "data": result["summary_text"], + "execution_count": execution_count, + "metadata": {}, + } + ) + return output + + +def construct_stream_output( + streamer: TextIteratorStreamer, + options: InferenceOptions, +) -> Output: + """ + Constructs the output for a stream response. + + Args: + streamer (TextIteratorStreamer): Streams the output. See: + https://huggingface.co/docs/transformers/v4.35.2/en/internal/summarization_utils#transformers.TextIteratorStreamer + options (InferenceOptions): The inference options. Used to determine + the stream callback. + + """ + output = ExecuteResult( + **{ + "output_type": "execute_result", + "data": "", # We update this below + "execution_count": 0, # Multiple outputs are not supported for streaming + "metadata": {}, + } + ) + accumulated_message = "" + for new_text in streamer: + if isinstance(new_text, str): + accumulated_message += new_text + options.stream_callback(new_text, accumulated_message, 0) + + output.data = accumulated_message + return output + + +class HuggingFaceTextSummarizationTransformer(ParameterizedModelParser): + """ + A model parser for HuggingFace models of type text summarization task using transformers. + """ + + def __init__(self): + """ + Returns: + HuggingFaceTextSummarizationTransformer + + Usage: + 1. Create a new model parser object with the model ID of the model to use. + parser = HuggingFaceTextSummarizationTransformer() + 2. Add the model parser to the registry. + config.register_model_parser(parser) + """ + super().__init__() + self.summarizers: dict[str, Pipeline] = {} + + def id(self) -> str: + """ + Returns an identifier for the Model Parser + """ + return "HuggingFaceTextSummarizationTransformer" + + async def serialize( + self, + prompt_name: str, + data: Any, + ai_config: "AIConfigRuntime", + parameters: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> List[Prompt]: + """ + Defines how a prompt and model inference settings get serialized in the .aiconfig. + + Args: + prompt_name (str): The prompt to be serialized. + inference_settings (dict): Model-specific inference settings to be serialized. + + Returns: + List[Prompt]: Serialized representation of the prompt and inference settings. + """ + data = copy.deepcopy(data) + + # assume data is completion params for HF text summarization + prompt_input = data["prompt"] + + # Prompt is handled, remove from data + data.pop("prompt", None) + + model_metadata = ai_config.get_model_metadata(data, self.id()) + prompt = Prompt( + name=prompt_name, + input=prompt_input, + metadata=PromptMetadata(model=model_metadata, parameters=parameters, **kwargs), + ) + return [prompt] + + async def deserialize( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + _options, + params: Optional[Dict[str, Any]] = {}, + ) -> Dict[str, Any]: + """ + Defines how to parse a prompt in the .aiconfig for a particular model + and constructs the completion params for that model. + + Args: + serialized_data (str): Serialized data from the .aiconfig. + + Returns: + dict: Model-specific completion parameters. + """ + # Build Completion data + model_settings = self.get_model_settings(prompt, aiconfig) + completion_data = refine_chat_completion_params(model_settings) + + # Add resolved prompt + resolved_prompt = resolve_prompt(prompt, params, aiconfig) + completion_data["prompt"] = resolved_prompt + return completion_data + + async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]: + """ + Invoked to run a prompt in the .aiconfig. This method should perform + the actual model inference based on the provided prompt and inference settings. + + Args: + prompt (str): The input prompt. + inference_settings (dict): Model-specific inference settings. + + Returns: + InferenceResponse: The response from the model. + """ + completion_data = await self.deserialize(prompt, aiconfig, options, parameters) + inputs = completion_data.pop("prompt", None) + + model_name: str = aiconfig.get_model_name(prompt) + if isinstance(model_name, str) and model_name not in self.summarizers: + self.summarizers[model_name] = pipeline("summarization", model=model_name) + summarizer = self.summarizers[model_name] + + # if stream enabled in runtime options and config, then stream. Otherwise don't stream. + streamer = None + should_stream = (options.stream if options else False) and (not "stream" in completion_data or completion_data.get("stream") != False) + if should_stream: + tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name) + streamer = TextIteratorStreamer(tokenizer) + completion_data["streamer"] = streamer + + outputs: List[Output] = [] + output = None + + def _summarize(): + return summarizer(inputs, **completion_data) + + if not should_stream: + response: List[Any] = _summarize() + for count, result in enumerate(response): + output = construct_regular_output(result, count) + outputs.append(output) + else: + if completion_data.get("num_return_sequences", 1) > 1: + raise ValueError("Sorry, TextIteratorStreamer does not support multiple return sequences, please set `num_return_sequences` to 1") + if not streamer: + raise ValueError("Stream option is selected but streamer is not initialized") + + # For streaming, cannot call `summarizer` directly otherwise response will be blocking + thread = threading.Thread(target=_summarize) + thread.start() + output = construct_stream_output(streamer, options) + if output is not None: + outputs.append(output) + + prompt.outputs = outputs + return prompt.outputs + + def get_output_text( + self, + prompt: Prompt, + aiconfig: "AIConfigRuntime", + output: Optional[Output] = None, + ) -> str: + if output is None: + output = aiconfig.get_latest_output(prompt) + + if output is None: + return "" + + # TODO (rossdanlm): Handle multiple outputs in list + # https://github.com/lastmile-ai/aiconfig/issues/467 + if output.output_type == "execute_result": + output_data = output.data + if isinstance(output_data, str): + return output_data + if isinstance(output_data, OutputDataWithValue): + if isinstance(output_data.value, str): + return output_data.value + # HuggingFace Text summarization does not support function + # calls so shouldn't get here, but just being safe + return json.dumps(output_data.value, indent=2) + return ""