Skip to content

Commit

Permalink
Connect local HF model parser prompt input schemas to the right class…
Browse files Browse the repository at this point in the history
…names (#900)

Connect local HF model parser prompt input schemas to the right
classnames

We did a bit of changing of classnames, so just updating here, as well
as re-ordering some of the imports


Also important to note is that we currently don't have this for the
remote inference text generation class:


https://github.com/lastmile-ai/aiconfig/blob/2364ee32d340a105b4485fa134086f9093d42384/extensions/HuggingFace/python/src/aiconfig_extension_hugging_face/remote_inference_client/text_generation.py#L124

I'll handle that next diff, as well as rename that class (and test that
it works without model name specified). For now I just copy pasted the
remote_inference/text_generation.py --> hf.py (core parser). Just to
show that they're indeed the same functionality and that core parser can
be deleted

## Test Plan
Before



https://github.com/lastmile-ai/aiconfig/assets/151060367/bf10e6f2-56db-423d-8ff0-a636a12f7f96



After



https://github.com/lastmile-ai/aiconfig/assets/151060367/56ddc0fc-3416-4542-81ee-27c1e4c338fb
  • Loading branch information
rossdanlm authored Jan 12, 2024
2 parents 2364ee3 + dff7ba7 commit 78a9378
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,14 @@ def get_output_text(
output_data = output.data
if isinstance(output_data, str):
return output_data


# Doing this to be backwards-compatible with old output format
# where we used to save the TextGenerationResponse or
# TextGenerationStreamResponse in output.data
if hasattr(output_data, "generated_text"):
assert isinstance(output_data.generated_text, str)
return output_data.generated_text

# HuggingFace text generation outputs should only ever be string
# format so shouldn't get here, but just being safe
return json.dumps(output_data, indent=2)
Expand Down
45 changes: 25 additions & 20 deletions python/src/aiconfig/default_parsers/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,33 @@
import json
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Union

# HuggingFace API imports
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import (
TextGenerationResponse,
TextGenerationStreamResponse,
)

from aiconfig import CallbackEvent
from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
Output,
Prompt,
PromptMetadata,
)
from aiconfig.util.config_utils import get_api_key_from_environment
from aiconfig.util.params import resolve_prompt

# HuggingFace API imports
from huggingface_hub import InferenceClient
from huggingface_hub.inference._text_generation import TextGenerationResponse, TextGenerationStreamResponse

from aiconfig import CallbackEvent
from aiconfig.schema import ExecuteResult, Output, OutputDataWithValue, Prompt, PromptMetadata

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
from aiconfig.Config import AIConfigRuntime


# Step 1: define Helpers


def refine_chat_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
def refine_completion_params(model_settings: dict[Any, Any]) -> dict[str, Any]:
"""
Refines the completion params for the HF text generation api. Removes any unsupported params.
The supported keys were found by looking at the HF text generation api. `huggingface_hub.InferenceClient.text_generation()`
Expand Down Expand Up @@ -99,15 +105,15 @@ def construct_stream_output(
return output


def construct_regular_output(response: str, response_includes_details: bool) -> Output:
def construct_regular_output(response: TextGenerationResponse, response_includes_details: bool) -> Output:
metadata = {"raw_response": response}
if response_includes_details:
metadata["details"] = response.details

output = ExecuteResult(
**{
"output_type": "execute_result",
"data": response,
"data": response.generated_text,
"execution_count": 0,
"metadata": metadata,
}
Expand All @@ -120,7 +126,7 @@ class HuggingFaceTextGenerationParser(ParameterizedModelParser):
A model parser for HuggingFace text generation models.
"""

