From dff7ba7ab905860ee0eea5f0bee8aa248954d0c1 Mon Sep 17 00:00:00 2001 From: "Rossdan Craig rossdan@lastmileai.dev" <> Date: Thu, 11 Jan 2024 19:15:50 -0500 Subject: [PATCH] Connect local HF model parser prompt input schemas to the right classnames We did a bit of changing of classnames, so just updating here Also important to note is that we currently don't have this for the remote inference text generation class: https://github.com/lastmile-ai/aiconfig/blob/2364ee32d340a105b4485fa134086f9093d42384/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py#L124 I'll handle that next diff, as well as rename that class (and test that it works without model name specified). For now I just copy pasted the remote_inference/text_generation.py --> hf.py (core parser). Just to show that they're indeed the same functionality and that core parser can be deleted ## Test Plan Before https://github.com/lastmile-ai/aiconfig/assets/151060367/bf10e6f2-56db-423d-8ff0-a636a12f7f96 After https://github.com/lastmile-ai/aiconfig/assets/151060367/56ddc0fc-3416-4542-81ee-27c1e4c338fb --- .../text_generation.py | 9 +++- python/src/aiconfig/default_parsers/hf.py | 45 ++++++++++--------- .../editor/client/src/utils/promptUtils.ts | 34 +++++++++----- 3 files changed, 55 insertions(+), 33 deletions(-) diff --git a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py index 67406f811..02bc0e514 100644 --- a/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py +++ b/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py @@ -315,7 +315,14 @@ def get_output_text( output_data = output.data if isinstance(output_data, str): return output_data - + + # Doing this to be backwards-compatible with old output format + # where we used to save the TextGenerationResponse or + # TextGenerationStreamResponse in output.data + if hasattr(output_data, "generated_text"): + assert isinstance(output_data.generated_text, str) + return output_data.generated_text + # HuggingFace text generation outputs should only ever be string # format so shouldn't get here, but just being safe return json.dumps(output_data, indent=2) diff --git a/python/src/aiconfig/default_parsers/hf.py b/python/src/aiconfig/default_parsers/hf.py index ae9d04239..02bc0e514 100644 --- a/python/src/aiconfig/default_parsers/hf.py +++ b/python/src/aiconfig/default_parsers/hf.py @@ -2,17 +2,25 @@ import json from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union +# HuggingFace API imports +from huggingface_hub import InferenceClient +from huggingface_hub.inference._text_generation import ( + TextGenerationResponse, + TextGenerationStreamResponse, +) + +from aiconfig import CallbackEvent from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser from aiconfig.model_parser import InferenceOptions +from aiconfig.schema import ( + ExecuteResult, + Output, + Prompt, + PromptMetadata, +) from aiconfig.util.config_utils import get_api_key_from_environment from aiconfig.util.params import resolve_prompt -# HuggingFace API imports -from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse - -from aiconfig import CallbackEvent -from aiconfig.schema import ExecuteResult, Output, OutputDataWithValue, Prompt, PromptMetadata # Circuluar Dependency Type Hints if TYPE_CHECKING: @@ -20,9 +28,7 @@ # Step 1: define Helpers - - -def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]: +def refine_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]: """ Refines the completion params for the HF text generation api. Removes any unsupported params. The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()` @@ -99,7 +105,7 @@ def construct_stream_output( return output -def construct_regular_output(response: str, response_includes_details: bool) -> Output: +def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output: metadata = {"raw_response": response} if response_includes_details: metadata["details"] = response.details @@ -107,7 +113,7 @@ def construct_regular_output(response: str, response_includes_details: bool) -> output = ExecuteResult( **{ "output_type": "execute_result", - "data": response, + "data": response.generated_text, "execution_count": 0, "metadata": metadata, } @@ -120,7 +126,7 @@ class HuggingFaceTextGenerationParser(ParameterizedModelParser): A model parser for HuggingFace text generation models. """ - def __init__(self, model_id: str = None, use_api_token=True): + def __init__(self, model_id: str = None, use_api_token=False): """ Args: model_id (str): The model ID of the model to use. @@ -164,7 +170,7 @@ async def serialize( ai_config: "AIConfigRuntime", parameters: Optional[dict[Any, Any]] = None, **kwargs, - ) -> List[Prompt]: + ) -> list[Prompt]: """ Defines how a prompt and model inference settings get serialized in the .aiconfig. @@ -234,7 +240,7 @@ async def deserialize( # Build Completion data model_settings = self.get_model_settings(prompt, aiconfig) - completion_data = refine_chat_completion_params(model_settings) + completion_data = refine_completion_params(model_settings) completion_data["prompt"] = resolved_prompt @@ -309,16 +315,15 @@ def get_output_text( output_data = output.data if isinstance(output_data, str): return output_data - if isinstance(output_data, OutputDataWithValue): - if isinstance(output_data.value, str): - return output_data.value - # HuggingFace Text generation does not support function - # calls so shouldn't get here, but just being safe - return json.dumps(output_data.value, indent=2) # Doing this to be backwards-compatible with old output format # where we used to save the TextGenerationResponse or # TextGenerationStreamResponse in output.data if hasattr(output_data, "generated_text"): + assert isinstance(output_data.generated_text, str) return output_data.generated_text + + # HuggingFace text generation outputs should only ever be string + # format so shouldn't get here, but just being safe + return json.dumps(output_data, indent=2) return "" diff --git a/python/src/aiconfig/editor/client/src/utils/promptUtils.ts b/python/src/aiconfig/editor/client/src/utils/promptUtils.ts index 4231d74ae..c66961b8b 100644 --- a/python/src/aiconfig/editor/client/src/utils/promptUtils.ts +++ b/python/src/aiconfig/editor/client/src/utils/promptUtils.ts @@ -4,14 +4,14 @@ import { OpenAIChatVisionModelParserPromptSchema } from "../shared/prompt_schema import { DalleImageGenerationParserPromptSchema } from "../shared/prompt_schemas/DalleImageGenerationParserPromptSchema"; import { PaLMTextParserPromptSchema } from "../shared/prompt_schemas/PaLMTextParserPromptSchema"; import { PaLMChatParserPromptSchema } from "../shared/prompt_schemas/PaLMChatParserPromptSchema"; -import { HuggingFaceTextGenerationParserPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationParserPromptSchema"; import { AnyscaleEndpointPromptSchema } from "../shared/prompt_schemas/AnyscaleEndpointPromptSchema"; -import { HuggingFaceText2ImageDiffusorPromptSchema } from "../shared/prompt_schemas/HuggingFaceText2ImageDiffusorPromptSchema"; -import { HuggingFaceTextGenerationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationTransformerPromptSchema"; import { HuggingFaceAutomaticSpeechRecognitionPromptSchema } from "../shared/prompt_schemas/HuggingFaceAutomaticSpeechRecognitionPromptSchema"; -import { HuggingFaceTextSummarizationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextSummarizationTransformerPromptSchema"; import { HuggingFaceImage2TextTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceImage2TextTransformerPromptSchema"; +import { HuggingFaceText2ImageDiffusorPromptSchema } from "../shared/prompt_schemas/HuggingFaceText2ImageDiffusorPromptSchema"; import { HuggingFaceText2SpeechTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceText2SpeechTransformerPromptSchema"; +import { HuggingFaceTextGenerationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationTransformerPromptSchema"; +import { HuggingFaceTextSummarizationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextSummarizationTransformerPromptSchema"; +import { HuggingFaceTextGenerationParserPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationParserPromptSchema"; /** * Get the name of the model for the specified prompt. The name will either be specified in the prompt's @@ -75,9 +75,10 @@ export const PROMPT_SCHEMAS: Record = { "dall-e-2": DalleImageGenerationParserPromptSchema, "dall-e-3": DalleImageGenerationParserPromptSchema, - // HuggingFaceTextGenerationParser + // TODO: core parser and remote inference share the same code, delete + // hf core parser and keep it in the extension instead + // HuggingFaceTextGenerationParser (Core parser and remote inference extension) HuggingFaceTextGenerationParser: HuggingFaceTextGenerationParserPromptSchema, - "HuggingFaceAutomaticSpeechRecognitionTransformer": HuggingFaceAutomaticSpeechRecognitionPromptSchema, // PaLMTextParser "models/text-bison-001": PaLMTextParserPromptSchema, @@ -88,18 +89,27 @@ export const PROMPT_SCHEMAS: Record = { // AnyscaleEndpoint AnyscaleEndpoint: AnyscaleEndpointPromptSchema, - // Local HuggingFace Parsers + // Local HuggingFace Parser Extensions + HuggingFaceAutomaticSpeechRecognitionTransformer: + HuggingFaceAutomaticSpeechRecognitionPromptSchema, + + HuggingFaceImage2TextTransformer: + HuggingFaceImage2TextTransformerPromptSchema, + HuggingFaceText2ImageDiffusor: HuggingFaceText2ImageDiffusorPromptSchema, - TextGeneration: HuggingFaceTextGenerationTransformerPromptSchema, Text2Image: HuggingFaceText2ImageDiffusorPromptSchema, - AutomaticSpeechRecognition: HuggingFaceAutomaticSpeechRecognitionPromptSchema, + + HuggingFaceText2SpeechTransformer: + HuggingFaceText2SpeechTransformerPromptSchema, + Text2Speech: HuggingFaceText2SpeechTransformerPromptSchema, + + HuggingFaceTextGenerationTransformer: + HuggingFaceTextGenerationTransformerPromptSchema, + TextGeneration: HuggingFaceTextGenerationTransformerPromptSchema, HuggingFaceTextSummarizationTransformer: HuggingFaceTextSummarizationTransformerPromptSchema, HuggingFaceTextTranslationTransformer: HuggingFaceTextGenerationTransformerPromptSchema, - HuggingFaceImage2TextTransformer: - HuggingFaceImage2TextTransformerPromptSchema, - Text2Speech: HuggingFaceText2SpeechTransformerPromptSchema, }; export type PromptInputSchema =