Skip to content

Commit

Permalink
Fixes for tests and configs (PR)
Browse files Browse the repository at this point in the history
  • Loading branch information
evanderiel committed Dec 7, 2023
1 parent 4e6536e commit 49397bb
Show file tree
Hide file tree
Showing 6 changed files with 788 additions and 848 deletions.
Empty file added aana/alembic/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Initialize.
Revision ID: d540b01b2c34
Revision ID: d63d3838d344
Revises:
Create Date: 2023-12-01 12:04:04.960099
Create Date: 2023-12-07 11:58:35.095546
"""
from collections.abc import Sequence
Expand All @@ -11,7 +11,7 @@
from alembic import op

# revision identifiers, used by Alembic.
revision: str = 'd540b01b2c34'
revision: str = 'd63d3838d344'
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
Expand All @@ -21,17 +21,19 @@ def upgrade() -> None:
"""Upgrade database to this revision from previous."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('media',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('id', sa.String(), nullable=False),
sa.Column('duration', sa.Float(), nullable=True, comment='Media duration in seconds'),
sa.Column('media_type', sa.String(), nullable=True, comment='Media type'),
sa.Column('orig_filename', sa.String(), nullable=True, comment='Original filename'),
sa.Column('orig_url', sa.String(), nullable=True, comment='Original URL'),
sa.Column('create_ts', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=True, comment='Timestamp when row is inserted'),
sa.Column('update_ts', sa.DateTime(timezone=True), nullable=True, comment='Timestamp when row is updated'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('captions',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('id', sa.String(), nullable=False),
sa.Column('model', sa.String(), nullable=True, comment='Name of model used to generate the caption'),
sa.Column('media_id', sa.Integer(), nullable=True, comment='Foreign key to media table'),
sa.Column('media_id', sa.String(), nullable=True, comment='Foreign key to media table'),
sa.Column('frame_id', sa.Integer(), nullable=True, comment='The 0-based frame id of media for caption'),
sa.Column('caption', sa.String(), nullable=True, comment='Frame caption'),
sa.Column('timestamp', sa.Float(), nullable=True, comment='Frame timestamp in seconds'),
Expand All @@ -41,9 +43,9 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint('id')
)
op.create_table('transcripts',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('id', sa.String(), nullable=False),
sa.Column('model', sa.String(), nullable=True, comment='Name of model used to generate transcript'),
sa.Column('media_id', sa.Integer(), nullable=True, comment='Foreign key to media table'),
sa.Column('media_id', sa.String(), nullable=True, comment='Foreign key to media table'),
sa.Column('transcript', sa.String(), nullable=True, comment='Full text transcript of media'),
sa.Column('segments', sa.String(), nullable=True, comment='Segments of the transcript'),
sa.Column('language', sa.String(), nullable=True, comment='Language of the transcript as predicted by model'),
Expand Down
6 changes: 3 additions & 3 deletions aana/configs/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from os import PathLike
from typing import TypeAlias, TypedDict

from sqlalchemy import Integer, create_engine
from sqlalchemy import String, create_engine

# These are here so we can change types in a single place.

id_type: TypeAlias = int
IdSqlType: TypeAlias = Integer
id_type: TypeAlias = str
IdSqlType: TypeAlias = String


class SQLiteConfig(TypedDict):
Expand Down
100 changes: 100 additions & 0 deletions aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,104 @@
},
],
},
{
"name": "save_video_info",
"type": "ray_task",
"function": "aana.utils.db.save_video_batch",
"batched": True,
"flatten_by": "video_batch.videos.[*]",
"inputs": [
{
"name": "videos",
"key": "video_inputs",
"path": "video_batch.videos.[*].video_input",
},
],
"outputs": [
{
"name": "media_ids",
"key": "media_ids",
"path": "video_batch.videos.[*].id",
}
],
},
{
"name": "save_transcripts_medium",
"type": "ray_task",
"function": "aana.utils.db.save_transcripts_batch",
"batched": True,
"flatten_by": "video_batch.videos.[*]",
"inputs": [
{
"name": "media_ids",
"key": "media_ids",
"path": "video_batch.videos.[*].id",
},
{"name": "model_name", "key": "model_name", "path": "model_name"},
{
"name": "video_transcriptions_info_whisper_medium",
"key": "transcription_info",
"path": "video_batch.videos.[*].transcription_info",
},
{
"name": "video_transcriptions_segments_whisper_medium",
"key": "segments",
"path": "video_batch.videos.[*].segments",
},
{
"name": "video_transcriptions_whisper_medium",
"key": "transcription",
"path": "video_batch.videos.[*].transcription",
},
],
"outputs": [
{
"name": "transcription_id",
"key": "transcription_id",
"path": "video_batch.videos.[*].transcription.id",
}
],
},
{
"name": "save_video_captions_hf_blip2_opt_2_7b",
"type": "ray_task",
"function": "aana.utils.db.save_captions_batch",
"batched": True,
"flatten_by": "video_batch.videos.[*]",
"inputs": [
{
"name": "media_ids",
"key": "media_ids",
"path": "video_batch.videos.[*].id",
},
{"name": "model_name", "key": "model_name", "path": "model_name"},
{
"name": "timestamps",
"key": "timestamps",
"path": "video_batch.videos.[*].timestamp",
},
{
"name": "duration",
"key": "duration",
"path": "video_batch.videos.[*].duration",
},
{
"name": "video_captions_hf_blip2_opt_2_7b",
"key": "captions",
"path": "video_batch.videos.[*].frames.[*].caption_hf_blip2_opt_2_7b",
},
{
"name": "frame_id",
"key": "frame_id",
"path": "video_batch.videos.[*].frames.[*].id",
},
],
"outputs": [
{
"name": "caption_id",
"key": "caption_id",
"path": "video_batch.videos.[*].frames.[*].caption_hf_blip2_opt_2_7b.id",
}
],
},
]
32 changes: 32 additions & 0 deletions aana/utils/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# ruff: noqa: A002
from pathlib import Path
from urllib.parse import urlparse

from sqlalchemy.orm import Session

from aana.configs.db import id_type
Expand All @@ -9,6 +12,7 @@
AsrTranscriptionList,
)
from aana.models.pydantic.captions import CaptionsList
from aana.models.pydantic.video_input import VideoInputList
from aana.repository.datastore.caption_repo import CaptionRepository
from aana.repository.datastore.engine import engine
from aana.repository.datastore.media_repo import MediaRepository
Expand All @@ -17,6 +21,34 @@

# Just using raw utility functions like this isn't a permanent solution, but
# it's good enough for now to validate what we're working on.
def save_video_batch(video_inputs: VideoInputList) -> list[id_type]:
"""Saves a batch of videos to datastore."""
entities = []
for video_input in video_inputs:
if video_input.url is not None:
orig_url = video_input.url
parsed_url = urlparse(orig_url)
orig_filename = Path(parsed_url.path).name
elif video_input.path is not None:
parsed_path = Path(video_input.path)
orig_filename = parsed_path.name
orig_url = None
else:
orig_url = None
orig_filename = None
entity = MediaEntity(
id=video_input.media_id,
media_type=MediaType.VIDEO,
orig_filename=orig_filename,
orig_url=orig_url,
)
entities.append(entity)
with Session(engine) as session:
repo = MediaRepository(session)
results = repo.create_multiple(entities)
return [result.id for result in results] # type: ignore


def save_media(media_type: MediaType, duration: float) -> id_type:
"""Creates and saves media to datastore.
Expand Down
Loading

0 comments on commit 49397bb

Please sign in to comment.