Skip to content

Commit

Permalink
Merge pull request #21 from mobiusml/streaming_endpoints_video
Browse files Browse the repository at this point in the history
Streaming Endpoints for Captioning and ASR
  • Loading branch information
movchan74 authored Nov 29, 2023
2 parents ac309e8 + b0fb35d commit 23945bb
Show file tree
Hide file tree
Showing 15 changed files with 482 additions and 84 deletions.
106 changes: 87 additions & 19 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.responses import StreamingResponse
from mobius_pipeline.node.socket import Socket
from mobius_pipeline.pipeline.pipeline import Pipeline
from pydantic import BaseModel, Field, create_model, parse_raw_as
from pydantic import BaseModel, Field, ValidationError, create_model, parse_raw_as

from aana.api.responses import AanaJSONResponse
from aana.exceptions.general import MultipleFileUploadNotAllowed
Expand Down Expand Up @@ -92,6 +92,19 @@ class FileUploadField:
description: str


@dataclass
class EndpointOutput:
"""Class used to represent an endpoint output.
Attributes:
name (str): Name of the output that should be returned by the endpoint.
output (str): Output of the pipeline that should be returned by the endpoint.
"""

name: str
output: str


@dataclass
class Endpoint:
"""Class used to represent an endpoint.
Expand All @@ -100,8 +113,7 @@ class Endpoint:
name (str): Name of the endpoint.
path (str): Path of the endpoint.
summary (str): Description of the endpoint that will be shown in the API documentation.
outputs (List[str]): List of required outputs from the pipeline that should be returned
by the endpoint.
outputs (List[EndpointOutput]): List of outputs that should be returned by the endpoint.
output_filter (Optional[OutputFilter]): The parameter will be added to the request and
will allow to choose subset of `outputs` to return.
streaming (bool): Whether the endpoint outputs a stream of data.
Expand All @@ -110,10 +122,18 @@ class Endpoint:
name: str
path: str
summary: str
outputs: list[str]
outputs: list[EndpointOutput]
output_filter: OutputFilter | None = None
streaming: bool = False

def __post_init__(self):
"""Post init method.
Creates dictionaries for fast lookup of outputs.
"""
self.name_to_output = {output.name: output.output for output in self.outputs}
self.output_to_name = {output.output: output.name for output in self.outputs}

def generate_model_name(self, suffix: str) -> str:
"""Generate a Pydantic model name based on a given suffix.
Expand Down Expand Up @@ -141,27 +161,48 @@ def socket_to_field(self, socket: Socket) -> tuple[Any, Any]:
data_model = Any
return (data_model, Field(None))

# check if any of the fields are required
if any(field.required for field in data_model.__fields__.values()):
# try to instantiate the data model
# to see if any of the fields are required
try:
data_model_instance = data_model()
except ValidationError:
# if we can't instantiate the data model
# it means that it has required fields
return (data_model, ...)
else:
return (data_model, data_model_instance)

return (data_model, data_model())