def __init__(self, model_id: str = None, use_api_token=True):
def __init__(self, model_id: str = None, use_api_token=False):
"""
Args:
model_id (str): The model ID of the model to use.
Expand Down Expand Up @@ -164,7 +170,7 @@ async def serialize(
ai_config: "AIConfigRuntime",
parameters: Optional[dict[Any, Any]] = None,
**kwargs,
) -> List[Prompt]:
) -> list[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Expand Down Expand Up @@ -234,7 +240,7 @@ async def deserialize(
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)

completion_data = refine_chat_completion_params(model_settings)
completion_data = refine_completion_params(model_settings)

completion_data["prompt"] = resolved_prompt

Expand Down Expand Up @@ -309,16 +315,15 @@ def get_output_text(
output_data = output.data
if isinstance(output_data, str):
return output_data
if isinstance(output_data, OutputDataWithValue):
if isinstance(output_data.value, str):
return output_data.value
# HuggingFace Text generation does not support function
# calls so shouldn't get here, but just being safe
return json.dumps(output_data.value, indent=2)

# Doing this to be backwards-compatible with old output format
# where we used to save the TextGenerationResponse or
# TextGenerationStreamResponse in output.data
if hasattr(output_data, "generated_text"):
assert isinstance(output_data.generated_text, str)
return output_data.generated_text

# HuggingFace text generation outputs should only ever be string
# format so shouldn't get here, but just being safe
return json.dumps(output_data, indent=2)
return ""
34 changes: 22 additions & 12 deletions python/src/aiconfig/editor/client/src/utils/promptUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ import { OpenAIChatVisionModelParserPromptSchema } from "../shared/prompt_schema
import { DalleImageGenerationParserPromptSchema } from "../shared/prompt_schemas/DalleImageGenerationParserPromptSchema";
import { PaLMTextParserPromptSchema } from "../shared/prompt_schemas/PaLMTextParserPromptSchema";
import { PaLMChatParserPromptSchema } from "../shared/prompt_schemas/PaLMChatParserPromptSchema";
import { HuggingFaceTextGenerationParserPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationParserPromptSchema";
import { AnyscaleEndpointPromptSchema } from "../shared/prompt_schemas/AnyscaleEndpointPromptSchema";
import { HuggingFaceText2ImageDiffusorPromptSchema } from "../shared/prompt_schemas/HuggingFaceText2ImageDiffusorPromptSchema";
import { HuggingFaceTextGenerationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationTransformerPromptSchema";
import { HuggingFaceAutomaticSpeechRecognitionPromptSchema } from "../shared/prompt_schemas/HuggingFaceAutomaticSpeechRecognitionPromptSchema";
import { HuggingFaceTextSummarizationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextSummarizationTransformerPromptSchema";
import { HuggingFaceImage2TextTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceImage2TextTransformerPromptSchema";
import { HuggingFaceText2ImageDiffusorPromptSchema } from "../shared/prompt_schemas/HuggingFaceText2ImageDiffusorPromptSchema";
import { HuggingFaceText2SpeechTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceText2SpeechTransformerPromptSchema";
import { HuggingFaceTextGenerationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationTransformerPromptSchema";
import { HuggingFaceTextSummarizationTransformerPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextSummarizationTransformerPromptSchema";
import { HuggingFaceTextGenerationParserPromptSchema } from "../shared/prompt_schemas/HuggingFaceTextGenerationParserPromptSchema";

/**
* Get the name of the model for the specified prompt. The name will either be specified in the prompt's
Expand Down Expand Up @@ -75,9 +75,10 @@ export const PROMPT_SCHEMAS: Record<string, PromptSchema> = {
"dall-e-2": DalleImageGenerationParserPromptSchema,
"dall-e-3": DalleImageGenerationParserPromptSchema,

// HuggingFaceTextGenerationParser
// TODO: core parser and remote inference share the same code, delete
// hf core parser and keep it in the extension instead
// HuggingFaceTextGenerationParser (Core parser and remote inference extension)
HuggingFaceTextGenerationParser: HuggingFaceTextGenerationParserPromptSchema,
"HuggingFaceAutomaticSpeechRecognitionTransformer": HuggingFaceAutomaticSpeechRecognitionPromptSchema,

// PaLMTextParser
"models/text-bison-001": PaLMTextParserPromptSchema,
Expand All @@ -88,18 +89,27 @@ export const PROMPT_SCHEMAS: Record<string, PromptSchema> = {
// AnyscaleEndpoint
AnyscaleEndpoint: AnyscaleEndpointPromptSchema,

// Local HuggingFace Parsers
// Local HuggingFace Parser Extensions
HuggingFaceAutomaticSpeechRecognitionTransformer:
HuggingFaceAutomaticSpeechRecognitionPromptSchema,

HuggingFaceImage2TextTransformer:
HuggingFaceImage2TextTransformerPromptSchema,

HuggingFaceText2ImageDiffusor: HuggingFaceText2ImageDiffusorPromptSchema,
TextGeneration: HuggingFaceTextGenerationTransformerPromptSchema,
Text2Image: HuggingFaceText2ImageDiffusorPromptSchema,
AutomaticSpeechRecognition: HuggingFaceAutomaticSpeechRecognitionPromptSchema,

HuggingFaceText2SpeechTransformer:
HuggingFaceText2SpeechTransformerPromptSchema,
Text2Speech: HuggingFaceText2SpeechTransformerPromptSchema,

HuggingFaceTextGenerationTransformer:
HuggingFaceTextGenerationTransformerPromptSchema,
TextGeneration: HuggingFaceTextGenerationTransformerPromptSchema,
HuggingFaceTextSummarizationTransformer:
HuggingFaceTextSummarizationTransformerPromptSchema,
HuggingFaceTextTranslationTransformer:
HuggingFaceTextGenerationTransformerPromptSchema,
HuggingFaceImage2TextTransformer:
HuggingFaceImage2TextTransformerPromptSchema,
Text2Speech: HuggingFaceText2SpeechTransformerPromptSchema,
};

export type PromptInputSchema =
Expand Down

0 comments on commit 78a9378

Please sign in to comment.