-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
232 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
from typing import Any, TypedDict | ||
|
||
import torch | ||
import transformers | ||
from pydantic import BaseModel, Field | ||
from ray import serve | ||
# from transformers import Blip2ForConditionalGeneration, Blip2Processor | ||
from aana.deployments.base_deployment import BaseDeployment | ||
from aana.exceptions.general import InferenceException | ||
from aana.models.core.dtype import Dtype | ||
from aana.models.core.image import Image | ||
from aana.utils.batch_processor import BatchProcessor | ||
from aana.utils.test import test_cache | ||
from aana.utils.chat_template import apply_chat_template | ||
|
||
|
||
class HFAnnaphi2Config(BaseModel): | ||
"""The configuration for the Annaphi2 deployment with HuggingFace models. | ||
Attributes: | ||
model (str): the model ID on HuggingFace | ||
dtype (str): the data type (optional, default: "auto"), one of "auto", "float32", "float16" | ||
batch_size (int): the batch size (optional, default: 1) | ||
num_processing_threads (int): the number of processing threads (optional, default: 1) | ||
""" | ||
|
||
model: str | ||
dtype: Dtype = Field(default=Dtype.AUTO) | ||
# batch_size: int = Field(default=1) | ||
num_processing_threads: int = Field(default=1) | ||
# default_sampling_params: SamplingParams | ||
max_length: int = Field(default=1024) | ||
chat_template: str | None = Field(default=None) | ||
|
||
|
||
class LLMOutput(TypedDict): | ||
"""The output of the LLM model. | ||
Attributes: | ||
text (str): the generated text | ||
""" | ||
|
||
text: str | ||
|
||
class LLMBatchOutput(TypedDict): | ||
"""The output of the LLM model for a batch of inputs. | ||
Attributes: | ||
texts (List[str]): the list of generated texts | ||
""" | ||
|
||
texts: list[str] | ||
|
||
|
||
|
||
# class ChatOutput(TypedDict): | ||
# """The output of the chat model. | ||
|
||
# Attributes: | ||
# dialog (ChatDialog): the dialog with the responses from the model | ||
# """ | ||
|
||
# message: ChatMessage | ||
|
||
|
||
@serve.deployment | ||
class HFAnnaphi2Deployment(BaseDeployment): | ||
"""Deployment to serve Annaphi2 models using HuggingFace.""" | ||
|
||
async def apply_config(self, config: dict[str, Any]): | ||
"""Apply the configuration. | ||
The method is called when the deployment is created or updated. | ||
It loads the model and processor from HuggingFace. | ||
The configuration should conform to the HFAnnaphi2Config schema. | ||
""" | ||
config_obj = HFAnnaphi2Config(**config) | ||
|
||
# Create the batch processor to split the requests into batches | ||
# and process them in parallel | ||
# self.batch_size = config_obj.batch_size | ||
self.num_processing_threads = config_obj.num_processing_threads | ||
self.chat_template_name = config_obj.chat_template | ||
# The actual inference is done in _generate() | ||
# We use lambda because BatchProcessor expects dict as input | ||
# and we use **kwargs to unpack the dict into named arguments for _generate() | ||
# self.batch_processor = BatchProcessor( | ||
# process_batch=lambda request: self._generate(**request), | ||
# batch_size=self.batch_size, | ||
# num_threads=self.num_processing_threads, | ||
# ) | ||
|
||
# Load the model and processor for Annaphi2 from HuggingFace | ||
self.model_id = config_obj.model | ||
self.dtype = config_obj.dtype | ||
self.max_length = config_obj.max_length | ||
# if self.dtype == Dtype.INT8: | ||
# load_in_8bit = True | ||
# self.torch_dtype = Dtype.FLOAT16.to_torch() | ||
# else: | ||
# load_in_8bit = False | ||
self.torch_dtype = self.dtype.to_torch() | ||
self.device = "cuda" if torch.cuda.is_available() else "cpu" | ||
self.model = transformers.AutoModelForCausalLM.from_pretrained( | ||
self.model_id, | ||
torch_dtype=self.torch_dtype, | ||
# load_in_8bit=load_in_8bit, | ||
device_map=self.device, | ||
) | ||
self.model = torch.compile(self.model) | ||
self.model.eval() | ||
self.processor = transformers.AutoTokenizer.from_pretrained(self.model_id) | ||
self.model.to(self.device) | ||
|
||
@test_cache | ||
async def generate(self, prompt: str) -> AsyncGenerator[LLMOutput, None]: | ||
"""Generate completion for the given prompt and stream the results. | ||
Args: | ||
prompt (str): the prompt | ||
Yields: | ||
LLMOutput: the dictionary with the key "text" and the generated text as the value | ||
""" | ||
prompt = str(prompt) | ||
prompt_chat = apply_chat_template(self.tokenizer, prompt, self.chat_template_name) | ||
inputs = self.tokenizer(prompt_chat, return_tensors="pt", return_attention_mask=True).to('cuda') | ||
outputs = self.model.generate(**inputs, max_length=self.max_length, eos_token_id=self.tokenizer.eos_token_id) | ||
text = tokenizer.batch_decode(outputs[:,:-1])[0] | ||
return LLMOutput(text=generated_text) | ||
|
||
|
||
|
||
|
||
|
||
|
||
# @test_cache | ||
# async def generate(self, prompt: str) -> AsyncGenerator[LLMOutput, None]: | ||
# """Generate completion for the given prompt. | ||
|
||
# Args: | ||
# prompt (str): the prompt | ||
|
||
# Returns: | ||
# LLMOutput: the dictionary with the key "text" and the generated text as the value | ||
|
||
# Raises: | ||
# InferenceException: if the inference fails | ||
# """ | ||
# prommpt = str(prompt) | ||
# prompt = apply_chat_template(self.tokenizer, dialog, self.chat_template_name) | ||
|
||
|
||
|
||
# captions: CaptioningBatchOutput = await self.batch_processor.process( | ||
# {"images": [image]} | ||
# ) | ||
# return CaptioningOutput(caption=captions["captions"][0]) | ||
|
||
# @test_cache | ||
# async def generate_batch(self, **kwargs) -> CaptioningBatchOutput: | ||
# """Generate captions for the given images. | ||
|
||
# Args: | ||
# images (List[Image]): the images | ||
# **kwargs (dict[str, Any]): keywordarguments to pass to the | ||
# batch processor. | ||
|
||
# Returns: | ||
# CaptioningBatchOutput: the dictionary with one key "captions" | ||
# and the list of captions for the images as value | ||
|
||
# Raises: | ||
# InferenceException: if the inference fails | ||
# """ | ||
# # Call the batch processor to process the requests | ||
# # The actual inference is done in _generate() | ||
# return await self.batch_processor.process(kwargs) | ||
|
||
# def _generate(self, images: list[Image]) -> CaptioningBatchOutput: | ||
# """Generate captions for the given images. | ||
|
||
# This method is called by the batch processor. | ||
|
||
# Args: | ||
# images (List[Image]): the images | ||
|
||
# Returns: | ||
# CaptioningBatchOutput: the dictionary with one key "captions" | ||
# and the list of captions for the images as value | ||
|
||
# Raises: | ||
# InferenceException: if the inference fails | ||
# """ | ||
# # Set the seed to make the results reproducible | ||
# transformers.set_seed(42) | ||
# # Loading images | ||
# numpy_images = [im.get_numpy() for im in images] | ||
# inputs = self.processor(numpy_images, return_tensors="pt").to( | ||
# self.device, self.torch_dtype | ||
# ) | ||
|
||
# try: | ||
# generated_ids = self.model.generate(**inputs) | ||
# generated_texts = self.processor.batch_decode( | ||
# generated_ids, skip_special_tokens=True | ||
# ) | ||
# generated_texts = [ | ||
# generated_text.strip() for generated_text in generated_texts | ||
# ] | ||
# return CaptioningBatchOutput(captions=generated_texts) | ||
# except Exception as e: | ||
# raise InferenceException(self.model_id) from e |