Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat with Video #23

Merged
merged 8 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ sh install.sh
5. Run the SDK.

```bash
CUDA_VISIBLE_DEVICES=0 poetry run aana --port 8000 --host 0.0.0.0 --target llama2
CUDA_VISIBLE_DEVICES=0 poetry run aana --port 8000 --host 0.0.0.0 --target chat_with_video
```

The target parameter specifies the set of endpoints to deploy.
Expand Down Expand Up @@ -105,7 +105,7 @@ This project uses Ruff for linting and formatting. If you want to
manually run Ruff on the codebase, it's

```sh
ruff check aana
poetry run ruff check aana
```

You can automatically fix some issues with the `--fix`
Expand All @@ -115,7 +115,7 @@ You can automatically fix some issues with the `--fix`
To run the auto-formatter, it's

```sh
ruff format aana
poetry run ruff format aana
```

If you want to enable this as a local pre-commit hook, additionally
Expand Down
60 changes: 47 additions & 13 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import traceback
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional

from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from mobius_pipeline.exceptions import BaseException
from mobius_pipeline.node.socket import Socket
from mobius_pipeline.pipeline.pipeline import Pipeline
from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as
from ray.exceptions import RayTaskError

from aana.api.app import custom_exception_handler
from aana.api.responses import AanaJSONResponse
from aana.exceptions.general import MultipleFileUploadNotAllowed
from aana.models.pydantic.exception_response import ExceptionResponseModel
Expand Down Expand Up @@ -41,7 +45,10 @@ async def run_pipeline(


async def run_pipeline_streaming(
pipeline: Pipeline, data: dict, required_outputs: list[str]
pipeline: Pipeline,
data: dict,
requested_partial_outputs: list[str],
requested_full_outputs: list[str],
) -> AsyncGenerator[dict[str, Any], None]:
"""This function is used to run a Mobius Pipeline as a generator.

Expand All @@ -50,7 +57,8 @@ async def run_pipeline_streaming(
Args:
pipeline (Pipeline): The pipeline to run.
data (dict): The data to create the container from.
required_outputs (List[str]): The required outputs of the pipeline.
requested_partial_outputs (list[str]): The required partial outputs of the pipeline that should be streamed.
requested_full_outputs (list[str]): The required full outputs of the pipeline that should be returned at the end.

Yields:
dict[str, Any]: The output of the pipeline and the execution time of the pipeline.
Expand All @@ -59,7 +67,9 @@ async def run_pipeline_streaming(
container = pipeline.parse_dict(data)

# run the pipeline
async for output in pipeline.run_generator(container, required_outputs):
async for output in pipeline.run_generator(
container, requested_partial_outputs, requested_full_outputs
):
yield output


Expand Down Expand Up @@ -103,6 +113,7 @@ class EndpointOutput:

name: str
output: str
streaming: bool = False


@dataclass
Expand Down Expand Up @@ -133,6 +144,9 @@ def __post_init__(self):
"""
self.name_to_output = {output.name: output.output for output in self.outputs}
self.output_to_name = {output.output: output.name for output in self.outputs}
self.streaming_outputs = {
output.output for output in self.outputs if output.streaming
}

def generate_model_name(self, suffix: str) -> str:
"""Generate a Pydantic model name based on a given suffix.
Expand Down Expand Up @@ -317,7 +331,7 @@ def create_endpoint_func( # noqa: C901
) -> Callable:
"""Create a function for routing an endpoint."""

async def route_func_body(body: str, files: list[UploadFile] | None = None):
async def route_func_body(body: str, files: list[UploadFile] | None = None): # noqa: C901
# parse form data as a pydantic model and validate it
data = parse_raw_as(RequestModel, body)

Expand All @@ -332,10 +346,6 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None):
data_dict = {}
for field_name in data.__fields__:
field_value = getattr(data, field_name)
# check if it has a method convert_to_entities
# if it does, call it to convert the model to an entity
if hasattr(field_value, "convert_input_to_object"):
field_value = field_value.convert_input_to_object()
data_dict[field_name] = field_value

if self.output_filter:
Expand All @@ -360,14 +370,38 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None):

# run the pipeline
if self.streaming:
requested_partial_outputs = []
requested_full_outputs = []
for output in outputs:
if output in self.streaming_outputs:
requested_partial_outputs.append(output)
else:
requested_full_outputs.append(output)

async def generator_wrapper() -> AsyncGenerator[bytes, None]:
"""Serializes the output of the generator using ORJSONResponseCustom."""
async for output in run_pipeline_streaming(
pipeline, data_dict, outputs
):
output = self.process_output(output)
yield AanaJSONResponse(content=output).body
try:
async for output in run_pipeline_streaming(
pipeline,
data_dict,
requested_partial_outputs,
requested_full_outputs,
):
output = self.process_output(output)
yield AanaJSONResponse(content=output).body
except RayTaskError as e:
yield custom_exception_handler(None, e).body
except BaseException as e:
yield custom_exception_handler(None, e)
except Exception as e:
error = e.__class__.__name__
stacktrace = traceback.format_exc()
yield AanaJSONResponse(
status_code=400,
content=ExceptionResponseModel(
error=error, message=str(e), stacktrace=stacktrace
).dict(),
).body

return StreamingResponse(
generator_wrapper(), media_type="application/json"
Expand Down
7 changes: 5 additions & 2 deletions aana/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ async def validation_exception_handler(request: Request, exc: ValidationError):
)


def custom_exception_handler(request: Request, exc_raw: BaseException | RayTaskError):
def custom_exception_handler(
request: Request | None, exc_raw: BaseException | RayTaskError
):
"""This handler is used to handle custom exceptions raised in the application.

BaseException is the base exception for all the exceptions
Expand Down Expand Up @@ -74,8 +76,9 @@ def custom_exception_handler(request: Request, exc_raw: BaseException | RayTaskE
error = exc.__class__.__name__
# get the message of the exception
message = str(exc)
status_code = getattr(exc, "http_status_code", 400)
return AanaJSONResponse(
status_code=400,
status_code=status_code,
content=ExceptionResponseModel(
error=error, message=message, data=data, stacktrace=stacktrace
).dict(),
Expand Down
26 changes: 6 additions & 20 deletions aana/configs/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,36 +13,22 @@
"vllm_deployment_llama2_7b_chat": VLLMDeployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.5},
ray_actor_options={"num_gpus": 0.25},
user_config=VLLMConfig(
model="TheBloke/Llama-2-7b-Chat-AWQ",
dtype="auto",
quantization="awq",
gpu_memory_utilization=0.7,
gpu_memory_reserved=10000,
default_sampling_params=SamplingParams(
temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256
),
).dict(),
),
"vllm_deployment_zephyr_7b_beta": VLLMDeployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.5},
user_config=VLLMConfig(
model="TheBloke/zephyr-7B-beta-AWQ",
dtype="auto",
quantization="awq",
gpu_memory_utilization=0.9,
max_model_len=512,
default_sampling_params=SamplingParams(
temperature=1.0, top_p=1.0, top_k=-1, max_tokens=256
temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024
),
chat_template="llama2",
).dict(),
),
"hf_blip2_deployment_opt_2_7b": HFBlip2Deployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.5},
ray_actor_options={"num_gpus": 0.25},
user_config=HFBlip2Config(
model="Salesforce/blip2-opt-2.7b",
dtype=Dtype.FLOAT16,
Expand All @@ -53,7 +39,7 @@
"whisper_deployment_medium": WhisperDeployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.5},
ray_actor_options={"num_gpus": 0.25},
user_config=WhisperConfig(
model_size=WhisperModelSize.MEDIUM,
compute_type=WhisperComputeType.FLOAT16,
Expand Down
120 changes: 107 additions & 13 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,36 @@
summary="Generate text using LLaMa2 7B Chat (streaming)",
outputs=[
EndpointOutput(
name="completion", output="vllm_llama2_7b_chat_output_stream"
name="completion",
output="vllm_llama2_7b_chat_output_stream",
streaming=True,
)
],
streaming=True,
),
],
"zephyr": [
Endpoint(
name="zephyr_generate",
path="/llm/generate",
summary="Generate text using Zephyr 7B Beta",
name="llm_chat",
path="/llm/chat",
summary="Chat with LLaMa2 7B Chat",
outputs=[
EndpointOutput(name="completion", output="vllm_zephyr_7b_beta_output")
EndpointOutput(
name="message", output="vllm_llama2_7b_chat_output_message"
)
],
)
),
Endpoint(
name="llm_chat_stream",
path="/llm/chat_stream",
summary="Chat with LLaMa2 7B Chat (streaming)",
outputs=[
EndpointOutput(
name="completion",
output="vllm_llama2_7b_chat_output_dialog_stream",
streaming=True,
)
],
streaming=True,
),
],
"blip2": [
Endpoint(
Expand Down Expand Up @@ -90,9 +105,16 @@
summary="Generate captions for videos using BLIP2 OPT-2.7B",
outputs=[
EndpointOutput(
name="captions", output="video_captions_hf_blip2_opt_2_7b"
name="captions",
output="video_captions_hf_blip2_opt_2_7b",
streaming=True,
),
EndpointOutput(
name="timestamps", output="video_timestamps", streaming=True
),
EndpointOutput(
name="video_captions_path", output="video_captions_path"
),
EndpointOutput(name="timestamps", output="video_timestamps"),
],
streaming=True,
),
Expand All @@ -102,15 +124,64 @@
summary="Transcribe a video using Whisper Medium",
outputs=[
EndpointOutput(
name="transcription", output="video_transcriptions_whisper_medium"
name="transcription",
output="video_transcriptions_whisper_medium",
streaming=True,
),
EndpointOutput(
name="segments",
output="video_transcriptions_segments_whisper_medium",
streaming=True,
),
EndpointOutput(
name="info", output="video_transcriptions_info_whisper_medium"
name="info",
output="video_transcriptions_info_whisper_medium",
streaming=True,
),
EndpointOutput(name="transcription_path", output="transcription_path"),
],
streaming=True,
),
Endpoint(
name="index_video_stream",
path="/video/index_stream",
summary="Index a video and return the captions and transcriptions as a stream",
outputs=[
EndpointOutput(
name="transcription",
output="video_transcriptions_whisper_medium",
streaming=True,
),
EndpointOutput(
name="segments",
output="video_transcriptions_segments_whisper_medium",
streaming=True,
),
EndpointOutput(
name="info",
output="video_transcriptions_info_whisper_medium",
streaming=True,
),
EndpointOutput(name="transcription_path", output="transcription_path"),
EndpointOutput(
name="captions",
output="video_captions_hf_blip2_opt_2_7b",
streaming=True,
),
EndpointOutput(
name="timestamps", output="video_timestamps", streaming=True
),
EndpointOutput(name="combined_timeline", output="combined_timeline"),
EndpointOutput(
name="combined_timeline_path", output="combined_timeline_path"
),
EndpointOutput(
name="video_metadata_path", output="video_metadata_path"
),
EndpointOutput(
name="video_captions_path", output="video_captions_path"
),
EndpointOutput(name="transcription_path", output="transcription_path"),
],
streaming=True,
),
Expand All @@ -128,10 +199,33 @@
summary="Generate text using LLaMa2 7B Chat (streaming)",
outputs=[
EndpointOutput(
name="completion", output="vllm_llama2_7b_chat_output_stream"
name="completion",
output="vllm_llama2_7b_chat_output_stream",
streaming=True,
)
],
streaming=True,
),
Endpoint(
name="video_chat_stream",
path="/video/chat_stream",
summary="Chat with video using LLaMa2 7B Chat (streaming)",
outputs=[
EndpointOutput(
name="completion",
output="vllm_llama2_7b_chat_output_dialog_stream_video",
streaming=True,
)
],
streaming=True,
),
Endpoint(
name="video_metadata",
path="/video/metadata",
summary="Load video metadata",
outputs=[
EndpointOutput(name="metadata", output="video_metadata"),
],
),
],
}
Loading