Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(breadbox): Celery checker #144

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions breadbox/breadbox/api/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from breadbox.config import Settings, get_settings
from depmap_compute import models

from breadbox.schemas.custom_http_exception import UserError, HTTPError
from breadbox.schemas.custom_http_exception import UserError, CeleryConnectionError
from ..schemas.compute import ComputeParams, ComputeResponse
from ..compute import analysis_tasks
from .dependencies import get_user
from ..celery_task.utils import format_task_status, cast_celery_task
from ..celery_task import utils


router = APIRouter(prefix="/compute", tags=["compute"])
Expand Down Expand Up @@ -69,6 +69,8 @@ def compute_univariate_associations(
user: str = Depends(get_user),
settings: Settings = Depends(get_settings),
):
utils.check_celery()

resultsDirPrefix = settings.compute_results_location
dataset_id = computeParams.datasetId
vector_variable_type = computeParams.vectorVariableType
Expand Down Expand Up @@ -97,7 +99,7 @@ def compute_univariate_associations(
)

try:
result = cast_celery_task(analysis_tasks.run_custom_analysis).delay(
result = utils.cast_celery_task(analysis_tasks.run_custom_analysis).delay(
user=user,
analysis_type=analysis_type,
query_node_id=computeParams.queryId,
Expand All @@ -113,9 +115,11 @@ def compute_univariate_associations(
except PermissionError as e:
raise HTTPException(403, detail=str(e))

return format_task_status(result)
return utils.format_task_status(result)


@router.get("/test_task", operation_id="test_task")
def test_task(message):
cast_celery_task(analysis_tasks.test_task).delay(message)
utils.check_celery()

utils.cast_celery_task(analysis_tasks.test_task).delay(message)
3 changes: 3 additions & 0 deletions breadbox/breadbox/api/dataset_uploads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .dependencies import get_user

from ..celery_task import utils
from breadbox.schemas.custom_http_exception import CeleryConnectionError

router = APIRouter(prefix="/dataset-v2", tags=["datasets"])

Expand Down Expand Up @@ -64,6 +65,8 @@ def add_dataset_uploads(
- `col_type`: Annotation type for the column. Annotation types may include: `continuous`, `categorical`, `binary`, `text`, or `list_strings`

"""
utils.check_celery()

# Converts a data type (like a Pydantic model) to something compatible with JSON, in this case a dict. Although Celery uses a JSON serializer to serialize arguments to tasks by default, pydantic models are too complex for their default serializer. Pydantic models have a built-in .dict() method but it turns out it doesn't convert enums to strings which celery can't JSON serialize, so I opted to use fastapi's jsonable_encoder() which appears to successfully json serialize enums
dataset_json = jsonable_encoder(dataset)
result = run_dataset_upload.delay(dataset_json, user) # pyright: ignore
Expand Down
5 changes: 5 additions & 0 deletions breadbox/breadbox/api/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

from ..config import Settings, get_settings
from .dependencies import get_user, get_db_with_user
from breadbox.schemas.custom_http_exception import CeleryConnectionError


router = APIRouter(prefix="/downloads", tags=["downloads"])
log = getLogger(__name__)
Expand Down Expand Up @@ -67,6 +69,7 @@ def export_dataset(
user: str = Depends(get_user),
settings: Settings = Depends(get_settings),
):
utils.check_celery()

dataset_id = exportParams.datasetId
feature_labels = exportParams.featureLabels
Expand Down Expand Up @@ -103,6 +106,8 @@ def export_merged_dataset(
user: str = Depends(get_user),
settings: Settings = Depends(get_settings),
):
utils.check_celery()

dataset_ids = exportParams.datasetIds
feature_labels = exportParams.featureLabels
sample_ids = exportParams.cellLineIds
Expand Down
24 changes: 23 additions & 1 deletion breadbox/breadbox/celery_task/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from fastapi import HTTPException
from typing import Any, Optional
from ..compute.celery import app
from breadbox.schemas.custom_http_exception import UserError
from breadbox.schemas.custom_http_exception import UserError, CeleryConnectionError
from typing import Protocol, cast, Callable
from celery.result import AsyncResult, EagerResult

Expand Down Expand Up @@ -204,3 +204,25 @@ def update_state(
meta["message"] = message

task.update_state(state=state, meta=meta)


def check_celery():
"""
Checks to see if celery redis broker is connected.
Check worker stats to see if any workers are running
"""
inspect = app.control.inspect()
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that we need to explicitly connect to the broker before submitting a task. If the broker is offline, then we already get an exception thrown when it tries to submit the task to the broker.

That being said, I suppose it's harmless to try. Probably just costs a few milliseconds and I can't imagine it having any real downside.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pgm Usually if celery isn't running (at least on development) we've forgotten to start the celery broker. I figured returning an explicit error to the client could be helpful.. This does have 3 retries to connect to the broker which takes a few seconds rather than a few milliseconds though so maybe I should decrease the amount of retries

# Tries to connect to celery broker
conn = app.broker_connection().ensure_connection(max_retries=3)
except Exception as exc:
raise CeleryConnectionError(
"Failed to connect to celery redis broker!"
) from exc
# Pings workers to see if any of them respond. Returns None if no response
stats = inspect.stats()
# NOTE: app.control.broadcast("ping", reply=True, limit=1) or inspect.ping() pings all workers but will not return if all workers are busy
if stats is None:
raise CeleryConnectionError(
"Celery workers are not responding. Check if workers are running!"
)
13 changes: 12 additions & 1 deletion breadbox/breadbox/health_check/health_check.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from fastapi import APIRouter
from breadbox.health_check import site_check_task
from breadbox.celery_task.utils import format_task_status
from breadbox.celery_task.utils import format_task_status, check_celery
from breadbox.schemas.custom_http_exception import CeleryConnectionError

import logging

router = APIRouter(prefix="/health_check", tags=["health_check"])
Expand All @@ -26,6 +28,8 @@ def log_test():
"/ok", operation_id="ok",
)
def ok():
check_celery()

task = site_check_task.is_ok.delay()
task.wait(timeout=60, interval=0.5)

Expand All @@ -35,3 +39,10 @@ def ok():
@router.get("/simulate-error", operation_id="simulate_error")
def simulate_error():
raise Exception("Simulated error")


@router.get("/celery", operation_id="celery_check")
def celery_check():
check_celery()

return {"message": "ok"}
6 changes: 6 additions & 0 deletions breadbox/breadbox/schemas/custom_http_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class HTTPError(BaseModel):
status.HTTP_403_FORBIDDEN,
status.HTTP_404_NOT_FOUND,
status.HTTP_409_CONFLICT,
status.HTTP_503_SERVICE_UNAVAILABLE,
]
ERROR_RESPONSES = dict.fromkeys(ERROR_CODES, {"model": HTTPError})

Expand Down Expand Up @@ -64,3 +65,8 @@ def __init__(self, msg):
class ComputeLinearFitError(UserError):
def __init__(self, msg):
super().__init__(msg)


class CeleryConnectionError(HTTPException):
def __init__(self, msg, error_code=503):
super().__init__(error_code, msg)
13 changes: 13 additions & 0 deletions breadbox/tests/api/test_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from breadbox.models.dataset import ValueType
from breadbox.compute import analysis_tasks
from breadbox.compute.analysis_tasks import _subset_feature_df
from breadbox.celery_task import utils
from tests import factories
from breadbox.compute.analysis_tasks import (
run_custom_analysis,
Expand Down Expand Up @@ -88,6 +89,12 @@ def mock_db_context(user, commit=True):
def get_test_settings():
return settings

def mock_check_celery():
return True

# Monkeypatch check_celery and pretend celery is running for test
monkeypatch.setattr(utils, "check_celery", mock_check_celery)

monkeypatch.setattr(
analysis_tasks, "get_settings", get_test_settings,
)
Expand Down Expand Up @@ -226,6 +233,12 @@ def mock_db_context(user, commit=True):
def get_test_settings():
return settings

def mock_check_celery():
return True

# Monkeypatch check_celery and pretend celery is running for test
monkeypatch.setattr(utils, "check_celery", mock_check_celery)

monkeypatch.setattr(
analysis_tasks, "get_settings", get_test_settings,
)
Expand Down
7 changes: 7 additions & 0 deletions breadbox/tests/api/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from breadbox.compute import analysis_tasks
from breadbox.compute.analysis_tasks import get_feature_data_slice_values
from breadbox.celery_task import utils
from breadbox.crud.dataset import get_dataset
from tests import factories

Expand Down Expand Up @@ -691,6 +692,12 @@ def mock_db_context(user, **kwargs):
def get_test_settings():
return settings

def mock_check_celery():
return True

# Monkeypatch check_celery and pretend celery is running for test
monkeypatch.setattr(utils, "check_celery", mock_check_celery)

monkeypatch.setattr(
analysis_tasks, "get_settings", get_test_settings,
)
Expand Down
6 changes: 6 additions & 0 deletions breadbox/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,12 @@ def mock_db_context(user, commit=False):
def get_test_settings():
return settings

def mock_check_celery():
return True

# Monkeypatch check_celery and pretend celery is running for test
monkeypatch.setattr(utils, "check_celery", mock_check_celery)

# The endpoint uses celery, and needs monkeypatching to replace db_context and get_settings,
# which are not passed in as params due to the limits of redis serialization.
monkeypatch.setattr(dataset_tasks, "db_context", mock_db_context)
Expand Down
Loading