def get_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]:
"""Generate fields for the Pydantic model based on the provided sockets.
def get_input_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]:
"""Generate fields for the request Pydantic model based on the provided sockets.
Parameters:
sockets (List[Socket]): List of sockets.
sockets (list[Socket]): List of sockets.
Returns:
Dict[str, Tuple[Any, Field]]: Dictionary of fields for the Pydantic model.
dict[str, tuple[Any, Field]]: Dictionary of fields for the request Pydantic model.
"""
fields = {}
for socket in sockets:
field = self.socket_to_field(socket)
fields[socket.name] = field
return fields

def get_output_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]:
"""Generate fields for the response Pydantic model based on the provided sockets.
Parameters:
sockets (list[Socket]): List of sockets.
Returns:
dict[str, tuple[Any, Field]]: Dictionary of fields for the response Pydantic model.
"""
fields = {}
for socket in sockets:
field = self.socket_to_field(socket)
name = self.output_to_name[socket.name]
fields[name] = field
return fields

def get_file_upload_field(
self, input_sockets: list[Socket]
) -> FileUploadField | None:
Expand Down Expand Up @@ -211,7 +252,9 @@ def get_output_filter_field(self) -> tuple[Any, Field] | None:
description = self.output_filter.description
outputs_enum_name = self.generate_model_name("Outputs")
outputs_enum = Enum( # type: ignore
outputs_enum_name, [(output, output) for output in self.outputs], type=str
outputs_enum_name,
[(output.name, output.name) for output in self.outputs],
type=str,
)
field = (Optional[list[outputs_enum]], Field(None, description=description)) # noqa: UP007
return field
Expand All @@ -226,7 +269,7 @@ def get_request_model(self, input_sockets: list[Socket]) -> type[BaseModel]:
Type[BaseModel]: Pydantic model for the request.
"""
model_name = self.generate_model_name("Request")
input_fields = self.get_fields(input_sockets)
input_fields = self.get_input_fields(input_sockets)
output_filter_field = self.get_output_filter_field()
if output_filter_field and self.output_filter:
input_fields[self.output_filter.name] = output_filter_field
Expand All @@ -243,10 +286,29 @@ def get_response_model(self, output_sockets: list[Socket]) -> type[BaseModel]:
Type[BaseModel]: Pydantic model for the response.
"""
model_name = self.generate_model_name("Response")
output_fields = self.get_fields(output_sockets)
output_fields = self.get_output_fields(output_sockets)
ResponseModel = create_model(model_name, **output_fields)
return ResponseModel

def process_output(self, output: dict[str, Any]) -> dict[str, Any]:
"""Process the output of the pipeline.
Maps the output names of the pipeline to the names defined in the endpoint outputs.
For example, maps videos_captions_hf_blip2_opt_2_7b to captions.
Args:
output (dict): The output of the pipeline.
Returns:
dict: The processed output.
"""
output = {
self.output_to_name.get(output_name, output_name): output_value
for output_name, output_value in output.items()
}
return output

def create_endpoint_func( # noqa: C901
self,
pipeline: Pipeline,
Expand Down Expand Up @@ -285,10 +347,12 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None):
if requested_outputs:
# get values for requested outputs because it's a list of enums
requested_outputs = [output.value for output in requested_outputs]
outputs = requested_outputs
# map the requested outputs to the actual outputs
# for example, videos_captions_hf_blip2_opt_2_7b to captions
outputs = [self.name_to_output[output] for output in requested_outputs]
# otherwise use the required outputs from the config (all outputs endpoint provides)
else:
outputs = self.outputs
outputs = [output.output for output in self.outputs]

# remove the output filter parameter from the data
if self.output_filter and self.output_filter.name in data_dict:
Expand All @@ -297,18 +361,20 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None):
# run the pipeline
if self.streaming:

async def generator_wrapper():
async def generator_wrapper() -> AsyncGenerator[bytes, None]:
"""Serializes the output of the generator using ORJSONResponseCustom."""
async for output in run_pipeline_streaming(
pipeline, data_dict, outputs
):
output = self.process_output(output)
yield AanaJSONResponse(content=output).body

return StreamingResponse(
generator_wrapper(), media_type="application/json"
)
else:
output = await run_pipeline(pipeline, data_dict, outputs)
output = self.process_output(output)
return AanaJSONResponse(content=output)

if file_upload_field:
Expand All @@ -331,7 +397,9 @@ def register(
pipeline (Pipeline): Pipeline to register the endpoint to.
custom_schemas (Dict[str, Dict]): Dictionary of custom schemas.
"""
input_sockets, output_sockets = pipeline.get_sockets(self.outputs)
input_sockets, output_sockets = pipeline.get_sockets(
[output.output for output in self.outputs]
)
RequestModel = self.get_request_model(input_sockets)
ResponseModel = self.get_response_model(output_sockets)
file_upload_field = self.get_file_upload_field(input_sockets)
Expand Down
2 changes: 1 addition & 1 deletion aana/configs/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_configuration(target: str, endpoints, nodes, deployments) -> dict:
# Target endpoints require the following outputs
endpoint_outputs = []
for endpoint in target_endpoints:
endpoint_outputs += endpoint.outputs
endpoint_outputs += [output.output for output in endpoint.outputs]

# Build the output graph for the whole pipeline
node_definitions = [NodeDefinition.from_dict(node_dict) for node_dict in nodes]
Expand Down
96 changes: 86 additions & 10 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from aana.api.api_generation import Endpoint
from aana.api.api_generation import Endpoint, EndpointOutput

