Skip to content

Commit

Permalink
refactor: Implement services_ctx and refactoring using this
Browse files Browse the repository at this point in the history
  • Loading branch information
jopemachine committed Feb 3, 2025
1 parent 447962b commit b0c186f
Show file tree
Hide file tree
Showing 20 changed files with 388 additions and 313 deletions.
1 change: 1 addition & 0 deletions src/ai/backend/manager/api/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def _handle_gql_common(request: web.Request, params: Any) -> ExecutionResu
manager_status=manager_status,
known_slot_types=known_slot_types,
background_task_manager=root_ctx.background_task_manager,
services_ctx=root_ctx.services_ctx,
storage_manager=root_ctx.storage_manager,
registry=root_ctx.registry,
idle_checker_host=root_ctx.idle_checker_host,
Expand Down
2 changes: 2 additions & 0 deletions src/ai/backend/manager/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import attrs

from ai.backend.common.metrics.metric import CommonMetricRegistry
from ai.backend.manager.api.services.base import ServicesContext
from ai.backend.manager.plugin.network import NetworkPluginContext

if TYPE_CHECKING:
Expand Down Expand Up @@ -50,6 +51,7 @@ class RootContext(BaseContext):
storage_manager: StorageSessionManager
hook_plugin_ctx: HookPluginContext
network_plugin_ctx: NetworkPluginContext
services_ctx: ServicesContext

registry: AgentRegistry
agent_cache: AgentRPCCache
Expand Down
28 changes: 7 additions & 21 deletions src/ai/backend/manager/api/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@

from ai.backend.common import validators as tx
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.models.gql_models.container_registry_utils import (
HarborQuotaManager,
)
from ai.backend.manager.models.rbac import ProjectScope

if TYPE_CHECKING:
Expand Down Expand Up @@ -40,11 +37,8 @@ async def update_registry_quota(request: web.Request, params: Any) -> web.Respon
scope_id = ProjectScope(project_id=group_id, domain_name=None)
quota = int(params["quota"])

async with root_ctx.db.begin_session() as db_sess:
manager = await HarborQuotaManager.new(db_sess, scope_id)
await manager.update(quota)

return web.json_response({})
await root_ctx.services_ctx.per_project_container_registries_quota.update(scope_id, quota)
return web.Response(status=204)


@server_status_required(READ_ALLOWED)
Expand All @@ -60,11 +54,8 @@ async def delete_registry_quota(request: web.Request, params: Any) -> web.Respon
group_id = params["group_id"]
scope_id = ProjectScope(project_id=group_id, domain_name=None)

async with root_ctx.db.begin_session() as db_sess:
manager = await HarborQuotaManager.new(db_sess, scope_id)
await manager.delete()

return web.json_response({})
await root_ctx.services_ctx.per_project_container_registries_quota.delete(scope_id)
return web.Response(status=204)


@server_status_required(READ_ALLOWED)
Expand All @@ -82,11 +73,8 @@ async def create_registry_quota(request: web.Request, params: Any) -> web.Respon
scope_id = ProjectScope(project_id=group_id, domain_name=None)
quota = int(params["quota"])

async with root_ctx.db.begin_session() as db_sess:
manager = await HarborQuotaManager.new(db_sess, scope_id)
await manager.create(quota)

return web.json_response({})
await root_ctx.services_ctx.per_project_container_registries_quota.create(scope_id, quota)
return web.Response(status=204)


@server_status_required(READ_ALLOWED)
Expand All @@ -102,9 +90,7 @@ async def read_registry_quota(request: web.Request, params: Any) -> web.Response
group_id = params["group_id"]
scope_id = ProjectScope(project_id=group_id, domain_name=None)

async with root_ctx.db.begin_session() as db_sess:
manager = await HarborQuotaManager.new(db_sess, scope_id)
quota = await manager.read()
quota = await root_ctx.services_ctx.per_project_container_registries_quota.read(scope_id)

return web.json_response({"result": quota})

Expand Down
1 change: 1 addition & 0 deletions src/ai/backend/manager/api/services/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources(name="src")
Empty file.
25 changes: 25 additions & 0 deletions src/ai/backend/manager/api/services/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine

from .container_registries.base import PerProjectRegistryQuotaRepository
from .container_registries.harbor import (
PerProjectContainerRegistryQuotaProtocol,
PerProjectContainerRegistryQuotaService,
)


class ServicesContext:
"""
In the API layer, requests are processed through the ServicesContext and
its subordinate layers, including the DB, Client, and Repository layers.
Each layer separates the responsibilities specific to its respective level.
"""

db: ExtendedAsyncSAEngine

def __init__(self, db: ExtendedAsyncSAEngine) -> None:
self.db = db

@property
def per_project_container_registries_quota(self) -> PerProjectContainerRegistryQuotaProtocol:
repository = PerProjectRegistryQuotaRepository(db=self.db)
return PerProjectContainerRegistryQuotaService(repository=repository)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources(name="src")
Empty file.
73 changes: 73 additions & 0 deletions src/ai/backend/manager/api/services/container_registries/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING

import sqlalchemy as sa
from sqlalchemy.orm import load_only

from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.api.exceptions import (
ContainerRegistryNotFound,
)
from ai.backend.manager.models.container_registry import ContainerRegistryRow
from ai.backend.manager.models.group import GroupRow
from ai.backend.manager.models.rbac import ProjectScope
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine

if TYPE_CHECKING:
pass

log = BraceStyleAdapter(logging.getLogger(__spec__.name))


class PerProjectRegistryQuotaRepository:
""" """

def __init__(self, db: ExtendedAsyncSAEngine):
self.db = db

@classmethod
def _is_valid_group_row(cls, group_row: GroupRow) -> bool:
return (
group_row
and group_row.container_registry
and "registry" in group_row.container_registry
and "project" in group_row.container_registry
)

async def fetch_container_registry_row(self, scope_id: ProjectScope) -> ContainerRegistryRow:
async with self.db.begin_readonly_session() as db_sess:
project_id = scope_id.project_id
group_query = (
sa.select(GroupRow)
.where(GroupRow.id == project_id)
.options(load_only(GroupRow.container_registry))
)
result = await db_sess.execute(group_query)
group_row = result.scalar_one_or_none()

if not PerProjectRegistryQuotaRepository._is_valid_group_row(group_row):
raise ContainerRegistryNotFound(
f"Container registry info does not exist or is invalid in the group. (gr: {project_id})"
)

registry_name, project = (
group_row.container_registry["registry"],
group_row.container_registry["project"],
)

registry_query = sa.select(ContainerRegistryRow).where(
(ContainerRegistryRow.registry_name == registry_name)
& (ContainerRegistryRow.project == project)
)

result = await db_sess.execute(registry_query)
registry = result.scalars().one_or_none()

if not registry:
raise ContainerRegistryNotFound(
f"Specified container registry row not found. (cr: {registry_name}, gr: {project})"
)

return registry
Loading

0 comments on commit b0c186f

Please sign in to comment.