Skip to content

Commit

Permalink
[AIC-py] add hf text 2 speech parser
Browse files Browse the repository at this point in the history
```
{
  "name": "NYC Trip Planner",
  "description": "Intrepid explorer with ChatGPT and AIConfig",
  "schema_version": "latest",
  "metadata": {
    "models": {
      "gpt2": {
        "model": "gpt2",
        "max_tokens": 3000,
        "system_prompt": "Some random system prompt"
      },
      "suno/bark": {
        "model": "suno/bark"
      }
    },
    "model_parsers": {
      "suno/bark": "HuggingFaceText2SpeechTransformer"
    },
    "default_model": "gpt2"
  },
  "prompts": [
    {
      "name": "test_tts",
      "input": "the quick brown fox",
      "metadata": {
        "model": {
          "name": "suno/bark",
          "settings": {}
        }
      }
    }
  ]
}


    mp = HuggingFaceText2SpeechTransformer()
    ModelParserRegistry.register_model_parser(mp)

    config = AIConfigRuntime.load("/Users/jonathan/Projects/aiconfig/test_hf_t2s.aiconfig.json")
    config.callback_manager = CallbackManager([])

    options = InferenceOptions(stream=False)
    out = await config.run("test_tts", options=options)
    with open("test.wav", "wb") as f:
        f.write(_b64_decode_str(out[0].data.value))


Listened to test.wav, it is in fact someone saying "the quick brown fox".

```
  • Loading branch information
jonathanlastmileai committed Jan 8, 2024
1 parent 375f1a3 commit 5ccf174
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 1 deletion.
3 changes: 2 additions & 1 deletion extensions/HuggingFace/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ diffusers # Used for image + audio generation
torch
torchvision
torchaudio
scipy # array -> wav file, text-speech. torchaudio.save seems broken.
transformers # Used for text generation

#Other
asyncio
asyncio
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from .remote_inference_client.text_generation import HuggingFaceTextGenerationParser
from .local_inference.text_summarization import HuggingFaceTextSummarizationTransformer
from .local_inference.text_translation import HuggingFaceTextTranslationTransformer
from .local_inference.text_2_speech import HuggingFaceText2SpeechTransformer


# from .remote_inference_client.text_generation import HuggingFaceTextGenerationClient

Expand All @@ -11,6 +13,7 @@
"HuggingFaceTextGenerationTransformer",
"HuggingFaceTextSummarizationTransformer",
"HuggingFaceTextTranslationTransformer",
"HuggingFaceText2SpeechTransformer",
]
REMOTE_INFERENCE_CLASSES = ["HuggingFaceTextGenerationParser"]
__ALL__ = LOCAL_INFERENCE_CLASSES + REMOTE_INFERENCE_CLASSES
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
import base64
import copy
import io
import numpy as np
import torch
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from transformers import Pipeline, pipeline
from scipy.io.wavfile import write as write_wav

from aiconfig.default_parsers.parameterized_model_parser import ParameterizedModelParser
from aiconfig.model_parser import InferenceOptions
from aiconfig.schema import (
ExecuteResult,
Output,
OutputDataWithValue,
Prompt,
PromptMetadata,
)
from aiconfig.util.params import resolve_prompt

# Circuluar Dependency Type Hints
if TYPE_CHECKING:
from aiconfig.Config import AIConfigRuntime


# Step 1: define Helpers
def refine_pipeline_creation_params(model_settings: Dict[str, Any]) -> List[Dict[str, Any]]:
supported_keys = {
"torch_dtype",
"force_download",
"cache_dir",
"resume_download",
"proxies",
"output_loading_info",
"local_files_only",
"use_auth_token",
"revision",
"custom_revision",
"mirror",
"device_map",
"max_memory",
"offload_folder",
"offload_state_dict",
"low_cpu_mem_usage",
"use_safetensors",
"variant",
}

pipeline_creation_params: Dict[str, Any] = {}
completion_params: Dict[str, Any] = {}
for key in model_settings:
if key.lower() in supported_keys:
pipeline_creation_params[key.lower()] = model_settings[key]
else:
if key.lower() == "kwargs" and isinstance(model_settings[key], Dict):
completion_params.update(model_settings[key])
else:
completion_params[key.lower()] = model_settings[key]

return [pipeline_creation_params, completion_params]


def refine_completion_params(unfiltered_completion_params: Dict[str, Any]) -> Dict[str, Any]:
supported_keys = {
# ???
}

completion_params: Dict[str, Any] = {}
for key in unfiltered_completion_params:
if key.lower() in supported_keys:
completion_params[key.lower()] = unfiltered_completion_params[key]

return completion_params


