From 1a045c3287cbb9684bcfad60ddf17027d8bafeac Mon Sep 17 00:00:00 2001 From: movchan74 Date: Thu, 23 Nov 2023 15:17:12 +0000 Subject: [PATCH 1/4] Added streaming endpoints for video processing --- aana/api/api_generation.py | 116 ++++++++++++++++---- aana/configs/build.py | 2 +- aana/configs/endpoints.py | 96 +++++++++++++++-- aana/configs/pipeline.py | 136 +++++++++++++++++++++++- aana/deployments/hf_blip2_deployment.py | 31 ++---- aana/deployments/whisper_deployment.py | 27 ++++- aana/tests/test_api_generation.py | 18 ++-- aana/tests/test_app.py | 6 +- aana/tests/test_app_streaming.py | 6 +- aana/tests/test_build.py | 14 +-- aana/utils/video.py | 58 +++++++++- mobius-pipeline | 2 +- pyproject.toml | 3 + 13 files changed, 436 insertions(+), 79 deletions(-) diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index 0f656308..18363555 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -10,6 +10,7 @@ from aana.exceptions.general import MultipleFileUploadNotAllowed from aana.models.pydantic.exception_response import ExceptionResponseModel +from pydantic import ValidationError async def run_pipeline( @@ -93,6 +94,20 @@ class FileUploadField: description: str +@dataclass +class EndpointOutput: + """ + Class used to represent an endpoint output. + + Attributes: + name (str): Name of the output that should be returned by the endpoint. + output (str): Output of the pipeline that should be returned by the endpoint. + """ + + name: str + output: str + + @dataclass class Endpoint: """ @@ -102,8 +117,7 @@ class Endpoint: name (str): Name of the endpoint. path (str): Path of the endpoint. summary (str): Description of the endpoint that will be shown in the API documentation. - outputs (List[str]): List of required outputs from the pipeline that should be returned - by the endpoint. + outputs (List[EndpointOutput]): List of outputs that should be returned by the endpoint. output_filter (Optional[OutputFilter]): The parameter will be added to the request and will allow to choose subset of `outputs` to return. streaming (bool): Whether the endpoint outputs a stream of data. @@ -112,10 +126,19 @@ class Endpoint: name: str path: str summary: str - outputs: List[str] + outputs: list[EndpointOutput] output_filter: Optional[OutputFilter] = None streaming: bool = False + def __post_init__(self): + """ + Post init method. + + Creates dictionaries for fast lookup of outputs. + """ + self.name_to_output = {output.name: output.output for output in self.outputs} + self.output_to_name = {output.output: output.name for output in self.outputs} + def generate_model_name(self, suffix: str) -> str: """ Generate a Pydantic model name based on a given suffix. @@ -145,21 +168,25 @@ def socket_to_field(self, socket: Socket) -> Tuple[Any, Any]: data_model = Any return (data_model, Field(None)) - # check if any of the fields are required - if any(field.required for field in data_model.__fields__.values()): + # try to instantiate the data model + # to see if any of the fields are required + try: + data_model_instance = data_model() + return (data_model, data_model_instance) + except ValidationError: + # if we can't instantiate the data model + # it means that it has required fields return (data_model, ...) - return (data_model, data_model()) - - def get_fields(self, sockets: List[Socket]) -> Dict[str, Tuple[Any, Any]]: + def get_input_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]: """ - Generate fields for the Pydantic model based on the provided sockets. + Generate fields for the request Pydantic model based on the provided sockets. Parameters: - sockets (List[Socket]): List of sockets. + sockets (list[Socket]): List of sockets. Returns: - Dict[str, Tuple[Any, Field]]: Dictionary of fields for the Pydantic model. + dict[str, tuple[Any, Field]]: Dictionary of fields for the request Pydantic model. """ fields = {} for socket in sockets: @@ -167,6 +194,23 @@ def get_fields(self, sockets: List[Socket]) -> Dict[str, Tuple[Any, Any]]: fields[socket.name] = field return fields + def get_output_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]: + """ + Generate fields for the response Pydantic model based on the provided sockets. + + Parameters: + sockets (list[Socket]): List of sockets. + + Returns: + dict[str, tuple[Any, Field]]: Dictionary of fields for the response Pydantic model. + """ + fields = {} + for socket in sockets: + field = self.socket_to_field(socket) + name = self.output_to_name[socket.name] + fields[name] = field + return fields + def get_file_upload_field( self, input_sockets: List[Socket] ) -> Optional[FileUploadField]: @@ -219,7 +263,9 @@ def get_output_filter_field(self) -> Optional[Tuple[Any, Any]]: description = self.output_filter.description outputs_enum_name = self.generate_model_name("Outputs") outputs_enum = Enum( # type: ignore - outputs_enum_name, [(output, output) for output in self.outputs], type=str + outputs_enum_name, + [(output.name, output.name) for output in self.outputs], + type=str, ) field = (Optional[List[outputs_enum]], Field(None, description=description)) return field @@ -235,7 +281,7 @@ def get_request_model(self, input_sockets: List[Socket]) -> Type[BaseModel]: Type[BaseModel]: Pydantic model for the request. """ model_name = self.generate_model_name("Request") - input_fields = self.get_fields(input_sockets) + input_fields = self.get_input_fields(input_sockets) output_filter_field = self.get_output_filter_field() if output_filter_field and self.output_filter: input_fields[self.output_filter.name] = output_filter_field @@ -253,10 +299,30 @@ def get_response_model(self, output_sockets: List[Socket]) -> Type[BaseModel]: Type[BaseModel]: Pydantic model for the response. """ model_name = self.generate_model_name("Response") - output_fields = self.get_fields(output_sockets) + output_fields = self.get_output_fields(output_sockets) ResponseModel = create_model(model_name, **output_fields) return ResponseModel + def process_output(self, output: dict[str, Any]) -> dict[str, Any]: + """ + Process the output of the pipeline. + + Maps the output names of the pipeline to the names defined in the endpoint outputs. + + For example, maps videos_captions_hf_blip2_opt_2_7b to captions. + + Args: + output (dict): The output of the pipeline. + + Returns: + dict: The processed output. + """ + output = { + self.output_to_name.get(output_name, output_name): output_value + for output_name, output_value in output.items() + } + return output + def create_endpoint_func( self, pipeline: Pipeline, @@ -293,10 +359,12 @@ async def route_func_body(body: str, files: Optional[List[UploadFile]] = None): if requested_outputs: # get values for requested outputs because it's a list of enums requested_outputs = [output.value for output in requested_outputs] - outputs = requested_outputs + # map the requested outputs to the actual outputs + # for example, videos_captions_hf_blip2_opt_2_7b to captions + outputs = [self.name_to_output[output] for output in requested_outputs] # otherwise use the required outputs from the config (all outputs endpoint provides) else: - outputs = self.outputs + outputs = [output.output for output in self.outputs] # remove the output filter parameter from the data if self.output_filter and self.output_filter.name in data_dict: @@ -304,15 +372,23 @@ async def route_func_body(body: str, files: Optional[List[UploadFile]] = None): # run the pipeline if self.streaming: + async def generator_wrapper(): """ Serializes the output of the generator using ORJSONResponseCustom """ - async for output in run_pipeline_streaming(pipeline, data_dict, outputs): + async for output in run_pipeline_streaming( + pipeline, data_dict, outputs + ): + output = self.process_output(output) yield AanaJSONResponse(content=output).body - return StreamingResponse(generator_wrapper(), media_type="application/json") + + return StreamingResponse( + generator_wrapper(), media_type="application/json" + ) else: output = await run_pipeline(pipeline, data_dict, outputs) + output = self.process_output(output) return AanaJSONResponse(content=output) if file_upload_field: @@ -336,7 +412,9 @@ def register( pipeline (Pipeline): Pipeline to register the endpoint to. custom_schemas (Dict[str, Dict]): Dictionary of custom schemas. """ - input_sockets, output_sockets = pipeline.get_sockets(self.outputs) + input_sockets, output_sockets = pipeline.get_sockets( + [output.output for output in self.outputs] + ) RequestModel = self.get_request_model(input_sockets) ResponseModel = self.get_response_model(output_sockets) file_upload_field = self.get_file_upload_field(input_sockets) diff --git a/aana/configs/build.py b/aana/configs/build.py index 8cd724db..50b8857c 100644 --- a/aana/configs/build.py +++ b/aana/configs/build.py @@ -43,7 +43,7 @@ def get_configuration(target: str, endpoints, nodes, deployments) -> Dict: # Target endpoints require the following outputs endpoint_outputs = [] for endpoint in target_endpoints: - endpoint_outputs += endpoint.outputs + endpoint_outputs += [output.output for output in endpoint.outputs] # Build the output graph for the whole pipeline node_definitions = [NodeDefinition.from_dict(node_dict) for node_dict in nodes] diff --git a/aana/configs/endpoints.py b/aana/configs/endpoints.py index e0d1b98e..cd62d22d 100644 --- a/aana/configs/endpoints.py +++ b/aana/configs/endpoints.py @@ -1,4 +1,4 @@ -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput endpoints = { @@ -7,13 +7,19 @@ name="llm_generate", path="/llm/generate", summary="Generate text using LLaMa2 7B Chat", - outputs=["vllm_llama2_7b_chat_output"], + outputs=[ + EndpointOutput(name="completion", output="vllm_llama2_7b_chat_output") + ], ), Endpoint( name="llm_generate_stream", path="/llm/generate_stream", summary="Generate text using LLaMa2 7B Chat (streaming)", - outputs=["vllm_llama2_7b_chat_output_stream"], + outputs=[ + EndpointOutput( + name="completion", output="vllm_llama2_7b_chat_output_stream" + ) + ], streaming=True, ), ], @@ -22,7 +28,9 @@ name="zephyr_generate", path="/llm/generate", summary="Generate text using Zephyr 7B Beta", - outputs=["vllm_zephyr_7b_beta_output"], + outputs=[ + EndpointOutput(name="completion", output="vllm_zephyr_7b_beta_output") + ], ) ], "blip2": [ @@ -30,13 +38,20 @@ name="blip2_generate", path="/image/generate_captions", summary="Generate captions for images using BLIP2 OPT-2.7B", - outputs=["captions_hf_blip2_opt_2_7b"], + outputs=[ + EndpointOutput(name="captions", output="captions_hf_blip2_opt_2_7b") + ], ), Endpoint( name="blip2_video_generate", path="/video/generate_captions", summary="Generate captions for videos using BLIP2 OPT-2.7B", - outputs=["video_captions_hf_blip2_opt_2_7b", "timestamps"], + outputs=[ + EndpointOutput( + name="captions", output="videos_captions_hf_blip2_opt_2_7b" + ), + EndpointOutput(name="timestamps", output="timestamps"), + ], ), ], "video": [ @@ -44,7 +59,10 @@ name="video_extract_frames", path="/video/extract_frames", summary="Extract frames from a video", - outputs=["timestamps", "duration"], + outputs=[ + EndpointOutput(name="timestamps", output="timestamps"), + EndpointOutput(name="duration", output="duration"), + ], ) ], "whisper": [ @@ -53,10 +71,68 @@ path="/video/transcribe", summary="Transcribe a video using Whisper Medium", outputs=[ - "video_transcriptions_whisper_medium", - "video_transcriptions_segments_whisper_medium", - "video_transcriptions_info_whisper_medium", + EndpointOutput( + name="transcription", output="videos_transcriptions_whisper_medium" + ), + EndpointOutput( + name="segments", + output="videos_transcriptions_segments_whisper_medium", + ), + EndpointOutput( + name="info", output="videos_transcriptions_info_whisper_medium" + ), ], ) ], + "chat_with_video": [ + Endpoint( + name="blip2_video_generate", + path="/video/generate_captions", + summary="Generate captions for videos using BLIP2 OPT-2.7B", + outputs=[ + EndpointOutput( + name="captions", output="video_captions_hf_blip2_opt_2_7b" + ), + EndpointOutput(name="timestamps", output="video_timestamps"), + ], + streaming=True, + ), + Endpoint( + name="whisper_transcribe", + path="/video/transcribe", + summary="Transcribe a video using Whisper Medium", + outputs=[ + EndpointOutput( + name="transcription", output="video_transcriptions_whisper_medium" + ), + EndpointOutput( + name="segments", + output="video_transcriptions_segments_whisper_medium", + ), + EndpointOutput( + name="info", output="video_transcriptions_info_whisper_medium" + ), + ], + streaming=True, + ), + Endpoint( + name="llm_generate", + path="/llm/generate", + summary="Generate text using LLaMa2 7B Chat", + outputs=[ + EndpointOutput(name="completion", output="vllm_llama2_7b_chat_output") + ], + ), + Endpoint( + name="llm_generate_stream", + path="/llm/generate_stream", + summary="Generate text using LLaMa2 7B Chat (streaming)", + outputs=[ + EndpointOutput( + name="completion", output="vllm_llama2_7b_chat_output_stream" + ) + ], + streaming=True, + ), + ], } diff --git a/aana/configs/pipeline.py b/aana/configs/pipeline.py index 8ad8465a..99b1d397 100644 --- a/aana/configs/pipeline.py +++ b/aana/configs/pipeline.py @@ -12,7 +12,7 @@ from aana.models.pydantic.image_input import ImageInputList from aana.models.pydantic.prompt import Prompt from aana.models.pydantic.sampling_params import SamplingParams -from aana.models.pydantic.video_input import VideoInputList +from aana.models.pydantic.video_input import VideoInput, VideoInputList from aana.models.pydantic.video_params import VideoParams from aana.models.pydantic.whisper_params import WhisperParams @@ -211,7 +211,7 @@ ], }, { - "name": "download_video", + "name": "download_videos", "type": "ray_task", "function": "aana.utils.video.download_video", "batched": True, @@ -278,7 +278,7 @@ ], }, { - "name": "hf_blip2_opt_2_7b_video", + "name": "hf_blip2_opt_2_7b_videos", "type": "ray_deployment", "deployment_name": "hf_blip2_deployment_opt_2_7b", "method": "generate_batch", @@ -292,7 +292,7 @@ ], "outputs": [ { - "name": "video_captions_hf_blip2_opt_2_7b", + "name": "videos_captions_hf_blip2_opt_2_7b", "key": "captions", "path": "video_batch.videos.[*].frames.[*].caption_hf_blip2_opt_2_7b", "data_model": VideoCaptionsList, @@ -330,6 +330,134 @@ "data_model": WhisperParams, }, ], + "outputs": [ + { + "name": "videos_transcriptions_segments_whisper_medium", + "key": "segments", + "path": "video_batch.videos.[*].segments", + "data_model": AsrSegmentsList, + }, + { + "name": "videos_transcriptions_info_whisper_medium", + "key": "transcription_info", + "path": "video_batch.videos.[*].transcription_info", + "data_model": AsrTranscriptionInfoList, + }, + { + "name": "videos_transcriptions_whisper_medium", + "key": "transcription", + "path": "video_batch.videos.[*].transcription", + "data_model": AsrTranscriptionList, + }, + ], + }, + { + "name": "video", + "type": "input", + "inputs": [], + "outputs": [ + { + "name": "video", + "key": "video", + "path": "video.video_input", + "data_model": VideoInput, + } + ], + }, + { + "name": "download_video", + "type": "ray_task", + "function": "aana.utils.video.download_video", + "dict_output": False, + "inputs": [ + { + "name": "video", + "key": "video_input", + "path": "video.video_input", + }, + ], + "outputs": [ + { + "name": "video_object", + "key": "output", + "path": "video.video", + }, + ], + }, + { + "name": "generate_frames_for_video", + "type": "ray_task", + "function": "aana.utils.video.generate_frames_decord", + "data_type": "generator", + "generator_path": "video", + "inputs": [ + { + "name": "video_object", + "key": "video", + "path": "video.video", + }, + {"name": "video_params", "key": "params", "path": "video_batch.params"}, + ], + "outputs": [ + { + "name": "video_frames", + "key": "frames", + "path": "video.frames.[*].image", + }, + { + "name": "video_timestamps", + "key": "timestamps", + "path": "video.timestamps", + }, + { + "name": "video_duration", + "key": "duration", + "path": "video.duration", + }, + ], + }, + { + "name": "hf_blip2_opt_2_7b_video", + "type": "ray_deployment", + "deployment_name": "hf_blip2_deployment_opt_2_7b", + "method": "generate_batch", + "flatten_by": "video.frames.[*]", + "inputs": [ + { + "name": "video_frames", + "key": "images", + "path": "video.frames.[*].image", + } + ], + "outputs": [ + { + "name": "video_captions_hf_blip2_opt_2_7b", + "key": "captions", + "path": "video.frames.[*].caption_hf_blip2_opt_2_7b", + "data_model": VideoCaptionsList, + } + ], + }, + { + "name": "whisper_medium_transcribe_video", + "type": "ray_deployment", + "deployment_name": "whisper_deployment_medium", + "data_type": "generator", + "generator_path": "video", + "method": "transcribe_stream", + "inputs": [ + { + "name": "video_object", + "key": "media", + "path": "video.video", + }, + { + "name": "whisper_params", + "key": "params", + "path": "video_batch.whisper_params", + "data_model": WhisperParams, + }, + ], "outputs": [ { "name": "video_transcriptions_segments_whisper_medium", diff --git a/aana/deployments/hf_blip2_deployment.py b/aana/deployments/hf_blip2_deployment.py index f5b5c677..af0b225d 100644 --- a/aana/deployments/hf_blip2_deployment.py +++ b/aana/deployments/hf_blip2_deployment.py @@ -27,26 +27,6 @@ class HFBlip2Config(BaseModel): batch_size: int = Field(default=1) num_processing_threads: int = Field(default=1) - @validator("dtype", pre=True, always=True) - def validate_dtype(cls, value: Dtype) -> Dtype: - """ - Validate the data type. For BLIP2 only "float32" and "float16" are supported. - - Args: - value (Dtype): the data type - - Returns: - Dtype: the validated data type - - Raises: - ValueError: if the data type is not supported - """ - if value not in {Dtype.AUTO, Dtype.FLOAT32, Dtype.FLOAT16}: - raise ValueError( - f"Invalid dtype: {value}. BLIP2 only supports 'auto', 'float32', and 'float16'." - ) - return value - class CaptioningOutput(TypedDict): """ @@ -105,10 +85,17 @@ async def apply_config(self, config: Dict[str, Any]): # Load the model and processor for BLIP2 from HuggingFace self.model_id = config_obj.model self.dtype = config_obj.dtype - self.torch_dtype = self.dtype.to_torch() + if self.dtype == Dtype.INT8: + load_in_8bit = True + self.torch_dtype = Dtype.FLOAT16.to_torch() + else: + load_in_8bit = False + self.torch_dtype = self.dtype.to_torch() self.model = Blip2ForConditionalGeneration.from_pretrained( - self.model_id, torch_dtype=self.torch_dtype + self.model_id, torch_dtype=self.torch_dtype, load_in_8bit=load_in_8bit ) + self.model = torch.compile(self.model) + self.model.eval() self.processor = Blip2Processor.from_pretrained(self.model_id) self.device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index bf661912..4ee6927f 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -147,7 +147,7 @@ async def apply_config(self, config: Dict[str, Any]): # TODO: add audio support async def transcribe( - self, media: Video, params: WhisperParams = WhisperParams() + self, media: Video, params: WhisperParams | None = None ) -> WhisperOutput: """ Transcribe the media with the whisper model. @@ -165,6 +165,8 @@ async def transcribe( Raises: InferenceException: If the inference fails. """ + if params is None: + params = WhisperParams() media_path: str = str(media.path) try: @@ -183,6 +185,29 @@ async def transcribe( transcription=asr_transcription, ) + async def transcribe_stream( + self, media: Video, params: WhisperParams | None = None + ) -> WhisperOutput: + """ + Transcribe the media with the whisper model in a streaming fashion. + + Right now this is the same as transcribe, but we will add support for + streaming in the future to support larger media and to make the ASR more responsive. + + Args: + media (Video): The media to transcribe. + params (WhisperParams): The parameters for the whisper model. + + Yields: + WhisperOutput: The transcription output as a dictionary: + segments (List[AsrSegment]): The ASR segments. + transcription_info (AsrTranscriptionInfo): The ASR transcription info. + transcription (AsrTranscription): The ASR transcription. + """ + # TODO: add streaming support + output = await self.transcribe(media, params) + yield output + async def transcribe_batch( self, media_batch: List[Video], params: WhisperParams = WhisperParams() ) -> WhisperBatchOutput: diff --git a/aana/tests/test_api_generation.py b/aana/tests/test_api_generation.py index ccfb2b03..edba374e 100644 --- a/aana/tests/test_api_generation.py +++ b/aana/tests/test_api_generation.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, Extra -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput from aana.exceptions.general import MultipleFileUploadNotAllowed @@ -49,7 +49,7 @@ def test_get_request_model(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output"], + outputs=[EndpointOutput(name="output", output="output")], ) input_sockets = [ @@ -81,7 +81,12 @@ def test_get_response_model(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output", "output_without_datamodel"], + outputs=[ + EndpointOutput(name="output", output="output"), + EndpointOutput( + name="output_without_datamodel", output="output_without_datamodel" + ), + ], ) output_sockets = [ @@ -109,7 +114,7 @@ def test_get_response_model(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output"], + outputs=[EndpointOutput(name="output", output="output")], ) output_sockets = [ @@ -135,7 +140,7 @@ def test_get_file_upload_field(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output"], + outputs=[EndpointOutput(name="output", output="output")], ) input_sockets = [ @@ -155,6 +160,7 @@ def test_get_file_upload_field(): # Check that the file upload field has the correct description assert file_upload_field.description == "Upload image files." + def test_get_file_upload_field_multiple_file_uploads(): """Test the get_file_upload_field function with multiple file uploads.""" @@ -162,7 +168,7 @@ def test_get_file_upload_field_multiple_file_uploads(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output"], + outputs=[EndpointOutput(name="output", output="output")], ) input_sockets = [ diff --git a/aana/tests/test_app.py b/aana/tests/test_app.py index 7e9ef534..0def3694 100644 --- a/aana/tests/test_app.py +++ b/aana/tests/test_app.py @@ -4,7 +4,7 @@ from ray import serve import requests -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput from aana.api.request_handler import RequestHandler @@ -62,7 +62,7 @@ async def lower(self, text: str) -> dict: name="lowercase", path="/lowercase", summary="Lowercase text", - outputs=["lowercase_text"], + outputs=[EndpointOutput(name="text", output="lowercase_text")], ) ] @@ -103,5 +103,5 @@ def test_app(ray_setup): data={"body": json.dumps(data)}, ) assert response.status_code == 200 - lowercase_text = response.json().get("lowercase_text") + lowercase_text = response.json().get("text") assert lowercase_text == ["hello world!", "this is a test."] diff --git a/aana/tests/test_app_streaming.py b/aana/tests/test_app_streaming.py index 386c3c64..aed1cdd7 100644 --- a/aana/tests/test_app_streaming.py +++ b/aana/tests/test_app_streaming.py @@ -7,7 +7,7 @@ from ray import serve import requests -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput from aana.api.request_handler import RequestHandler @@ -71,7 +71,7 @@ async def lower_stream(self, text: str) -> AsyncGenerator[dict, None]: name="lowercase", path="/lowercase", summary="Lowercase text", - outputs=["lowercase_text"], + outputs=[EndpointOutput(name="text", output="lowercase_text")], streaming=True, ) ] @@ -121,7 +121,7 @@ def test_app_streaming(ray_setup): offset = 0 for chunk in response.iter_content(chunk_size=None): json_data = json.loads(chunk) - lowercase_text_chunk = json_data["lowercase_text"] + lowercase_text_chunk = json_data["text"] lowercase_text += lowercase_text_chunk chunk_size = len(lowercase_text_chunk) diff --git a/aana/tests/test_build.py b/aana/tests/test_build.py index 0e069f3d..0e8f7ef9 100644 --- a/aana/tests/test_build.py +++ b/aana/tests/test_build.py @@ -1,7 +1,7 @@ from mobius_pipeline.exceptions import OutputNotFoundException import pytest -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput from aana.configs.build import get_configuration nodes = [ @@ -70,7 +70,7 @@ name="lowercase", path="/lowercase", summary="Lowercase text", - outputs=["lowercase_text"], + outputs=[EndpointOutput(name="lowercase_text", output="lowercase_text")], ) ], "uppercase": [ @@ -78,7 +78,7 @@ name="uppercase", path="/uppercase", summary="Uppercase text", - outputs=["uppercase_text"], + outputs=[EndpointOutput(name="uppercase_text", output="uppercase_text")], ) ], "both": [ @@ -86,13 +86,13 @@ name="lowercase", path="/lowercase", summary="Lowercase text", - outputs=["lowercase_text"], + outputs=[EndpointOutput(name="lowercase_text", output="lowercase_text")], ), Endpoint( name="uppercase", path="/uppercase", summary="Uppercase text", - outputs=["uppercase_text"], + outputs=[EndpointOutput(name="uppercase_text", output="uppercase_text")], ), ], "non_existent": [ @@ -100,7 +100,7 @@ name="non_existent", path="/non_existent", summary="Non existent endpoint", - outputs=["non_existent"], + outputs=[EndpointOutput(name="non_existent", output="non_existent")], ) ], "capitalize": [ @@ -108,7 +108,7 @@ name="capitalize", path="/capitalize", summary="Capitalize text", - outputs=["capitalize_text"], + outputs=[EndpointOutput(name="capitalize_text", output="capitalize_text")], ) ], } diff --git a/aana/utils/video.py b/aana/utils/video.py index 234ca3d5..4c904d88 100644 --- a/aana/utils/video.py +++ b/aana/utils/video.py @@ -3,7 +3,7 @@ import numpy as np import yt_dlp from yt_dlp.utils import DownloadError -from typing import List, TypedDict +from typing import Generator, List, TypedDict from aana.configs.settings import settings from aana.exceptions.general import DownloadException, VideoReadingException from aana.models.core.image import Image @@ -62,7 +62,59 @@ def extract_frames_decord(video: Video, params: VideoParams) -> FramesDict: return FramesDict(frames=frames, timestamps=timestamps, duration=duration) -def download_video(video_input: VideoInput) -> Video: +def generate_frames_decord( + video: Video, params: VideoParams, batch_size: int = 8 +) -> Generator[FramesDict, None, None]: + """ + Generate frames from a video using decord. + + Args: + video (Video): the video to extract frames from + params (VideoParams): the parameters of the video extraction + batch_size (int): the number of frames to yield at each iteration + + Yields: + FramesDict: a dictionary containing the extracted frames, timestamps, + and duration for each batch + """ + + device = decord.cpu(0) + num_threads = 1 # TODO: see if we can use more threads + + num_fps: float = params.extract_fps + try: + video_reader = decord.VideoReader( + str(video.path), ctx=device, num_threads=num_threads + ) + except Exception as e: + raise VideoReadingException(video) from e + + video_fps = video_reader.get_avg_fps() + num_frames = len(video_reader) + duration = num_frames / video_fps + + if params.fast_mode_enabled: + indexes = video_reader.get_key_indices() + else: + # num_fps can be smaller than 1 (e.g. 0.5 means 1 frame every 2 seconds) + indexes = np.arange(0, num_frames, int(video_fps / num_fps)) + timestamps = video_reader.get_frame_timestamp(indexes)[:, 0].tolist() + + for i in range(0, len(indexes), batch_size): + batch = indexes[i : i + batch_size] + batch_frames_array = video_reader.get_batch(batch).asnumpy() + batch_frames = [] + for _, frame in enumerate(batch_frames_array): + img = Image(numpy=frame) + batch_frames.append(img) + + batch_timestamps = timestamps[i : i + batch_size] + yield FramesDict( + frames=batch_frames, timestamps=batch_timestamps, duration=duration + ) + + +def download_video(video_input: VideoInput | Video) -> Video: """ Downloads videos for a VideoInput object. @@ -72,6 +124,8 @@ def download_video(video_input: VideoInput) -> Video: Returns: Video: the video object """ + if isinstance(video_input, Video): + return video_input if video_input.url is not None: video_source: VideoSource = VideoSource.from_url(video_input.url) if video_source == VideoSource.YOUTUBE: diff --git a/mobius-pipeline b/mobius-pipeline index f455c656..386943bd 160000 --- a/mobius-pipeline +++ b/mobius-pipeline @@ -1 +1 @@ -Subproject commit f455c656da566866f031d74d8676bef0f558b4f3 +Subproject commit 386943bd78d8c3617013ac52bd18a92be0e19c5e diff --git a/pyproject.toml b/pyproject.toml index 7f1fc97d..42f759c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ faster-whisper = "^0.9.0" onnxruntime = "1.16.1" deepdiff = "^6.7.0" yt-dlp = "^2023.10.13" +qdrant-client = "^1.6.9" +bitsandbytes = "^0.41.2.post2" +accelerate = "^0.24.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.2" From 6de17654909b36f329c0c6291abaa2eafccd13c9 Mon Sep 17 00:00:00 2001 From: movchan74 Date: Thu, 23 Nov 2023 15:59:44 +0000 Subject: [PATCH 2/4] Add tests for transcribe_stream and generate_frames_decord --- .../deployments/test_whisper_deployment.py | 19 ++++++++- aana/tests/test_frame_extraction.py | 42 +++++++++++++++++-- aana/tests/test_video.py | 9 +++- aana/utils/video.py | 1 - 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/aana/tests/deployments/test_whisper_deployment.py b/aana/tests/deployments/test_whisper_deployment.py index dff4ea2e..69a9c635 100644 --- a/aana/tests/deployments/test_whisper_deployment.py +++ b/aana/tests/deployments/test_whisper_deployment.py @@ -77,6 +77,7 @@ async def test_whisper_deployment(video_file): with open(path, "r") as f: expected_output = json.load(f) + # Test transcribe method path = resources.path("aana.tests.files.videos", video_file) assert path.exists(), f"Video not found: {path}" video = Video(path=path) @@ -88,10 +89,26 @@ async def test_whisper_deployment(video_file): compare_transcriptions(expected_output, output) + # Test transcribe_stream method + path = resources.path("aana.tests.files.videos", video_file) + assert path.exists(), f"Video not found: {path}" + video = Video(path=path) + + stream = handle.options(stream=True).transcribe_stream.remote( + media=video, params=WhisperParams(word_timestamps=True) + ) + # We only have one chunk now + # TODO: test multiple chunks when steaming is implemented properly + async for chunk in stream: + chunk = await chunk + output = pydantic_to_dict(chunk) + compare_transcriptions(expected_output, output) + + # Test transcribe_batch method videos = [video, video] batch_output = await handle.transcribe_batch.remote( - media=videos, params=WhisperParams(word_timestamps=True) + media_batch=videos, params=WhisperParams(word_timestamps=True) ) batch_output = pydantic_to_dict(batch_output) diff --git a/aana/tests/test_frame_extraction.py b/aana/tests/test_frame_extraction.py index 3211ed38..109c67a5 100644 --- a/aana/tests/test_frame_extraction.py +++ b/aana/tests/test_frame_extraction.py @@ -4,7 +4,7 @@ from aana.models.core.image import Image from aana.models.core.video import Video from aana.models.pydantic.video_params import VideoParams -from aana.utils.video import extract_frames_decord +from aana.utils.video import extract_frames_decord, generate_frames_decord @pytest.mark.parametrize( @@ -33,9 +33,43 @@ def test_extract_frames_success( assert isinstance(result["frames"][0], Image) assert result["duration"] == expected_duration assert len(result["frames"]) == expected_num_frames - assert ( - len(result["frames"]) == len(result["timestamps"]) - 1 - ) # Minus 1 because the duration is added as the last timestamp + assert len(result["timestamps"]) == expected_num_frames + + +@pytest.mark.parametrize( + "video_name, extract_fps, fast_mode_enabled, expected_duration, expected_num_frames", + [ + ("squirrel.mp4", 1.0, False, 10.0, 10), + ("squirrel.mp4", 0.5, False, 10.0, 5), + ("squirrel.mp4", 1.0, True, 10.0, 4), + ], +) +def test_generate_frames_success( + video_name, extract_fps, fast_mode_enabled, expected_duration, expected_num_frames +): + """ + Test that generator generate_frames_decord can be used + to extract frames from a video successfully. + """ + + video_path = resources.path("aana.tests.files.videos", video_name) + video = Video(path=video_path) + params = VideoParams(extract_fps=extract_fps, fast_mode_enabled=fast_mode_enabled) + gen_frame = generate_frames_decord(video=video, params=params, batch_size=1) + total_frames = 0 + for result in gen_frame: + assert "frames" in result + assert "timestamps" in result + assert "duration" in result + assert isinstance(result["duration"], float) + assert isinstance(result["frames"], list) + assert isinstance(result["frames"][0], Image) + assert result["duration"] == expected_duration + assert len(result["frames"]) == 1 # batch_size = 1 + assert len(result["timestamps"]) == 1 # batch_size = 1 + total_frames += 1 + + assert total_frames == expected_num_frames def test_extract_frames_failure(): diff --git a/aana/tests/test_video.py b/aana/tests/test_video.py index 862245d1..b64dbe8e 100644 --- a/aana/tests/test_video.py +++ b/aana/tests/test_video.py @@ -149,7 +149,7 @@ def test_download_video(mock_download_file): """ Test download_video. """ - # Test VideoInput + # Test VideoInput with path path = resources.path("aana.tests.files.videos", "squirrel.mp4") video_input = VideoInput(path=str(path)) video = download_video(video_input) @@ -158,6 +158,7 @@ def test_download_video(mock_download_file): assert video.content is None assert video.url is None + # Test VideoInput with url try: url = "http://example.com/squirrel.mp4" video_input = VideoInput(url=url) @@ -195,3 +196,9 @@ def test_download_video(mock_download_file): youtube_video_input = VideoInput(url=youtube_url) with pytest.raises(DownloadException): download_video(youtube_video_input) + + # Test Video object as input + path = resources.path("aana.tests.files.videos", "squirrel.mp4") + video = Video(path=path) + downloaded_video = download_video(video) + assert downloaded_video == video diff --git a/aana/utils/video.py b/aana/utils/video.py index 4c904d88..a1f3f919 100644 --- a/aana/utils/video.py +++ b/aana/utils/video.py @@ -51,7 +51,6 @@ def extract_frames_decord(video: Video, params: VideoParams) -> FramesDict: # num_fps can be smaller than 1 (e.g. 0.5 means 1 frame every 2 seconds) indexes = np.arange(0, num_frames, int(video_fps / num_fps)) timestamps = video_reader.get_frame_timestamp(indexes)[:, 0].tolist() - timestamps.append(duration) frames_array = video_reader.get_batch(indexes).asnumpy() frames = [] From d58a705a8b587a0f02916ace40fd55a470e73d8b Mon Sep 17 00:00:00 2001 From: movchan74 Date: Wed, 29 Nov 2023 10:47:05 +0000 Subject: [PATCH 3/4] Added missing type annotations --- aana/api/api_generation.py | 5 +++-- aana/deployments/whisper_deployment.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index 18363555..88383297 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -1,6 +1,7 @@ from dataclasses import dataclass +from collections.abc import AsyncGenerator from enum import Enum -from typing import AsyncGenerator, Dict, Tuple, Type, Any, List, Optional +from typing import Dict, Tuple, Type, Any, List, Optional from fastapi import FastAPI, File, Form, UploadFile from fastapi.responses import StreamingResponse from mobius_pipeline.pipeline.pipeline import Pipeline @@ -373,7 +374,7 @@ async def route_func_body(body: str, files: Optional[List[UploadFile]] = None): # run the pipeline if self.streaming: - async def generator_wrapper(): + async def generator_wrapper() -> AsyncGenerator[str, None]: """ Serializes the output of the generator using ORJSONResponseCustom """ diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index 4ee6927f..34d37098 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -1,3 +1,4 @@ +from collections.abc import AsyncGenerator from enum import Enum from typing import Any, Dict, List, TypedDict, cast from faster_whisper import WhisperModel @@ -187,7 +188,7 @@ async def transcribe( async def transcribe_stream( self, media: Video, params: WhisperParams | None = None - ) -> WhisperOutput: + ) -> AsyncGenerator[WhisperOutput, None]: """ Transcribe the media with the whisper model in a streaming fashion. From b0fb35dab33d92cbfde2761666089a25a17ffcf1 Mon Sep 17 00:00:00 2001 From: Aleksandr Movchan Date: Wed, 29 Nov 2023 11:40:55 +0000 Subject: [PATCH 4/4] Ruff fixes --- aana/api/api_generation.py | 24 ++++++++++-------------- aana/deployments/hf_blip2_deployment.py | 2 +- aana/deployments/whisper_deployment.py | 3 +-- aana/tests/test_frame_extraction.py | 8 ++++---- aana/utils/video.py | 8 +++----- pyproject.toml | 3 --- 6 files changed, 19 insertions(+), 29 deletions(-) diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index 1c00b2fa..588f64af 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -2,16 +2,16 @@ from dataclasses import dataclass from enum import Enum from typing import Any, Optional + from fastapi import FastAPI, File, Form, UploadFile from fastapi.responses import StreamingResponse from mobius_pipeline.node.socket import Socket from mobius_pipeline.pipeline.pipeline import Pipeline -from pydantic import BaseModel, Field, create_model, parse_raw_as +from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as from aana.api.responses import AanaJSONResponse from aana.exceptions.general import MultipleFileUploadNotAllowed from aana.models.pydantic.exception_response import ExceptionResponseModel -from pydantic import ValidationError async def run_pipeline( @@ -94,8 +94,7 @@ class FileUploadField: @dataclass class EndpointOutput: - """ - Class used to represent an endpoint output. + """Class used to represent an endpoint output. Attributes: name (str): Name of the output that should be returned by the endpoint. @@ -128,8 +127,7 @@ class Endpoint: streaming: bool = False def __post_init__(self): - """ - Post init method. + """Post init method. Creates dictionaries for fast lookup of outputs. """ @@ -167,15 +165,15 @@ def socket_to_field(self, socket: Socket) -> tuple[Any, Any]: # to see if any of the fields are required try: data_model_instance = data_model() - return (data_model, data_model_instance) except ValidationError: # if we can't instantiate the data model # it means that it has required fields return (data_model, ...) + else: + return (data_model, data_model_instance) def get_input_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]: - """ - Generate fields for the request Pydantic model based on the provided sockets. + """Generate fields for the request Pydantic model based on the provided sockets. Parameters: sockets (list[Socket]): List of sockets. @@ -190,8 +188,7 @@ def get_input_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]: return fields def get_output_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]: - """ - Generate fields for the response Pydantic model based on the provided sockets. + """Generate fields for the response Pydantic model based on the provided sockets. Parameters: sockets (list[Socket]): List of sockets. @@ -294,8 +291,7 @@ def get_response_model(self, output_sockets: list[Socket]) -> type[BaseModel]: return ResponseModel def process_output(self, output: dict[str, Any]) -> dict[str, Any]: - """ - Process the output of the pipeline. + """Process the output of the pipeline. Maps the output names of the pipeline to the names defined in the endpoint outputs. @@ -366,7 +362,7 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None): if self.streaming: async def generator_wrapper() -> AsyncGenerator[bytes, None]: - """Serializes the output of the generator using ORJSONResponseCustom""" + """Serializes the output of the generator using ORJSONResponseCustom.""" async for output in run_pipeline_streaming( pipeline, data_dict, outputs ): diff --git a/aana/deployments/hf_blip2_deployment.py b/aana/deployments/hf_blip2_deployment.py index 74bf18b7..2cb4ff0e 100644 --- a/aana/deployments/hf_blip2_deployment.py +++ b/aana/deployments/hf_blip2_deployment.py @@ -1,7 +1,7 @@ from typing import Any, TypedDict import torch -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field from ray import serve from transformers import Blip2ForConditionalGeneration, Blip2Processor diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index f0121af8..3fbaeb25 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -177,8 +177,7 @@ async def transcribe( async def transcribe_stream( self, media: Video, params: WhisperParams | None = None ) -> AsyncGenerator[WhisperOutput, None]: - """ - Transcribe the media with the whisper model in a streaming fashion. + """Transcribe the media with the whisper model in a streaming fashion. Right now this is the same as transcribe, but we will add support for streaming in the future to support larger media and to make the ASR more responsive. diff --git a/aana/tests/test_frame_extraction.py b/aana/tests/test_frame_extraction.py index e000dc48..761c160d 100644 --- a/aana/tests/test_frame_extraction.py +++ b/aana/tests/test_frame_extraction.py @@ -48,11 +48,11 @@ def test_extract_frames_success( def test_generate_frames_success( video_name, extract_fps, fast_mode_enabled, expected_duration, expected_num_frames ): - """ - Test that generator generate_frames_decord can be used - to extract frames from a video successfully. - """ + """Test generate_frames_decord. + generate_frames_decord is a generator function that yields a dictionary + containing the frames, timestamps and duration of the video. + """ video_path = resources.path("aana.tests.files.videos", video_name) video = Video(path=video_path) params = VideoParams(extract_fps=extract_fps, fast_mode_enabled=fast_mode_enabled) diff --git a/aana/utils/video.py b/aana/utils/video.py index 209da4bc..b7bbecc3 100644 --- a/aana/utils/video.py +++ b/aana/utils/video.py @@ -6,6 +6,7 @@ import numpy as np import yt_dlp from yt_dlp.utils import DownloadError + from aana.configs.settings import settings from aana.exceptions.general import DownloadException, VideoReadingException from aana.models.core.image import Image @@ -67,8 +68,7 @@ def extract_frames_decord(video: Video, params: VideoParams) -> FramesDict: def generate_frames_decord( video: Video, params: VideoParams, batch_size: int = 8 ) -> Generator[FramesDict, None, None]: - """ - Generate frames from a video using decord. + """Generate frames from a video using decord. Args: video (Video): the video to extract frames from @@ -79,7 +79,6 @@ def generate_frames_decord( FramesDict: a dictionary containing the extracted frames, timestamps, and duration for each batch """ - device = decord.cpu(0) num_threads = 1 # TODO: see if we can use more threads @@ -117,8 +116,7 @@ def generate_frames_decord( def download_video(video_input: VideoInput | Video) -> Video: - """ - Downloads videos for a VideoInput object. + """Downloads videos for a VideoInput object. Args: video_input (VideoInput): the video input to download diff --git a/pyproject.toml b/pyproject.toml index 0893282f..c93eca25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,9 +28,6 @@ torch = { url = "https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp31 torchvision = { url = "https://download.pytorch.org/whl/cu118/torchvision-0.15.2%2Bcu118-cp310-cp310-linux_x86_64.whl" } vllm = "^0.2.1.post1" yt-dlp = "^2023.10.13" -qdrant-client = "^1.6.9" -bitsandbytes = "^0.41.2.post2" -accelerate = "^0.24.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.2"