diff --git a/aana/routers/task.py b/aana/routers/task.py index 8b7ccf0d..620e8fe5 100644 --- a/aana/routers/task.py +++ b/aana/routers/task.py @@ -2,7 +2,7 @@ from typing import Annotated from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from aana.api.security import UserIdDependency from aana.configs.settings import settings as aana_settings @@ -39,9 +39,35 @@ class TaskList(BaseModel): tasks: list[TaskInfo] = Field(..., description="The list of tasks.") +# fmt: off +class TaskCount(BaseModel): + """Response for a count of tasks by status.""" + + created: int | None = Field(None, description="The number of tasks in the CREATED status.") + assigned: int | None = Field(None, description="The number of tasks in the ASSIGNED status.") + completed: int | None = Field(None, description="The number of tasks in the COMPLETED status.") + running: int | None = Field(None, description="The number of tasks in the RUNNING status.") + failed: int | None = Field(None, description="The number of tasks in the FAILED status.") + not_finished: int | None = Field(None, description="The number of tasks in the NOT_FINISHED status.") + total: int = Field(..., description="The total number of tasks.") +# fmt: on + # Endpoints +@router.get( + "/tasks/count", + summary="Count Tasks", + description="Count tasks per status.", + response_model_exclude_none=True, +) +async def count_tasks(db: GetDbDependency, user_id: UserIdDependency) -> TaskCount: + """Count tasks by status.""" + task_repo = TaskRepository(db) + counts = task_repo.count(user_id=user_id) + return TaskCount(**counts) + + @router.get( "/tasks/{task_id}", summary="Get Task Status", diff --git a/aana/storage/models/task.py b/aana/storage/models/task.py index d9c30fab..3bc7d721 100644 --- a/aana/storage/models/task.py +++ b/aana/storage/models/task.py @@ -12,7 +12,16 @@ class Status(str, Enum): - """Enum for task status.""" + """Enum for task status. + + Attributes: + CREATED: The task is created. + ASSIGNED: The task is assigned to a worker. + COMPLETED: The task is completed. + RUNNING: The task is running. + FAILED: The task has failed. + NOT_FINISHED: The task is not finished. + """ CREATED = "created" ASSIGNED = "assigned" diff --git a/aana/storage/repository/task.py b/aana/storage/repository/task.py index 176ceb08..5e9c21d8 100644 --- a/aana/storage/repository/task.py +++ b/aana/storage/repository/task.py @@ -18,7 +18,7 @@ def __init__(self, session: Session): """Constructor.""" super().__init__(session, TaskEntity) - def read(self, task_id: str | UUID, check: bool = True) -> TaskEntity: + def read(self, task_id: str | UUID, check: bool = True) -> TaskEntity | None: """Reads a single task by id from the database. Args: @@ -31,8 +31,11 @@ def read(self, task_id: str | UUID, check: bool = True) -> TaskEntity: Raises: NotFoundException if the entity is not found and `check` is True. """ - if isinstance(task_id, str): - task_id = UUID(task_id) + try: + if isinstance(task_id, str): + task_id = UUID(task_id) + except ValueError: + return None return super().read(task_id, check=check) def delete(self, task_id: str | UUID, check: bool = False) -> TaskEntity | None: @@ -48,8 +51,11 @@ def delete(self, task_id: str | UUID, check: bool = False) -> TaskEntity | None: Raises: NotFoundException: The id does not correspond to a record in the database. """ - if isinstance(task_id, str): - task_id = UUID(task_id) + try: + if isinstance(task_id, str): + task_id = UUID(task_id) + except ValueError: + return None return super().delete(task_id, check) def save( @@ -442,3 +448,22 @@ def get_tasks( .all() ) return tasks + + def count(self, user_id: str | None = None) -> dict[str, int]: + """Count tasks by status. + + Args: + user_id (str | None): The user ID. If None, all tasks are counted. + + Returns: + dict[str, int]: The count of tasks by status. + """ + counts = ( + self.session.query(TaskEntity.status, func.count(TaskEntity.id)) + .filter(TaskEntity.user_id == user_id if user_id else True) + .group_by(TaskEntity.status) + .all() + ) + count_dict = {status.value: count for status, count in counts} + count_dict["total"] = sum(count_dict.values()) + return count_dict diff --git a/aana/storage/repository/webhook.py b/aana/storage/repository/webhook.py index 727e1a60..21b8681a 100644 --- a/aana/storage/repository/webhook.py +++ b/aana/storage/repository/webhook.py @@ -15,7 +15,7 @@ def __init__(self, session: Session): """Constructor.""" super().__init__(session, WebhookEntity) - def read(self, item_id: str | UUID, check: bool = True) -> WebhookEntity: + def read(self, item_id: str | UUID, check: bool = True) -> WebhookEntity | None: """Reads a single webhook from the database. Args: @@ -28,11 +28,14 @@ def read(self, item_id: str | UUID, check: bool = True) -> WebhookEntity: Raises: NotFoundException if the entity is not found and `check` is True. """ - if isinstance(item_id, str): - item_id = UUID(item_id) + try: + if isinstance(item_id, str): + item_id = UUID(item_id) + except ValueError: + return None return super().read(item_id, check) - def delete(self, item_id: str | UUID, check: bool = True) -> WebhookEntity: + def delete(self, item_id: str | UUID, check: bool = True) -> WebhookEntity | None: """Delete a webhook from the database. Args: @@ -45,8 +48,11 @@ def delete(self, item_id: str | UUID, check: bool = True) -> WebhookEntity: Raises: NotFoundException: The id does not correspond to a record in the database. """ - if isinstance(item_id, str): - item_id = UUID(item_id) + try: + if isinstance(item_id, str): + item_id = UUID(item_id) + except ValueError: + return None return super().delete(item_id, check) def save(self, webhook: WebhookEntity) -> WebhookEntity: diff --git a/aana/tests/units/test_task_queue.py b/aana/tests/units/test_task_queue.py index 66e6ef4e..77ba0b2b 100644 --- a/aana/tests/units/test_task_queue.py +++ b/aana/tests/units/test_task_queue.py @@ -260,3 +260,8 @@ def test_task_queue(create_app): # noqa: C901 response = response.json() task_status = response.get("status") assert task_status == "completed", response + + # Test task count endpoint + response = requests.get(f"http://localhost:{port}/tasks/count") + assert response.status_code == 200 + assert response.json() == {"total": 31, "completed": 31}