Skip to content

Commit

Permalink
Ruff fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Nov 29, 2023
1 parent 5c0a509 commit b0fb35d
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 29 deletions.
24 changes: 10 additions & 14 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional

from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from mobius_pipeline.node.socket import Socket
from mobius_pipeline.pipeline.pipeline import Pipeline
from pydantic import BaseModel, Field, create_model, parse_raw_as
from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as

from aana.api.responses import AanaJSONResponse
from aana.exceptions.general import MultipleFileUploadNotAllowed
from aana.models.pydantic.exception_response import ExceptionResponseModel
from pydantic import ValidationError


async def run_pipeline(
Expand Down Expand Up @@ -94,8 +94,7 @@ class FileUploadField:

@dataclass
class EndpointOutput:
"""
Class used to represent an endpoint output.
"""Class used to represent an endpoint output.
Attributes:
name (str): Name of the output that should be returned by the endpoint.
Expand Down Expand Up @@ -128,8 +127,7 @@ class Endpoint:
streaming: bool = False

def __post_init__(self):
"""
Post init method.
"""Post init method.
Creates dictionaries for fast lookup of outputs.
"""
Expand Down Expand Up @@ -167,15 +165,15 @@ def socket_to_field(self, socket: Socket) -> tuple[Any, Any]:
# to see if any of the fields are required
try:
data_model_instance = data_model()
return (data_model, data_model_instance)
except ValidationError:
# if we can't instantiate the data model
# it means that it has required fields
return (data_model, ...)
else:
return (data_model, data_model_instance)

def get_input_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]:
"""
Generate fields for the request Pydantic model based on the provided sockets.
"""Generate fields for the request Pydantic model based on the provided sockets.
Parameters:
sockets (list[Socket]): List of sockets.
Expand All @@ -190,8 +188,7 @@ def get_input_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]:
return fields

def get_output_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]:
"""
Generate fields for the response Pydantic model based on the provided sockets.
"""Generate fields for the response Pydantic model based on the provided sockets.
Parameters:
sockets (list[Socket]): List of sockets.
Expand Down Expand Up @@ -294,8 +291,7 @@ def get_response_model(self, output_sockets: list[Socket]) -> type[BaseModel]:
return ResponseModel

def process_output(self, output: dict[str, Any]) -> dict[str, Any]:
"""
Process the output of the pipeline.
"""Process the output of the pipeline.
Maps the output names of the pipeline to the names defined in the endpoint outputs.
Expand Down Expand Up @@ -366,7 +362,7 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None):
if self.streaming:

async def generator_wrapper() -> AsyncGenerator[bytes, None]:
"""Serializes the output of the generator using ORJSONResponseCustom"""
"""Serializes the output of the generator using ORJSONResponseCustom."""
async for output in run_pipeline_streaming(
pipeline, data_dict, outputs
):
Expand Down
2 changes: 1 addition & 1 deletion aana/deployments/hf_blip2_deployment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, TypedDict

import torch
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field
from ray import serve
from transformers import Blip2ForConditionalGeneration, Blip2Processor

Expand Down
3 changes: 1 addition & 2 deletions aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ async def transcribe(
async def transcribe_stream(
self, media: Video, params: WhisperParams | None = None
) -> AsyncGenerator[WhisperOutput, None]:
"""
Transcribe the media with the whisper model in a streaming fashion.
"""Transcribe the media with the whisper model in a streaming fashion.
Right now this is the same as transcribe, but we will add support for
streaming in the future to support larger media and to make the ASR more responsive.
Expand Down
8 changes: 4 additions & 4 deletions aana/tests/test_frame_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def test_extract_frames_success(
def test_generate_frames_success(
video_name, extract_fps, fast_mode_enabled, expected_duration, expected_num_frames
):
"""
Test that generator generate_frames_decord can be used
to extract frames from a video successfully.
"""
"""Test generate_frames_decord.
generate_frames_decord is a generator function that yields a dictionary
containing the frames, timestamps and duration of the video.
"""
video_path = resources.path("aana.tests.files.videos", video_name)
video = Video(path=video_path)
params = VideoParams(extract_fps=extract_fps, fast_mode_enabled=fast_mode_enabled)
Expand Down
8 changes: 3 additions & 5 deletions aana/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import yt_dlp
from yt_dlp.utils import DownloadError

from aana.configs.settings import settings
from aana.exceptions.general import DownloadException, VideoReadingException
from aana.models.core.image import Image
Expand Down Expand Up @@ -67,8 +68,7 @@ def extract_frames_decord(video: Video, params: VideoParams) -> FramesDict:
def generate_frames_decord(
video: Video, params: VideoParams, batch_size: int = 8
) -> Generator[FramesDict, None, None]:
"""
Generate frames from a video using decord.
"""Generate frames from a video using decord.
Args:
video (Video): the video to extract frames from
Expand All @@ -79,7 +79,6 @@ def generate_frames_decord(
FramesDict: a dictionary containing the extracted frames, timestamps,
and duration for each batch
"""

device = decord.cpu(0)
num_threads = 1 # TODO: see if we can use more threads

Expand Down Expand Up @@ -117,8 +116,7 @@ def generate_frames_decord(


def download_video(video_input: VideoInput | Video) -> Video:
"""
Downloads videos for a VideoInput object.
"""Downloads videos for a VideoInput object.
Args:
video_input (VideoInput): the video input to download
Expand Down
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ torch = { url = "https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp31
torchvision = { url = "https://download.pytorch.org/whl/cu118/torchvision-0.15.2%2Bcu118-cp310-cp310-linux_x86_64.whl" }
vllm = "^0.2.1.post1"
yt-dlp = "^2023.10.13"
qdrant-client = "^1.6.9"
bitsandbytes = "^0.41.2.post2"
accelerate = "^0.24.1"

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.2"
Expand Down

0 comments on commit b0fb35d

Please sign in to comment.