Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Dec 12, 2023
1 parent 4535609 commit 78905b9
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 25 deletions.
13 changes: 3 additions & 10 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from mobius_pipeline.exceptions import BaseException
from mobius_pipeline.node.socket import Socket
from mobius_pipeline.pipeline.pipeline import Pipeline
from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as
Expand Down Expand Up @@ -345,10 +346,6 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None): #
data_dict = {}
for field_name in data.__fields__:
field_value = getattr(data, field_name)
# check if it has a method convert_to_entities
# if it does, call it to convert the model to an entity
# if hasattr(field_value, "convert_input_to_object"):
# field_value = field_value.convert_input_to_object()
data_dict[field_name] = field_value

if self.output_filter:
Expand Down Expand Up @@ -393,14 +390,10 @@ async def generator_wrapper() -> AsyncGenerator[bytes, None]:
output = self.process_output(output)
yield AanaJSONResponse(content=output).body
except RayTaskError as e:
print(f"Got exception: {e} Type: {type(e)}")
yield custom_exception_handler(None, e).body
# except BaseException as e:
# print(f"Got exception: {e} Type: {type(e)}")
# yield custom_exception_handler(None, e)
except BaseException as e:
yield custom_exception_handler(None, e)
except Exception as e:
print(f"Got exception: {e} Type: {type(e)}")
# yield custom_exception_handler(None, e).body
error = e.__class__.__name__
stacktrace = traceback.format_exc()
yield AanaJSONResponse(
Expand Down
4 changes: 3 additions & 1 deletion aana/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ async def validation_exception_handler(request: Request, exc: ValidationError):
)


def custom_exception_handler(request: Request, exc_raw: BaseException | RayTaskError):
def custom_exception_handler(
request: Request | None, exc_raw: BaseException | RayTaskError
):
"""This handler is used to handle custom exceptions raised in the application.
BaseException is the base exception for all the exceptions
Expand Down
12 changes: 6 additions & 6 deletions aana/configs/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
"vllm_deployment_llama2_7b_chat": VLLMDeployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.9},
ray_actor_options={"num_gpus": 0.25},
user_config=VLLMConfig(
model="TheBloke/Llama-2-7b-Chat-AWQ",
dtype="auto",
quantization="awq",
gpu_memory_utilization=0.9,
gpu_memory_reserved=10000,
default_sampling_params=SamplingParams(
temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256
),
Expand All @@ -28,12 +28,12 @@
"vllm_deployment_zephyr_7b_beta": VLLMDeployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.5},
ray_actor_options={"num_gpus": 0.25},
user_config=VLLMConfig(
model="TheBloke/zephyr-7B-beta-AWQ",
dtype="auto",
quantization="awq",
gpu_memory_utilization=0.9,
gpu_memory_reserved=10000,
max_model_len=512,
default_sampling_params=SamplingParams(
temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256
Expand All @@ -43,7 +43,7 @@
"hf_blip2_deployment_opt_2_7b": HFBlip2Deployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.45},
ray_actor_options={"num_gpus": 0.25},
user_config=HFBlip2Config(
model="Salesforce/blip2-opt-2.7b",
dtype=Dtype.FLOAT16,
Expand All @@ -54,7 +54,7 @@
"whisper_deployment_medium": WhisperDeployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.45},
ray_actor_options={"num_gpus": 0.25},
user_config=WhisperConfig(
model_size=WhisperModelSize.MEDIUM,
compute_type=WhisperComputeType.FLOAT16,
Expand Down
15 changes: 10 additions & 5 deletions aana/deployments/vllm_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.model_executor.utils import set_random_seed
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from vllm.utils import random_uuid
from vllm.utils import get_gpu_memory, random_uuid

from aana.deployments.base_deployment import BaseDeployment
from aana.exceptions.general import InferenceException
Expand All @@ -24,15 +24,15 @@ class VLLMConfig(BaseModel):
model (str): the model name
dtype (str): the data type (optional, default: "auto")
quantization (str): the quantization method (optional, default: None)
gpu_memory_utilization (float): the GPU memory utilization.
gpu_memory_reserved (float): the GPU memory reserved for the model in mb
default_sampling_params (SamplingParams): the default sampling parameters.
max_model_len (int): the maximum generated text length in tokens (optional, default: None)
"""

model: str
dtype: str | None = Field(default="auto")
quantization: str | None = Field(default=None)
gpu_memory_utilization: float
gpu_memory_reserved: float
default_sampling_params: SamplingParams
max_model_len: int | None = Field(default=None)
chat_template: str | None = Field(default=None)
Expand Down Expand Up @@ -83,7 +83,7 @@ async def apply_config(self, config: dict[str, Any]):
- model: the model name
- dtype: the data type (optional, default: "auto")
- quantization: the quantization method (optional, default: None)
- gpu_memory_utilization: the GPU memory utilization.
- gpu_memory_reserved: the GPU memory reserved for the model in mb
- default_sampling_params: the default sampling parameters.
- max_model_len: the maximum generated text length in tokens (optional, default: None)
- chat_template: the name of the chat template (optional, default: None)
Expand All @@ -93,6 +93,11 @@ async def apply_config(self, config: dict[str, Any]):
"""
config_obj = VLLMConfig(**config)
self.model = config_obj.model
total_gpu_memory_bytes = get_gpu_memory()
total_gpu_memory_mb = total_gpu_memory_bytes / 1024**2
self.gpu_memory_utilization = (
config_obj.gpu_memory_reserved / total_gpu_memory_mb
)
self.default_sampling_params: SamplingParams = (
config_obj.default_sampling_params
)
Expand All @@ -101,7 +106,7 @@ async def apply_config(self, config: dict[str, Any]):
model=config_obj.model,
dtype=config_obj.dtype,
quantization=config_obj.quantization,
gpu_memory_utilization=config_obj.gpu_memory_utilization,
gpu_memory_utilization=self.gpu_memory_utilization,
max_model_len=config_obj.max_model_len,
)

Expand Down
4 changes: 2 additions & 2 deletions aana/utils/video.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
import json # noqa: I001
import pickle
from collections import defaultdict
from collections.abc import Generator
Expand All @@ -7,7 +7,7 @@
from typing import TypedDict

import numpy as np
import torch, decord # See https://github.com/dmlc/decord/issues/263 # noqa: F401
import torch, decord # noqa: F401 # See https://github.com/dmlc/decord/issues/263
import yt_dlp
from yt_dlp.utils import DownloadError

Expand Down
2 changes: 1 addition & 1 deletion mobius-pipeline

0 comments on commit 78905b9

Please sign in to comment.