Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Edr/video bugs #69

Merged
merged 19 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions aana/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from alembic import context
from sqlalchemy import engine_from_config, pool

from aana.configs.db import create_database_engine
from aana.configs.settings import settings
from aana.models.db.base import BaseEntity

Expand Down Expand Up @@ -38,7 +37,7 @@ def run_migrations_offline() -> None:
script output.

"""
engine = create_database_engine(settings.db_config)
engine = settings.db_config.get_engine()
context.configure(
url=engine.url,
target_metadata=target_metadata,
Expand All @@ -58,7 +57,7 @@ def run_migrations_online() -> None:

"""
config_section = config.get_section(config.config_ini_section, {})
engine = create_database_engine(settings.db_config)
engine = settings.db_config.get_engine()
config_section["sqlalchemy.url"] = engine.url
connectable = engine_from_config(
config_section,
Expand Down
36 changes: 26 additions & 10 deletions aana/configs/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from alembic import command
from alembic.config import Config
from pydantic import BaseSettings
from sqlalchemy import String, TypeDecorator, create_engine
from sqlalchemy.engine import Engine

from aana.models.pydantic.media_id import MediaId

Expand Down Expand Up @@ -56,14 +58,23 @@ class DbType(str, Enum):
SQLITE = "sqlite"


class DBConfig(TypedDict):
class DbSettings(BaseSettings):
"""Database configuration."""

datastore_type: DbType | str
datastore_config: SQLiteConfig | PostgreSQLConfig
datastore_type: DbType | str = DbType.SQLITE
datastore_config: SQLiteConfig | PostgreSQLConfig = SQLiteConfig(
path="/var/lib/aana_data"
)
engine: Engine | None = None

def get_engine(self):
"""Gets engine. Each instance of DbSettings will create a max.of 1 engine."""
if not self.engine:
self.engine = _create_database_engine(self)
return self.engine

def create_database_engine(db_config):

def _create_database_engine(db_config):
"""Create SQLAlchemy engine based on the provided configuration.

Args:
Expand All @@ -72,12 +83,12 @@ def create_database_engine(db_config):
Returns:
sqlalchemy.engine.Engine: SQLAlchemy engine instance.
"""
db_type = db_config.get("datastore_type", "").lower()
db_type = getattr(db_config, "datastore_type", "").lower()

if db_type == DbType.POSTGRESQL:
return create_postgresql_engine(db_config["datastore_config"])
return create_postgresql_engine(db_config.datastore_config)
elif db_type == DbType.SQLITE:
return create_sqlite_engine(db_config["datastore_config"])
return create_sqlite_engine(db_config.datastore_config)
else:
raise ValueError(f"Unsupported database type: {db_type}") # noqa: TRY003

Expand Down Expand Up @@ -108,9 +119,11 @@ def create_sqlite_engine(config):
return create_engine(connection_string)


def get_alembic_config(app_config, ini_file_path, alembic_data_path) -> Config:
def get_alembic_config(
app_config, ini_file_path: Path, alembic_data_path: Path
) -> Config:
"""Produces an alembic config to run migrations programmatically."""
engine = create_database_engine(app_config.db_config)
engine = app_config.db_config.get_engine()
alembic_config = Config(ini_file_path)
alembic_config.set_main_option("script_location", str(alembic_data_path))
config_section = alembic_config.get_section(alembic_config.config_ini_section, {})
Expand All @@ -131,7 +144,10 @@ def run_alembic_migrations(settings):
alembic_data_path = aana_root / "alembic"

alembic_config = get_alembic_config(settings, ini_file_path, alembic_data_path)
command.upgrade(alembic_config, "head")
engine = settings.db_config.get_engine()
with engine.begin() as connection:
alembic_config.attributes["connection"] = connection
command.upgrade(alembic_config, "head")


def drop_all_tables(settings):
Expand Down
5 changes: 5 additions & 0 deletions aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,11 @@
"key": "video",
"path": "video.video",
},
{
"name": "video_duration",
"key": "duration",
movchan74 marked this conversation as resolved.
Show resolved Hide resolved
"path": "video.duration",
},
],
"outputs": [
{
Expand Down
7 changes: 2 additions & 5 deletions aana/configs/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseSettings

from aana.configs.db import DBConfig
from aana.configs.db import DbSettings


class TestSettings(BaseSettings):
Expand All @@ -21,10 +21,7 @@ class Settings(BaseSettings):
video_dir = tmp_data_dir / "videos"
num_workers: int = 2

db_config: DBConfig = {
"datastore_type": "sqlite",
"datastore_config": {"path": Path("/var/lib/aana_data")},
}
db_config: DbSettings = DbSettings()

test: TestSettings = TestSettings()

Expand Down
1 change: 1 addition & 0 deletions aana/models/db/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ class MediaEntity(BaseEntity, TimeStampEntity):
back_populates="media",
cascade="all, delete",
uselist=False,
post_update=True,
foreign_keys=[video_id],
)
5 changes: 1 addition & 4 deletions aana/models/db/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ class VideoEntity(BaseEntity, TimeStampEntity):
title = Column(String, comment="Title of the video")
description = Column(String, comment="Description of the video")

media = relationship(
"MediaEntity",
foreign_keys=[media_id],
)
media = relationship("MediaEntity", foreign_keys=[media_id], post_update=True)
evanderiel marked this conversation as resolved.
Show resolved Hide resolved
captions = relationship(
"CaptionEntity", back_populates="video", cascade="all, delete-orphan"
)
Expand Down
3 changes: 1 addition & 2 deletions aana/repository/datastore/engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from aana.configs.db import create_database_engine
from aana.configs.settings import settings

engine = create_database_engine(settings.db_config)
engine = settings.db_config.get_engine()
4 changes: 2 additions & 2 deletions aana/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from aana.api.request_handler import RequestHandler
from aana.configs.build import get_configuration
from aana.configs.db import (
DBConfig,
DbSettings,
DbType,
SQLiteConfig,
)
Expand Down Expand Up @@ -109,7 +109,7 @@ def app_setup(ray_serve_setup): # noqa: D417
"""
# create temporary database
tmp_database_path = Path(tempfile.mkstemp(suffix=".db")[1])
db_config = DBConfig(
db_config = DbSettings(
datastore_type=DbType.SQLITE,
datastore_config=SQLiteConfig(path=tmp_database_path),
)
Expand Down
115 changes: 63 additions & 52 deletions aana/tests/db/datastore/test_config.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,86 @@
# ruff: noqa: S101
import pytest

from aana.configs.db import create_database_engine
from aana.configs.db import (
DbSettings,
PostgreSQLConfig,
SQLiteConfig,
)


def test_pg_datastore_config():
"""Tests datastore config."""
db_config = {
"datastore_type": "postgresql",
"datastore_config": {
"host": "0.0.0.0", # noqa: S104
"port": "5432",
"database": "postgres",
"user": "postgres",
"password": "bogus",
},
}
@pytest.fixture
def pg_settings():
"""Fixture for working PostgreSQL settings."""
return DbSettings(
datastore_type="postgresql",
datastore_config=PostgreSQLConfig(
host="0.0.0.0", # noqa: S104
port="5432",
database="postgres",
user="postgres",
password="bogus", # noqa: S106
),
)

engine = create_database_engine(db_config)

@pytest.fixture
def sqlite_settings():
"""Fixture for working sqlite settings."""
return DbSettings(
datastore_type="sqlite",
datastore_config=SQLiteConfig(path="/tmp/deleteme.sqlite"), # noqa: S108
)


@pytest.mark.parameterize("db_settings", [sqlite_settings, pg_settings])
def test_et_engine_idempotent(db_settings):
"""Tests that get_engine returns the same engine on subsequent calls."""
e1 = db_settings.get_engine()
e2 = db_settings.get_engine()
assert e1 is e2


def test_pg_datastore_config(pg_settings):
"""Tests datastore config for postgres."""
engine = pg_settings.get_engine()

assert engine.name == "postgresql"
assert str(engine.url) == "postgresql://postgres:***@0.0.0.0:5432/postgres"


def test_sqlite_datastore_config():
"""Tests datastore config."""
db_config = {
"datastore_type": "sqlite",
"datastore_config": {"path": "/tmp/deleteme.sqlite"}, # noqa: S108
}

engine = create_database_engine(db_config)
def test_sqlite_datastore_config(sqlite_settings):
"""Tests datastore config for SQLite."""
engine = sqlite_settings.get_engine()

assert engine.name == "sqlite"
assert str(engine.url) == f"sqlite:///{db_config['datastore_config']['path']}"
assert str(engine.url) == f"sqlite:///{sqlite_settings.datastore_config['path']}"


def test_nonexistent_datastore_config():
"""Tests datastore config."""
db_config = {
"datastore_type": "oracle🤮",
"datastore_config": {
"host": "0.0.0.0", # noqa: S104
"port": "5432",
"database": "oracle",
"user": "oracle",
"password": "bogus",
},
}
"""Tests that datastore config errors on unsupported DB types."""
db_settings = DbSettings(
**{
"datastore_type": "oracle",
"datastore_config": {
"host": "0.0.0.0", # noqa: S104
"port": "5432",
"database": "oracle",
"user": "oracle",
"password": "bogus",
},
}
)
with pytest.raises(ValueError):
_ = create_database_engine(db_config)
_ = db_settings.get_engine()


def test_invalid_datastore_config():
def test_invalid_datastore_config(pg_settings, sqlite_settings):
"""Tests that a datastore with the wrong config raises errors."""
config_1 = {
"datastore_type": "postgresql",
"datastore_config": {"path": "/tmp/deleteme.sqlite"}, # noqa: S108
}
config_2 = {
"datastore_type": "sqlite",
"datastore_config": {
"host": "0.0.0.0", # noqa: S104
"port": "5432",
"database": "postgres",
"user": "postgres",
"password": "bogus",
},
}
tmp = pg_settings.datastore_config
pg_settings.datastore_config = sqlite_settings.datastore_config
sqlite_settings.datastore_config = tmp

with pytest.raises(KeyError):
_ = create_database_engine(config_1)
_ = sqlite_settings.get_engine()
with pytest.raises(KeyError):
_ = create_database_engine(config_2)
_ = pg_settings.get_engine()
19 changes: 19 additions & 0 deletions aana/tests/db/datastore/test_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aana.repository.datastore.caption_repo import CaptionRepository
from aana.repository.datastore.media_repo import MediaRepository
from aana.repository.datastore.transcript_repo import TranscriptRepository
from aana.repository.datastore.video_repo import VideoRepository


@pytest.fixture
Expand Down Expand Up @@ -36,6 +37,24 @@ def test_create_media(mocked_session):
mocked_session.commit.assert_called_once()


def test_create_media_with_video(mocked_session):
"""Tests that video propery is se on media and vice-versa."""
media_repo = MediaRepository(mocked_session)
video_repo = VideoRepository(mocked_session)
media_type = "video"
media_id = "foo"
media = MediaEntity(id=media_id, media_type=media_type)
video = VideoEntity(media=media)
media.video = video
media2 = media_repo.create(media)
mocked_session.add.assert_called_with(media2)
video2 = video_repo.create(video)
mocked_session.add.assert_called_with(video2)

assert media2.video == video2
assert video2.media == media2


def test_create_caption(mocked_session):
"""Tests caption creation."""
repo = CaptionRepository(mocked_session)
Expand Down
7 changes: 4 additions & 3 deletions aana/tests/db/datastore/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def mock_session(mocker):
def test_save_video(mock_session):
"""Tests save media function."""
media_id = "foobar"
duration = 550.25
path = resources.path("aana.tests.files.videos", "squirrel.mp4")
video = Video(path=path, media_id=media_id)
result = save_video(video)
result = save_video(video, duration)

assert result["media_id"] == media_id
assert result["video_id"] is None
Expand All @@ -60,8 +61,8 @@ def test_save_videos_batch(mock_session):
media_ids = ["foo", "bar"]
path = resources.path("aana.tests.files.videos", "squirrel.mp4")
videos = [Video(path=path, media_id=m_id) for m_id in media_ids]

result = save_video_batch(videos)
durations = [x + 0.1 for x, _ in enumerate(media_ids)]
result = save_video_batch(videos, durations)
movchan74 marked this conversation as resolved.
Show resolved Hide resolved

assert result["media_ids"] == media_ids
assert result["video_ids"] == [None, None]
Expand Down
Loading