Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edr/video bugs #69

Merged
merged 19 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions aana/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from alembic import context
from sqlalchemy import engine_from_config, pool

from aana.configs.db import create_database_engine
from aana.configs.settings import settings
from aana.models.db.base import BaseEntity

Expand Down Expand Up @@ -38,12 +37,13 @@ def run_migrations_offline() -> None:
script output.

"""
engine = create_database_engine(settings.db_config)
engine = settings.db_config.get_engine()
context.configure(
url=engine.url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
render_as_batched=True,
)

with context.begin_transaction():
Expand All @@ -58,7 +58,7 @@ def run_migrations_online() -> None:

"""
config_section = config.get_section(config.config_ini_section, {})
engine = create_database_engine(settings.db_config)
engine = settings.db_config.get_engine()
config_section["sqlalchemy.url"] = engine.url
connectable = engine_from_config(
config_section,
Expand All @@ -67,7 +67,9 @@ def run_migrations_online() -> None:
)

with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
context.configure(
connection=connection, target_metadata=target_metadata, render_as_batch=True
)

with context.begin_transaction():
context.run_migrations()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Name constraints, fix cascading deletes.

Revision ID: 0918cb09ac67
Revises: 86a31568e6c2
Create Date: 2024-03-07 12:56:36.129852

"""
from collections.abc import Sequence

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "0918cb09ac67"
down_revision: str | None = "86a31568e6c2"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Upgrade database to this revision from previous."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("media", schema=None) as batch_op:
batch_op.drop_constraint("fk_media_video_id_video", type_="foreignkey")
batch_op.create_foreign_key(
batch_op.f("fk_media_video_id_video"),
"video",
["video_id"],
["id"],
ondelete="CASCADE",
)

# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade database from this revision to previous."""
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("media", schema=None) as batch_op:
batch_op.drop_constraint(
batch_op.f("fk_media_video_id_video"), type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_media_video_id_video", "video", ["video_id"], ["id"]
)

# ### end Alembic commands ###
4 changes: 2 additions & 2 deletions aana/api/api_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ async def route_func_body(body: str, files: list[UploadFile] | None = None): #
files_as_bytes = [await file.read() for file in files]
getattr(data, file_upload_field.name).set_files(files_as_bytes)

# We have to do this instead of data.dict() because
# data.dict() will convert all nested models to dicts
# We have to do this instead of data.model_dump() because
# data.model_dump() will convert all nested models to dicts
# and we want to keep them as pydantic models
data_dict = {}
for field_name in data.model_fields:
Expand Down
54 changes: 42 additions & 12 deletions aana/configs/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from alembic import command
from alembic.config import Config
from pydantic_settings import BaseSettings
from sqlalchemy import String, TypeDecorator, create_engine
from sqlalchemy.engine import Engine
from typing_extensions import TypedDict

from aana.models.pydantic.media_id import MediaId
Expand Down Expand Up @@ -57,14 +59,37 @@ class DbType(str, Enum):
SQLITE = "sqlite"


class DBConfig(TypedDict):
class DbSettings(BaseSettings):
"""Database configuration."""

datastore_type: DbType | str
datastore_config: SQLiteConfig | PostgreSQLConfig


def create_database_engine(db_config):
datastore_type: DbType | str = DbType.SQLITE
datastore_config: SQLiteConfig | PostgreSQLConfig = SQLiteConfig(
path="/var/lib/aana_data"
)
engine: Engine | None = None

def get_engine(self):
"""Gets engine. Each instance of DbSettings will create a max.of 1 engine."""
if not self.engine:
self.engine = _create_database_engine(self)
return self.engine

def __getstate__(self):
"""Used by pickle to pickle an object."""
# We need to remove the "engine" property because SqlAlchemy engines
# are not picklable
state = self.__dict__.copy()
state.pop("engine", None)
return state

def __setstate__(self, state):
"""Used to restore a runtime object from pickle; the opposite of __getstate__()."""
# We don't need to do anything special here, since the engine will be recreated
# if needed.
self.__dict__.update(state)


def _create_database_engine(db_config):
"""Create SQLAlchemy engine based on the provided configuration.

Args:
Expand All @@ -73,12 +98,12 @@ def create_database_engine(db_config):
Returns:
sqlalchemy.engine.Engine: SQLAlchemy engine instance.
"""
db_type = db_config.get("datastore_type", "").lower()
db_type = getattr(db_config, "datastore_type", "").lower()

if db_type == DbType.POSTGRESQL:
return create_postgresql_engine(db_config["datastore_config"])
return create_postgresql_engine(db_config.datastore_config)
elif db_type == DbType.SQLITE:
return create_sqlite_engine(db_config["datastore_config"])
return create_sqlite_engine(db_config.datastore_config)
else:
raise ValueError(f"Unsupported database type: {db_type}") # noqa: TRY003

Expand Down Expand Up @@ -109,9 +134,11 @@ def create_sqlite_engine(config):
return create_engine(connection_string)


def get_alembic_config(app_config, ini_file_path, alembic_data_path) -> Config:
def get_alembic_config(
app_config, ini_file_path: Path, alembic_data_path: Path
) -> Config:
"""Produces an alembic config to run migrations programmatically."""
engine = create_database_engine(app_config.db_config)
engine = app_config.db_config.get_engine()
alembic_config = Config(ini_file_path)
alembic_config.set_main_option("script_location", str(alembic_data_path))
config_section = alembic_config.get_section(alembic_config.config_ini_section, {})
Expand All @@ -132,7 +159,10 @@ def run_alembic_migrations(settings):
alembic_data_path = aana_root / "alembic"

