Skip to content

Commit

Permalink
PR: fixes for batching, dict_output
Browse files Browse the repository at this point in the history
  • Loading branch information
evanderiel committed Dec 12, 2023
1 parent 2d764d3 commit 30f0d2a
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 36 deletions.
2 changes: 1 addition & 1 deletion aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
name="blip2_video_generate",
path="/video/generate_captions",
summary="Generate captions for videos using BLIP2 OPT-2.7B",
outputs=["video_captions_hf_blip2_opt_2_7b", "timestamps"],
outputs=["video_captions_hf_blip2_opt_2_7b", "timestamps", "caption_ids"],
),
],
"video": [
Expand Down
25 changes: 16 additions & 9 deletions aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@
"name": "save_video_info",
"type": "ray_task",
"function": "aana.utils.db.save_video_batch",
"dict_output": True,
"inputs": [
{
"name": "video_objects",
Expand All @@ -379,7 +380,10 @@
"name": "save_transcripts_medium",
"type": "ray_task",
"function": "aana.utils.db.save_transcripts_batch",
"model_name": "whisper_medium",
"kwargs": {
"model_name": "whisper_medium",
},
"dict_output": True,
"inputs": [
{
"name": "media_ids",
Expand All @@ -404,8 +408,8 @@
],
"outputs": [
{
"name": "transcription_id",
"key": "transcription_id",
"name": "transcription_ids",
"key": "transcription_ids",
"path": "video_batch.videos.[*].transcription.id",
}
],
Expand All @@ -414,7 +418,10 @@
"name": "save_video_captions_hf_blip2_opt_2_7b",
"type": "ray_task",
"function": "aana.utils.db.save_captions_batch",
"model_name": "hf_blip2_opt_2_7b",
"kwargs": {
"model_name": "hf_blip2_opt_2_7b",
},
"dict_output": True,
"inputs": [
{
"name": "media_ids",
Expand All @@ -423,24 +430,24 @@
},
{
"name": "video_captions_hf_blip2_opt_2_7b",
"key": "captions",
"key": "captions_list",
"path": "video_batch.videos.[*].frames.[*].caption_hf_blip2_opt_2_7b",
},
{
"name": "timestamps",
"key": "timestamps",
"key": "timestamps_list",
"path": "video_batch.videos.[*].timestamp",
},
{
"name": "frame_ids",
"key": "frame_ids",
"key": "frame_ids_list",
"path": "video_batch.videos.[*].frames.[*].id",
},
],
"outputs": [
{
"name": "caption_id",
"key": "caption_id",
"name": "caption_ids",
"key": "caption_ids",
"path": "video_batch.videos.[*].frames.[*].caption_hf_blip2_opt_2_7b.id",
}
],
Expand Down
3 changes: 3 additions & 0 deletions aana/models/db/caption.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations # Let classes use themselves in type annotations

import uuid

from sqlalchemy import CheckConstraint, Column, Float, ForeignKey, Integer, String
from sqlalchemy.orm import relationship

Expand Down Expand Up @@ -48,6 +50,7 @@ def from_caption_output(
) -> CaptionEntity:
"""Converts a Caption pydantic model to a CaptionEntity."""
return CaptionEntity(
id=str(uuid.uuid4()),
model=model_name,
media_id=media_id,
frame_id=frame_id,
Expand Down
28 changes: 14 additions & 14 deletions aana/tests/db/datastore/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,24 @@ def test_save_transcripts(mock_session):

def test_save_captions(mock_session):
"""Tests save captions function."""
media_ids = ["0"] * 3
models = ["test_model"] * 3
media_ids = ["0"]
models = "test_model"
captions = ["A caption", "Another caption", "A third caption"]
captions_list = CaptionsList(
__root__=[Caption(__root__=caption) for caption in captions]
)
timestamps = [0.1, 0.2, 0.3]
frame_ids = [0, 1, 2]
captions_list = [
CaptionsList(__root__=[Caption(__root__=caption) for caption in captions])
]
timestamps = [[0.1, 0.2, 0.3, 0.4]]
frame_ids = [[0, 1, 2]]

ids = save_captions_batch(media_ids, models, captions_list, timestamps, frame_ids)
result = save_captions_batch(
media_ids, models, captions_list, timestamps, frame_ids
)

assert (
len(ids)
== len(captions_list)
== len(timestamps)
== len(models)
== len(media_ids)
== len(frame_ids)
len(result["caption_id"])
== len(captions_list[0])
== len(timestamps[0][:-1])
== len(frame_ids[0])
)
mock_session.context_var.add_all.assert_called_once()
mock_session.context_var.commit.assert_called_once()
37 changes: 26 additions & 11 deletions aana/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
# Just using raw utility functions like this isn't a permanent solution, but
# it's good enough for now to validate what we're working on.
def save_video_batch(
video_objects: list[Video], video_params: list[VideoParams]
videos: list[Video], # , video_params: list[VideoParams]
) -> list[id_type]:
"""Saves a batch of videos to datastore."""
entities = []
for video_object, _ in zip(video_objects, video_params, strict=True):
for video_object, _ in zip(videos, videos, strict=True):
if video_object.url is not None:
orig_url = video_object.url
parsed_url = urlparse(orig_url)
Expand All @@ -49,7 +49,9 @@ def save_video_batch(
with Session(engine) as session:
repo = MediaRepository(session)
results = repo.create_multiple(entities)
return [result.id for result in results] # type: ignore
return {
"media_ids": [result.id for result in results] # type: ignore
}


def save_media(media_type: MediaType, duration: float) -> id_type:
Expand All @@ -72,21 +74,32 @@ def save_media(media_type: MediaType, duration: float) -> id_type:
def save_captions_batch(
media_ids: list[id_type],
model_name: str,
captions: CaptionsList,
timestamps: list[float],
frame_ids: list[int],
captions_list: list[CaptionsList],
timestamps_list: list[list[float]],
frame_ids_list: list[list[int]],
) -> list[id_type]:
"""Save captions."""
print(
f"{len(media_ids)}\n{len(timestamps_list[0])}\n{len(list(frame_ids_list[0]))}\n{model_name=}\ncaptions: {len(captions_list[0])}"
)
with Session(engine) as session:
entities = [
CaptionEntity.from_caption_output(model_name, media_id, frame_id, t, c)
for media_id, c, t, frame_id in zip(
media_ids, captions, timestamps, frame_ids, strict=True
CaptionEntity.from_caption_output(
model_name, media_id, frame_id, timestamp, caption
)
for media_id, captions, timestamps, frame_ids in zip(
media_ids, captions_list, timestamps_list, frame_ids_list, strict=True
)
for caption, timestamp, frame_id in zip(
captions, timestamps[:-1], frame_ids, strict=True
)
]
repo = CaptionRepository(session)
results = repo.create_multiple(entities)
return [c.id for c in results] # type: ignore
# return [c.id for c in results]
return {
"caption_ids": [c.id for c in results] # type: ignore
}


def save_transcripts_batch(
Expand All @@ -107,4 +120,6 @@ def save_transcripts_batch(

repo = TranscriptRepository(session)
entities = repo.create_multiple(entities)
return [c.id for c in entities] # type: ignore
return {
"transcript_id": [c.id for c in entities] # type: ignore
}
2 changes: 1 addition & 1 deletion aana/utils/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def extract_frames_decord(video: Video, params: VideoParams) -> FramesDict:
frames=frames,
timestamps=timestamps,
duration=duration,
frame_ids=range(len(frames)),
frame_ids=list(range(len(frames))),
)


Expand Down

0 comments on commit 30f0d2a

Please sign in to comment.