def construct_output(audio, execution_count: int) -> Output:
def _b64_encode_bytes(byte_array: bytes) -> str:
return base64.b64encode(byte_array).decode("utf-8")

def _audio_ndarray_to_wav_bytes(audio: np.ndarray, sampling_rate: int) -> bytes:
buffered = io.BytesIO()
write_wav(buffered, sampling_rate, audio)

# get byte array from the buffer
byte_array = buffered.getvalue()
return byte_array

def _audio_ndarray_to_b64_str(audio: np.ndarray, sampling_rate: int) -> str:
byte_array = _audio_ndarray_to_wav_bytes(audio, sampling_rate)
return _b64_encode_bytes(byte_array)

data = dict(kind="base64", value=_audio_ndarray_to_b64_str(np.squeeze(audio["audio"]), audio["sampling_rate"]))
output = ExecuteResult(
**{
"output_type": "execute_result",
"data": data,
"execution_count": execution_count,
"metadata": {},
"mime_type": "audio/wav",
}
)
return output


class HuggingFaceText2SpeechTransformer(ParameterizedModelParser):
def __init__(self):
super().__init__()
self.synthesizers: dict[str, Pipeline] = {}

def id(self) -> str:
"""
Returns an identifier for the Model Parser
"""
return "HuggingFaceText2SpeechTransformer"

async def serialize(
self,
prompt_name: str,
data: Any,
ai_config: "AIConfigRuntime",
parameters: Optional[Dict[str, Any]] = None,
**completion_params,
) -> List[Prompt]:
"""
Defines how a prompt and model inference settings get serialized in the .aiconfig.
Args:
prompt (str): The prompt to be serialized.
inference_settings (dict): Model-specific inference settings to be serialized.
Returns:
str: Serialized representation of the prompt and inference settings.
"""
data = copy.deepcopy(data)

# assume data is completion params for HF text to image task:
# https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__
# TODO (rossdanlm): Figure out how to check for StableDiffusionXLPipeline
prompt_input = data["prompt"]

# Prompt is handled, remove from data
data.pop("prompt", None)

model_metadata = ai_config.get_model_metadata(data, self.id())
prompt = Prompt(
name=prompt_name,
input=prompt_input,
metadata=PromptMetadata(model=model_metadata, parameters=parameters, **completion_params),
)
return [prompt]

async def deserialize(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
_options,
params: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
"""
Defines how to parse a prompt in the .aiconfig for a particular model
and constructs the completion params for that model.
Args:
serialized_data (str): Serialized data from the .aiconfig.
Returns:
dict: Model-specific completion parameters.
"""
# Build Completion data
model_settings = self.get_model_settings(prompt, aiconfig)
[_pipeline_creation_params, unfiltered_completion_params] = refine_pipeline_creation_params(model_settings)
completion_data = refine_completion_params(unfiltered_completion_params)

# Add resolved prompt
resolved_prompt = resolve_prompt(prompt, params, aiconfig)
completion_data["prompt"] = resolved_prompt
return completion_data

async def run_inference(self, prompt: Prompt, aiconfig: "AIConfigRuntime", options: InferenceOptions, parameters: Dict[str, Any]) -> List[Output]:
"""
Invoked to run a prompt in the .aiconfig. This method should perform
the actual model inference based on the provided prompt and inference settings.
Args:
prompt (str): The input prompt.
inference_settings (dict): Model-specific inference settings.
Returns:
InferenceResponse: The response from the model.
"""
model_settings = self.get_model_settings(prompt, aiconfig)
[pipeline_creation_data, _] = refine_pipeline_creation_params(model_settings)

model_name: str = aiconfig.get_model_name(prompt)
if isinstance(model_name, str) and model_name not in self.synthesizers:
self.synthesizers[model_name] = pipeline("text-to-speech", model_name)
synthesizer = self.synthesizers[model_name]

completion_data = await self.deserialize(prompt, aiconfig, options, parameters)
inputs = completion_data.pop("prompt", None)
response = synthesizer(inputs, **completion_data)

outputs: List[Output] = []
assert not isinstance(response, list)
for count, audio in enumerate([response]):
output = construct_output(audio, count)
outputs.append(output)

prompt.outputs = outputs
return prompt.outputs

def get_output_text(
self,
prompt: Prompt,
aiconfig: "AIConfigRuntime",
output: Optional[Output] = None,
) -> str:
if output is None:
output = aiconfig.get_latest_output(prompt)

if output is None:
return ""

# TODO (rossdanlm): Handle multiple outputs in list
# https://github.com/lastmile-ai/aiconfig/issues/467
if output.output_type == "execute_result":
if isinstance(output.data, OutputDataWithValue):
return output.data.value
elif isinstance(output.data, str):
return output.data
return ""

0 comments on commit 5ccf174

Please sign in to comment.