endpoints = {
"llama2": [
Endpoint(
name="llm_generate",
path="/llm/generate",
summary="Generate text using LLaMa2 7B Chat",
outputs=["vllm_llama2_7b_chat_output"],
outputs=[
EndpointOutput(name="completion", output="vllm_llama2_7b_chat_output")
],
),
Endpoint(
name="llm_generate_stream",
path="/llm/generate_stream",
summary="Generate text using LLaMa2 7B Chat (streaming)",
outputs=["vllm_llama2_7b_chat_output_stream"],
outputs=[
EndpointOutput(
name="completion", output="vllm_llama2_7b_chat_output_stream"
)
],
streaming=True,
),
],
Expand All @@ -21,29 +27,41 @@
name="zephyr_generate",
path="/llm/generate",
summary="Generate text using Zephyr 7B Beta",
outputs=["vllm_zephyr_7b_beta_output"],
outputs=[
EndpointOutput(name="completion", output="vllm_zephyr_7b_beta_output")
],
)
],
"blip2": [
Endpoint(
name="blip2_generate",
path="/image/generate_captions",
summary="Generate captions for images using BLIP2 OPT-2.7B",
outputs=["captions_hf_blip2_opt_2_7b"],
outputs=[
EndpointOutput(name="captions", output="captions_hf_blip2_opt_2_7b")
],
),
Endpoint(
name="blip2_video_generate",
path="/video/generate_captions",
summary="Generate captions for videos using BLIP2 OPT-2.7B",
outputs=["video_captions_hf_blip2_opt_2_7b", "timestamps"],
outputs=[
EndpointOutput(
name="captions", output="videos_captions_hf_blip2_opt_2_7b"
),
EndpointOutput(name="timestamps", output="timestamps"),
],
),
],
"video": [
Endpoint(
name="video_extract_frames",
path="/video/extract_frames",
summary="Extract frames from a video",
outputs=["timestamps", "duration"],
outputs=[
EndpointOutput(name="timestamps", output="timestamps"),
EndpointOutput(name="duration", output="duration"),
],
)
],
"whisper": [
Expand All @@ -52,10 +70,68 @@
path="/video/transcribe",
summary="Transcribe a video using Whisper Medium",
outputs=[
"video_transcriptions_whisper_medium",
"video_transcriptions_segments_whisper_medium",
"video_transcriptions_info_whisper_medium",
EndpointOutput(
name="transcription", output="videos_transcriptions_whisper_medium"
),
EndpointOutput(
name="segments",
output="videos_transcriptions_segments_whisper_medium",
),
EndpointOutput(
name="info", output="videos_transcriptions_info_whisper_medium"
),
],
)
],
"chat_with_video": [
Endpoint(
name="blip2_video_generate",
path="/video/generate_captions",
summary="Generate captions for videos using BLIP2 OPT-2.7B",
outputs=[
EndpointOutput(
name="captions", output="video_captions_hf_blip2_opt_2_7b"
),
EndpointOutput(name="timestamps", output="video_timestamps"),
],
streaming=True,
),
Endpoint(
name="whisper_transcribe",
path="/video/transcribe",
summary="Transcribe a video using Whisper Medium",
outputs=[
EndpointOutput(
name="transcription", output="video_transcriptions_whisper_medium"
),
EndpointOutput(
name="segments",
output="video_transcriptions_segments_whisper_medium",
),
EndpointOutput(
name="info", output="video_transcriptions_info_whisper_medium"
),
],
streaming=True,
),
Endpoint(
name="llm_generate",
path="/llm/generate",
summary="Generate text using LLaMa2 7B Chat",
outputs=[
EndpointOutput(name="completion", output="vllm_llama2_7b_chat_output")
],
),
Endpoint(
name="llm_generate_stream",
path="/llm/generate_stream",
summary="Generate text using LLaMa2 7B Chat (streaming)",
outputs=[
EndpointOutput(
name="completion", output="vllm_llama2_7b_chat_output_stream"
)
],
streaming=True,
),
],
}
Loading

0 comments on commit 23945bb

Please sign in to comment.