alembic_config = get_alembic_config(settings, ini_file_path, alembic_data_path)
command.upgrade(alembic_config, "head")
engine = settings.db_config.get_engine()
with engine.begin() as connection:
alembic_config.attributes["connection"] = connection
command.upgrade(alembic_config, "head")


def drop_all_tables(settings):
Expand Down
45 changes: 25 additions & 20 deletions aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,26 @@
}
],
},
{
"name": "delete_media",
"type": "function",
"function": "aana.utils.db.delete_media",
"dict_output": True,
"inputs": [
{
"name": "media_id",
"key": "media_id",
"path": "media_id",
},
],
"outputs": [
{
"name": "deleted_media_id",
"key": "media_id",
"path": "deleted_media_id",
}
],
},
{
"name": "save_video",
"type": "function",
Expand All @@ -897,6 +917,11 @@
"key": "video",
"path": "video.video",
},
{
"name": "video_duration",
"key": "duration",
movchan74 marked this conversation as resolved.
Show resolved Hide resolved
"path": "video.duration",
},
],
"outputs": [
{
Expand All @@ -911,26 +936,6 @@
},
],
},
{
"name": "delete_media",
"type": "function",
"function": "aana.utils.db.delete_media",
"dict_output": True,
"inputs": [
{
"name": "media_id",
"key": "media_id",
"path": "media_id",
},
],
"outputs": [
{
"name": "deleted_media_id",
"key": "media_id",
"path": "deleted_media_id",
}
],
},
{
"name": "save_videos_info",
"type": "function",
Expand Down
7 changes: 2 additions & 5 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic_settings import BaseSettings

from aana.configs.db import DBConfig
from aana.configs.db import DbSettings


class TestSettings(BaseSettings):
Expand All @@ -23,10 +23,7 @@ class Settings(BaseSettings):
model_dir: Path = tmp_data_dir / "models"
num_workers: int = 2

db_config: DBConfig = {
"datastore_type": "sqlite",
"datastore_config": {"path": Path("/var/lib/aana_data")},
}
db_config: DbSettings = DbSettings()

test: TestSettings = TestSettings()

Expand Down
12 changes: 10 additions & 2 deletions aana/models/db/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from sqlalchemy import Column, DateTime, func
from sqlalchemy import Column, DateTime, MetaData, func
from sqlalchemy.orm import DeclarativeBase


class BaseEntity(DeclarativeBase):
"""Base for all ORM classes."""

pass
metadata = MetaData(
naming_convention={
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_`%(constraint_name)s`",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
)


class TimeStampEntity:
Expand Down
7 changes: 2 additions & 5 deletions aana/models/db/caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing

from sqlalchemy import CheckConstraint, Column, Float, ForeignKey, Integer, String
from sqlalchemy.orm import relationship

from aana.configs.db import MediaIdSqlType
from aana.models.db.base import BaseEntity, TimeStampEntity
Expand Down Expand Up @@ -37,18 +36,16 @@ class CaptionEntity(BaseEntity, TimeStampEntity):

frame_id = Column(
Integer,
CheckConstraint("frame_id >= 0"),
CheckConstraint("frame_id >= 0", "frame_id_positive"),
comment="The 0-based frame id of video for caption",
)
caption = Column(String, comment="Frame caption")
timestamp = Column(
Float,
CheckConstraint("timestamp >= 0"),
CheckConstraint("timestamp >= 0", name="timestamp_positive"),
comment="Frame timestamp in seconds",
)

video = relationship("VideoEntity", back_populates="captions")

@classmethod
def from_caption_output(
cls,
Expand Down
7 changes: 4 additions & 3 deletions aana/models/db/media.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum

from sqlalchemy import Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from sqlalchemy.orm import backref, relationship

from aana.configs.db import MediaIdSqlType
from aana.models.db.base import BaseEntity, TimeStampEntity
Expand All @@ -21,15 +21,16 @@ class MediaEntity(BaseEntity, TimeStampEntity):
media_type = Column(String, comment="The type of media")
video_id = Column(
Integer,
ForeignKey("video.id"),
ForeignKey("video.id", ondelete="CASCADE"),
nullable=True,
comment="If media_type is `video`, the id of the video this entry represents.",
)

video = relationship(
"VideoEntity",
back_populates="media",
backref=backref("media", passive_deletes=True, uselist=False),
cascade="all, delete",
uselist=False,
post_update=True,
foreign_keys=[video_id],
)
7 changes: 3 additions & 4 deletions aana/models/db/transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING

from sqlalchemy import JSON, CheckConstraint, Column, Float, ForeignKey, Integer, String
from sqlalchemy.orm import relationship

from aana.configs.db import MediaIdSqlType
from aana.models.db.base import BaseEntity, TimeStampEntity
Expand Down Expand Up @@ -45,12 +44,12 @@ class TranscriptEntity(BaseEntity, TimeStampEntity):
)
language_confidence = Column(
Float,
CheckConstraint("0 <= language_confidence <= 1"),
CheckConstraint(
"0 <= language_confidence <= 1", name="language_confidence_value_range"
),
comment="Confidence score of language prediction",
)

video = relationship("VideoEntity", back_populates="transcripts")

@classmethod
def from_asr_output(
cls,
Expand Down
Loading
Loading