Skip to content

Commit

Permalink
Add task count endpoint.
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Movchan committed Feb 4, 2025
1 parent 6dacd94 commit e95212a
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 13 deletions.
28 changes: 27 additions & 1 deletion aana/routers/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Annotated

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from aana.api.security import UserIdDependency
from aana.configs.settings import settings as aana_settings
Expand Down Expand Up @@ -39,9 +39,35 @@ class TaskList(BaseModel):
tasks: list[TaskInfo] = Field(..., description="The list of tasks.")


# fmt: off
class TaskCount(BaseModel):
"""Response for a count of tasks by status."""

created: int | None = Field(None, description="The number of tasks in the CREATED status.")
assigned: int | None = Field(None, description="The number of tasks in the ASSIGNED status.")
completed: int | None = Field(None, description="The number of tasks in the COMPLETED status.")
running: int | None = Field(None, description="The number of tasks in the RUNNING status.")
failed: int | None = Field(None, description="The number of tasks in the FAILED status.")
not_finished: int | None = Field(None, description="The number of tasks in the NOT_FINISHED status.")
total: int = Field(..., description="The total number of tasks.")
# fmt: on

# Endpoints


@router.get(
"/tasks/count",
summary="Count Tasks",
description="Count tasks per status.",
response_model_exclude_none=True,
)
async def count_tasks(db: GetDbDependency, user_id: UserIdDependency) -> TaskCount:
"""Count tasks by status."""
task_repo = TaskRepository(db)
counts = task_repo.count(user_id=user_id)
return TaskCount(**counts)


@router.get(
"/tasks/{task_id}",
summary="Get Task Status",
Expand Down
11 changes: 10 additions & 1 deletion aana/storage/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@


class Status(str, Enum):
"""Enum for task status."""
"""Enum for task status.
Attributes:
CREATED: The task is created.
ASSIGNED: The task is assigned to a worker.
COMPLETED: The task is completed.
RUNNING: The task is running.
FAILED: The task has failed.
NOT_FINISHED: The task is not finished.
"""

CREATED = "created"
ASSIGNED = "assigned"
Expand Down
35 changes: 30 additions & 5 deletions aana/storage/repository/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, session: Session):
"""Constructor."""
super().__init__(session, TaskEntity)

def read(self, task_id: str | UUID, check: bool = True) -> TaskEntity:
def read(self, task_id: str | UUID, check: bool = True) -> TaskEntity | None:
"""Reads a single task by id from the database.
Args:
Expand All @@ -31,8 +31,11 @@ def read(self, task_id: str | UUID, check: bool = True) -> TaskEntity:
Raises:
NotFoundException if the entity is not found and `check` is True.
"""
if isinstance(task_id, str):
task_id = UUID(task_id)
try:
if isinstance(task_id, str):
task_id = UUID(task_id)
except ValueError:
return None
return super().read(task_id, check=check)

def delete(self, task_id: str | UUID, check: bool = False) -> TaskEntity | None:
Expand All @@ -48,8 +51,11 @@ def delete(self, task_id: str | UUID, check: bool = False) -> TaskEntity | None:
Raises:
NotFoundException: The id does not correspond to a record in the database.
"""
if isinstance(task_id, str):
task_id = UUID(task_id)
try:
if isinstance(task_id, str):
task_id = UUID(task_id)
except ValueError:
return None
return super().delete(task_id, check)

def save(
Expand Down Expand Up @@ -442,3 +448,22 @@ def get_tasks(
.all()
)
return tasks

def count(self, user_id: str | None = None) -> dict[str, int]:
"""Count tasks by status.
Args:
user_id (str | None): The user ID. If None, all tasks are counted.
Returns:
dict[str, int]: The count of tasks by status.
"""
counts = (
self.session.query(TaskEntity.status, func.count(TaskEntity.id))
.filter(TaskEntity.user_id == user_id if user_id else True)
.group_by(TaskEntity.status)
.all()
)
count_dict = {status.value: count for status, count in counts}
count_dict["total"] = sum(count_dict.values())
return count_dict
18 changes: 12 additions & 6 deletions aana/storage/repository/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, session: Session):
"""Constructor."""
super().__init__(session, WebhookEntity)

def read(self, item_id: str | UUID, check: bool = True) -> WebhookEntity:
def read(self, item_id: str | UUID, check: bool = True) -> WebhookEntity | None:
"""Reads a single webhook from the database.
Args:
Expand All @@ -28,11 +28,14 @@ def read(self, item_id: str | UUID, check: bool = True) -> WebhookEntity:
Raises:
NotFoundException if the entity is not found and `check` is True.
"""
if isinstance(item_id, str):
item_id = UUID(item_id)
try:
if isinstance(item_id, str):
item_id = UUID(item_id)
except ValueError:
return None
return super().read(item_id, check)

def delete(self, item_id: str | UUID, check: bool = True) -> WebhookEntity:
def delete(self, item_id: str | UUID, check: bool = True) -> WebhookEntity | None:
"""Delete a webhook from the database.
Args:
Expand All @@ -45,8 +48,11 @@ def delete(self, item_id: str | UUID, check: bool = True) -> WebhookEntity:
Raises:
NotFoundException: The id does not correspond to a record in the database.
"""
if isinstance(item_id, str):
item_id = UUID(item_id)
try:
if isinstance(item_id, str):
item_id = UUID(item_id)
except ValueError:
return None
return super().delete(item_id, check)

def save(self, webhook: WebhookEntity) -> WebhookEntity:
Expand Down
5 changes: 5 additions & 0 deletions aana/tests/units/test_task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,8 @@ def test_task_queue(create_app): # noqa: C901
response = response.json()
task_status = response.get("status")
assert task_status == "completed", response

# Test task count endpoint
response = requests.get(f"http://localhost:{port}/tasks/count")
assert response.status_code == 200
assert response.json() == {"total": 31, "completed": 31}

0 comments on commit e95212a

Please sign in to comment.