Skip to content

Commit

Permalink
adding a hugging face deployment
Browse files Browse the repository at this point in the history
  • Loading branch information
appoose committed Feb 20, 2024
1 parent acb4fc9 commit 817ae4b
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 0 deletions.
16 changes: 16 additions & 0 deletions aana/configs/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,23 @@
model="Salesforce/blip2-opt-2.7b",
dtype=Dtype.FLOAT16,
batch_size=2,
num_processing_threads=2,
).dict(),
),
"hf_aanaphi2_deployment": HFAnnaphi2Deployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.25},
user_config=Aanaphi2Config(
model="mobiuslabsgmbh/aanaphi2-v0.1",
dtype=Dtype.FLOAT16,
# batch_size=2,
num_processing_threads=2,
chat_template="aanaphi2",
max_length=1024,
# default_sampling_params=SamplingParams(
# max_length=1024
# ),
).dict(),
),
"whisper_deployment_medium": WhisperDeployment.options(
Expand Down
216 changes: 216 additions & 0 deletions aana/deployments/hf_aanaphi2_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from typing import Any, TypedDict

import torch
import transformers
from pydantic import BaseModel, Field
from ray import serve
# from transformers import Blip2ForConditionalGeneration, Blip2Processor
from aana.deployments.base_deployment import BaseDeployment
from aana.exceptions.general import InferenceException
from aana.models.core.dtype import Dtype
from aana.models.core.image import Image
from aana.utils.batch_processor import BatchProcessor
from aana.utils.test import test_cache
from aana.utils.chat_template import apply_chat_template


class HFAnnaphi2Config(BaseModel):
"""The configuration for the Annaphi2 deployment with HuggingFace models.
Attributes:
model (str): the model ID on HuggingFace
dtype (str): the data type (optional, default: "auto"), one of "auto", "float32", "float16"
batch_size (int): the batch size (optional, default: 1)
num_processing_threads (int): the number of processing threads (optional, default: 1)
"""

model: str
dtype: Dtype = Field(default=Dtype.AUTO)
# batch_size: int = Field(default=1)
num_processing_threads: int = Field(default=1)
# default_sampling_params: SamplingParams
max_length: int = Field(default=1024)
chat_template: str | None = Field(default=None)


class LLMOutput(TypedDict):
"""The output of the LLM model.
Attributes:
text (str): the generated text
"""

text: str

class LLMBatchOutput(TypedDict):
"""The output of the LLM model for a batch of inputs.
Attributes:
texts (List[str]): the list of generated texts
"""

texts: list[str]



# class ChatOutput(TypedDict):
# """The output of the chat model.

# Attributes:
# dialog (ChatDialog): the dialog with the responses from the model
# """

# message: ChatMessage


@serve.deployment
class HFAnnaphi2Deployment(BaseDeployment):
"""Deployment to serve Annaphi2 models using HuggingFace."""

async def apply_config(self, config: dict[str, Any]):
"""Apply the configuration.
The method is called when the deployment is created or updated.
It loads the model and processor from HuggingFace.
The configuration should conform to the HFAnnaphi2Config schema.
"""
config_obj = HFAnnaphi2Config(**config)

# Create the batch processor to split the requests into batches
# and process them in parallel
# self.batch_size = config_obj.batch_size
self.num_processing_threads = config_obj.num_processing_threads
self.chat_template_name = config_obj.chat_template
# The actual inference is done in _generate()
# We use lambda because BatchProcessor expects dict as input
# and we use **kwargs to unpack the dict into named arguments for _generate()
# self.batch_processor = BatchProcessor(
# process_batch=lambda request: self._generate(**request),
# batch_size=self.batch_size,
# num_threads=self.num_processing_threads,
# )

# Load the model and processor for Annaphi2 from HuggingFace
self.model_id = config_obj.model
self.dtype = config_obj.dtype
self.max_length = config_obj.max_length
# if self.dtype == Dtype.INT8:
# load_in_8bit = True
# self.torch_dtype = Dtype.FLOAT16.to_torch()
# else:
# load_in_8bit = False
self.torch_dtype = self.dtype.to_torch()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = transformers.AutoModelForCausalLM.from_pretrained(
self.model_id,
torch_dtype=self.torch_dtype,
# load_in_8bit=load_in_8bit,
device_map=self.device,
)
self.model = torch.compile(self.model)
self.model.eval()
self.processor = transformers.AutoTokenizer.from_pretrained(self.model_id)
self.model.to(self.device)

@test_cache
async def generate(self, prompt: str) -> AsyncGenerator[LLMOutput, None]:
"""Generate completion for the given prompt and stream the results.
Args:
prompt (str): the prompt
Yields:
LLMOutput: the dictionary with the key "text" and the generated text as the value
"""
prompt = str(prompt)
prompt_chat = apply_chat_template(self.tokenizer, prompt, self.chat_template_name)
inputs = self.tokenizer(prompt_chat, return_tensors="pt", return_attention_mask=True).to('cuda')
outputs = self.model.generate(**inputs, max_length=self.max_length, eos_token_id=self.tokenizer.eos_token_id)
text = tokenizer.batch_decode(outputs[:,:-1])[0]
return LLMOutput(text=generated_text)






# @test_cache
# async def generate(self, prompt: str) -> AsyncGenerator[LLMOutput, None]:
# """Generate completion for the given prompt.

# Args:
# prompt (str): the prompt

# Returns:
# LLMOutput: the dictionary with the key "text" and the generated text as the value

# Raises:
# InferenceException: if the inference fails
# """
# prommpt = str(prompt)
# prompt = apply_chat_template(self.tokenizer, dialog, self.chat_template_name)



# captions: CaptioningBatchOutput = await self.batch_processor.process(
# {"images": [image]}
# )
# return CaptioningOutput(caption=captions["captions"][0])

# @test_cache
# async def generate_batch(self, **kwargs) -> CaptioningBatchOutput:
# """Generate captions for the given images.

# Args:
# images (List[Image]): the images
# **kwargs (dict[str, Any]): keywordarguments to pass to the
# batch processor.

# Returns:
# CaptioningBatchOutput: the dictionary with one key "captions"
# and the list of captions for the images as value

# Raises:
# InferenceException: if the inference fails
# """
# # Call the batch processor to process the requests
# # The actual inference is done in _generate()
# return await self.batch_processor.process(kwargs)

# def _generate(self, images: list[Image]) -> CaptioningBatchOutput:
# """Generate captions for the given images.

# This method is called by the batch processor.

# Args:
# images (List[Image]): the images

# Returns:
# CaptioningBatchOutput: the dictionary with one key "captions"
# and the list of captions for the images as value

# Raises:
# InferenceException: if the inference fails
# """
# # Set the seed to make the results reproducible
# transformers.set_seed(42)
# # Loading images
# numpy_images = [im.get_numpy() for im in images]
# inputs = self.processor(numpy_images, return_tensors="pt").to(
# self.device, self.torch_dtype
# )

# try:
# generated_ids = self.model.generate(**inputs)
# generated_texts = self.processor.batch_decode(
# generated_ids, skip_special_tokens=True
# )
# generated_texts = [
# generated_text.strip() for generated_text in generated_texts
# ]
# return CaptioningBatchOutput(captions=generated_texts)
# except Exception as e:
# raise InferenceException(self.model_id) from e

0 comments on commit 817ae4b

Please sign in to comment.