diff --git a/specs/dodam/warp_zfish_subchunkable_example.cue b/specs/dodam/warp_zfish_subchunkable_example.cue index eb39bf8c9..cd2a249cc 100644 --- a/specs/dodam/warp_zfish_subchunkable_example.cue +++ b/specs/dodam/warp_zfish_subchunkable_example.cue @@ -13,7 +13,7 @@ } start_coord: [4096 * 3, 4096 * 4, 3003] - end_coord: [12288 * 2, 12288 * 2, 3011] + end_coord: [12288 * 2, 12288 * 2, 3006] coord_resolution: [16, 16, 30] dst_resolution: [32, 32, 30] @@ -24,7 +24,7 @@ //processing_crop_pads: [[0, 0, 0], [16, 16, 0], [32, 32, 0]] //processing_blend_pads: [[0, 0, 0], [0, 0, 0], [16, 16, 0]] //level_intermediaries_dirs: [#TEMP_PATH2, #TEMP_PATH1, #TEMP_PATH0] - processing_chunk_sizes: [[8192, 8192, 1], [2048, 2048, 1]] + processing_chunk_sizes: [[1024, 1024, 1], [512, 512, 1]] processing_crop_pads: [[0, 0, 0], [16, 16, 0]] processing_blend_pads: [[0, 0, 0], [16, 16, 0]] processing_blend_modes: "quadratic" @@ -62,10 +62,11 @@ "@type": "mazepa.execute_locally" target: #FLOW_TMPL -num_procs: 4 +num_procs: 32 semaphores_spec: { read: 4 write: 4 - cuda: 4 + cuda: 1 cpu: 4 } +debug: true diff --git a/tests/unit/common/test_multiprocessing.py b/tests/unit/common/test_multiprocessing.py deleted file mode 100644 index c229bda22..000000000 --- a/tests/unit/common/test_multiprocessing.py +++ /dev/null @@ -1,43 +0,0 @@ -import pytest - -import zetta_utils.common.multiprocessing -from zetta_utils.common.multiprocessing import ( - get_persistent_process_pool, - setup_persistent_process_pool, -) - - -def times2(x): - return 2 * x - - -@pytest.mark.parametrize("args, expected", [[[1, 2, 3], [2, 4, 6]]]) -def test_run_func(args, expected): - with setup_persistent_process_pool(3): - pool = get_persistent_process_pool() - assert pool is not None - future = pool.map(times2, args) - assert list(future.result()) == expected - - -def test_no_init_with_1_proc(): - with setup_persistent_process_pool(1): - assert get_persistent_process_pool() is None - - -def test_double_init_exc(): - with pytest.raises(RuntimeError): - with setup_persistent_process_pool(2): - with setup_persistent_process_pool(2): - pass - - -def test_unalloc_nonexistent_exc(): - with pytest.raises(RuntimeError): - with setup_persistent_process_pool(2): - pool = get_persistent_process_pool() - assert pool is not None - pool.stop() - pool.join() - zetta_utils.common.multiprocessing.PERSISTENT_PROCESS_POOL = None - # exception on exiting context diff --git a/tests/unit/mazepa/test_end_to_end_workflow_sqs.py b/tests/unit/mazepa/test_end_to_end_workflow_sqs.py index a668b6fda..c58aa0a08 100644 --- a/tests/unit/mazepa/test_end_to_end_workflow_sqs.py +++ b/tests/unit/mazepa/test_end_to_end_workflow_sqs.py @@ -11,7 +11,7 @@ from zetta_utils.mazepa.tasks import _TaskableOperation from zetta_utils.message_queues.sqs.queue import SQSQueue -from ..message_queues.test_sqs_queue import aws_credentials, sqs_endpoint +from ..message_queues.sqs.test_sqs_queue import aws_credentials, sqs_endpoint boto3.setup_default_session() diff --git a/tests/unit/common/test_semaphores.py b/tests/unit/mazepa/test_semaphores.py similarity index 88% rename from tests/unit/common/test_semaphores.py rename to tests/unit/mazepa/test_semaphores.py index dc920a6d9..133fc61e6 100644 --- a/tests/unit/common/test_semaphores.py +++ b/tests/unit/mazepa/test_semaphores.py @@ -4,9 +4,10 @@ from typing import List import posix_ipc +import psutil import pytest -from zetta_utils.common.semaphores import ( +from zetta_utils.mazepa.semaphores import ( DummySemaphore, SemaphoreType, configure_semaphores, @@ -21,7 +22,7 @@ def cleanup_semaphores(): sema_types: List[SemaphoreType] = ["read", "write", "cuda", "cpu"] for name in sema_types: try: - # two unlinks in case parent semaphore exists + # two unlinks in case grandparent semaphore exists semaphore(name).unlink() semaphore(name).unlink() except: @@ -65,13 +66,14 @@ def test_unlink_nonexistent_exc(): # exception on exiting context -def test_get_parent_semaphore(): +def test_get_grandparent_semaphore(): + grandpa_pid = psutil.Process(os.getppid()).ppid() try: - sema = posix_ipc.Semaphore(name_to_posix_name("read", os.getppid())) + sema = posix_ipc.Semaphore(name_to_posix_name("read", grandpa_pid)) sema.unlink() except: pass - sema = posix_ipc.Semaphore(name_to_posix_name("read", os.getppid()), flags=posix_ipc.O_CREX) + sema = posix_ipc.Semaphore(name_to_posix_name("read", grandpa_pid), flags=posix_ipc.O_CREX) assert sema.name == semaphore("read").name sema.unlink() diff --git a/tests/unit/message_queues/file/__init__.py b/tests/unit/message_queues/file/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/message_queues/file/test_queue.py b/tests/unit/message_queues/file/test_queue.py new file mode 100644 index 000000000..c2a148bda --- /dev/null +++ b/tests/unit/message_queues/file/test_queue.py @@ -0,0 +1,73 @@ +import time + +import pytest + +from zetta_utils.message_queues.file.queue import FileQueue + + +def success_fn(): + return "Success" + + +def test_make_and_delete_file_queue(): + with FileQueue("test_queue"): + pass + + +def test_get_tq_queue(): + with FileQueue("test_queue"): + FileQueue("test_queue")._get_tq_queue() # pylint:disable = protected-access + + +def test_push_pull(): + with FileQueue("test_queue") as q: + payloads = {None, 1, "asdfadsfdsa", success_fn} + q.push(list(payloads)) + time.sleep(0.1) + result = q.pull(max_num=len(payloads)) + assert len(result) == len(payloads) + received_payloads = {r.payload for r in result} + assert received_payloads == payloads + + +def test_delete(): + with FileQueue("test_queue") as q: + q.push([None]) + time.sleep(0.1) + result = q.pull(max_num=10) + assert len(result) == 1 + result[0].acknowledge_fn() + time.sleep(1.1) + result_empty = q.pull() + assert len(result_empty) == 0 + + +def test_extend_lease(): + with FileQueue("test_queue") as q: + q.push([None]) + time.sleep(0.1) + result = q.pull() + assert len(result) == 1 + result[0].extend_lease_fn(3) + time.sleep(1) + result_empty = q.pull() + assert len(result_empty) == 0 + time.sleep(2.1) + result_nonempty = q.pull() + assert len(result_nonempty) == 1 + + +@pytest.mark.parametrize( + "queue_name", ["fq://test_queue", "file://test_queue", "sqs://test_queue"] +) +def test_prefix_exc(queue_name): + with pytest.raises(ValueError): + with FileQueue(queue_name): + pass + + +def test_double_init_exc(): + with pytest.raises(RuntimeError): + with FileQueue("test_queue"): + with FileQueue("test_queue"): + pass diff --git a/tests/unit/message_queues/sqs/__init__.py b/tests/unit/message_queues/sqs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/message_queues/test_sqs_queue.py b/tests/unit/message_queues/sqs/test_queue.py similarity index 100% rename from tests/unit/message_queues/test_sqs_queue.py rename to tests/unit/message_queues/sqs/test_queue.py diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/common.py b/zetta_utils/cloud_management/resource_allocation/k8s/common.py index c7fd77824..f405b65a4 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/common.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/common.py @@ -10,7 +10,7 @@ from kubernetes import client as k8s_client # type: ignore from zetta_utils import builder, log -from zetta_utils.common import SemaphoreType +from zetta_utils.mazepa import SemaphoreType from .eks import eks_cluster_data from .gke import gke_cluster_data @@ -68,7 +68,7 @@ def get_mazepa_worker_command( result = ( """ zetta -vv -l try run -s '{ - "@type": "mazepa.run_worker" + "@type": "mazepa.run_worker_manager" """ + f"task_queue: {json.dumps(task_queue_spec)}\n" + f"outcome_queue: {json.dumps(outcome_queue_spec)}\n" diff --git a/zetta_utils/cloud_management/resource_allocation/k8s/deployment.py b/zetta_utils/cloud_management/resource_allocation/k8s/deployment.py index 53d09561c..8c28fd824 100644 --- a/zetta_utils/cloud_management/resource_allocation/k8s/deployment.py +++ b/zetta_utils/cloud_management/resource_allocation/k8s/deployment.py @@ -9,7 +9,7 @@ from kubernetes import client as k8s_client # type: ignore from zetta_utils import builder, log -from zetta_utils.common import SemaphoreType +from zetta_utils.mazepa import SemaphoreType from ..resource_tracker import ( ExecutionResource, diff --git a/zetta_utils/common/__init__.py b/zetta_utils/common/__init__.py index ef5314b8f..1150a881f 100644 --- a/zetta_utils/common/__init__.py +++ b/zetta_utils/common/__init__.py @@ -11,5 +11,3 @@ from .pprint import lrpad from .signal_handlers import custom_signal_handler_ctx from .timer import RepeatTimer -from .semaphores import SemaphoreType, configure_semaphores, semaphore -from .multiprocessing import setup_persistent_process_pool, get_persistent_process_pool diff --git a/zetta_utils/common/multiprocessing.py b/zetta_utils/common/multiprocessing.py deleted file mode 100644 index 4a0d2db67..000000000 --- a/zetta_utils/common/multiprocessing.py +++ /dev/null @@ -1,48 +0,0 @@ -# pylint: disable=global-statement -from __future__ import annotations - -import contextlib - -from pebble import ProcessPool - -from zetta_utils import log - -logger = log.get_logger("zetta_utils") - -PERSISTENT_PROCESS_POOL: ProcessPool | None = None - - -@contextlib.contextmanager -def setup_persistent_process_pool(num_procs: int): - """ - Context manager for creating a persistent pool of workers. - """ - - global PERSISTENT_PROCESS_POOL - try: - if num_procs == 1: - logger.info("Skipping creation because 1 process is requested.") - elif PERSISTENT_PROCESS_POOL is not None: - raise RuntimeError("Persistent process pool already exists.") - else: - logger.info(f"Creating a persistent process pool with {num_procs} processes.") - PERSISTENT_PROCESS_POOL = ProcessPool(num_procs) - yield - finally: - if num_procs == 1: - pass - elif PERSISTENT_PROCESS_POOL is None: - raise RuntimeError("Persistent process pool does not exist.") - else: - PERSISTENT_PROCESS_POOL.stop() - PERSISTENT_PROCESS_POOL.join() - PERSISTENT_PROCESS_POOL = None - logger.info("Cleaned up persistent process pool.") - - -def get_persistent_process_pool() -> ProcessPool | None: - """ - Fetches and returns either the semaphore associated with the current process, - or the semaphore associated with the parent process, in that order. - """ - return PERSISTENT_PROCESS_POOL diff --git a/zetta_utils/common/path.py b/zetta_utils/common/path.py index e6de7db30..4afe29215 100644 --- a/zetta_utils/common/path.py +++ b/zetta_utils/common/path.py @@ -11,10 +11,15 @@ def abspath(path: str) -> str: path_no_prefix = split[-1] if len(prefixes) == 0: prefixes = ["file"] - if prefixes == ["file"]: + if prefixes in (["file"], ["fq"]): path_no_prefix = os.path.abspath(os.path.expanduser(path_no_prefix)) return "://".join(prefixes + [path_no_prefix]) +def strip_prefix(path: str) -> str: # pragma: no cover + return path.split("://")[-1] + + def is_local(path: str) -> bool: # pragma: no cover - return abspath(path).startswith("file://") + local_prefixes = ["file://", "fq://"] + return any(abspath(path).startswith(local_prefix) for local_prefix in local_prefixes) diff --git a/zetta_utils/convnet/utils.py b/zetta_utils/convnet/utils.py index b004f79c2..623bbe1ae 100644 --- a/zetta_utils/convnet/utils.py +++ b/zetta_utils/convnet/utils.py @@ -40,7 +40,7 @@ def _load_model( return result -_load_model_cached = cachetools.cached(cachetools.LRUCache(maxsize=8))(_load_model) +_load_model_cached = cachetools.cached(cachetools.LRUCache(maxsize=2))(_load_model) @typechecked diff --git a/zetta_utils/layer/volumetric/cloudvol/backend.py b/zetta_utils/layer/volumetric/cloudvol/backend.py index a3f663724..ad08aa815 100644 --- a/zetta_utils/layer/volumetric/cloudvol/backend.py +++ b/zetta_utils/layer/volumetric/cloudvol/backend.py @@ -113,7 +113,7 @@ def _set_cv_defaults(self): self.cv_kwargs.setdefault("progress", False) self.cv_kwargs.setdefault("autocrop", False) self.cv_kwargs.setdefault("non_aligned_writes", False) - self.cv_kwargs.setdefault("cache", not self.is_local) + self.cv_kwargs.setdefault("cache", False) self.cv_kwargs.setdefault("compress_cache", False) self.cv_kwargs.setdefault("compress", True) self.cv_kwargs.setdefault("cdn_cache", False) @@ -195,7 +195,6 @@ def clear_disk_cache(self) -> None: # pragma: no cover def clear_cache(self) -> None: # pragma: no cover _clear_cv_cache(self.path) - self.clear_disk_cache() def read(self, idx: VolumetricIndex) -> torch.Tensor: # Data out: cxyz diff --git a/zetta_utils/mazepa/__init__.py b/zetta_utils/mazepa/__init__.py index 36704e055..17cc36d6e 100644 --- a/zetta_utils/mazepa/__init__.py +++ b/zetta_utils/mazepa/__init__.py @@ -29,3 +29,4 @@ from .progress_tracker import progress_ctx_mngr from .execution import Executor, execute from .worker import run_worker +from .semaphores import SemaphoreType, configure_semaphores, semaphore diff --git a/zetta_utils/mazepa/autoexecute_task_queue.py b/zetta_utils/mazepa/autoexecute_task_queue.py index 5f4b8a2f2..becd13539 100644 --- a/zetta_utils/mazepa/autoexecute_task_queue.py +++ b/zetta_utils/mazepa/autoexecute_task_queue.py @@ -6,7 +6,6 @@ from typeguard import typechecked from zetta_utils import log -from zetta_utils.common import get_persistent_process_pool from zetta_utils.mazepa.worker import process_task_message from zetta_utils.message_queues.base import MessageQueue, ReceivedMessage @@ -23,7 +22,6 @@ class AutoexecuteTaskQueue(MessageQueue): tasks_todo: list[Task] = attrs.field(init=False, factory=list) debug: bool = False handle_exceptions: bool = False - parallel_if_pool_exists: bool = False def push(self, payloads: Iterable[Task]): # TODO: Fix progress bar issue with multiple live displays in rich @@ -40,27 +38,9 @@ def pull( if len(self.tasks_todo) == 0: return [] else: - pool = get_persistent_process_pool() - if not self.parallel_if_pool_exists or pool is None: - results: list[ReceivedMessage[OutcomeReport]] = [] - for task in self.tasks_todo[:max_num]: - results.append(execute_task(task, self.debug, self.handle_exceptions)) - # TODO: remove monkey patching from builder so that unit tests work; - # pickle does not handle monkey patched objects inside Python - else: # pragma: no cover - futures = [] - for task in self.tasks_todo[:max_num]: - futures.append( - pool.schedule( - execute_task, - kwargs={ - "task": task, - "debug": self.debug, - "handle_exceptions": self.handle_exceptions, - }, - ) - ) - results = [future.result() for future in futures] + results: list[ReceivedMessage[OutcomeReport]] = [] + for task in self.tasks_todo[:max_num]: + results.append(execute_task(task, self.debug, self.handle_exceptions)) self.tasks_todo = self.tasks_todo[max_num:] return results diff --git a/zetta_utils/mazepa/execution.py b/zetta_utils/mazepa/execution.py index 496de18af..4551d4da2 100644 --- a/zetta_utils/mazepa/execution.py +++ b/zetta_utils/mazepa/execution.py @@ -39,7 +39,6 @@ class Executor: checkpoint: Optional[str] = None checkpoint_interval_sec: Optional[float] = None raise_on_failed_checkpoint: bool = True - parallel_if_pool_exists: bool = False def __call__(self, target: Union[Task, Flow, ExecutionState, ComparablePartial, Callable]): assert (self.task_queue is None and self.outcome_queue is None) or ( @@ -58,7 +57,6 @@ def __call__(self, target: Union[Task, Flow, ExecutionState, ComparablePartial, checkpoint=self.checkpoint, checkpoint_interval_sec=self.checkpoint_interval_sec, raise_on_failed_checkpoint=self.raise_on_failed_checkpoint, - parallel_if_pool_exists=self.parallel_if_pool_exists, ) @@ -77,7 +75,6 @@ def execute( checkpoint: Optional[str] = None, checkpoint_interval_sec: Optional[float] = 150, raise_on_failed_checkpoint: bool = True, - parallel_if_pool_exists: bool = False, ): """ Executes a target until completion using the given execution queue. @@ -122,7 +119,7 @@ def execute( else: assert outcome_queue is None task_queue_ = AutoexecuteTaskQueue( - debug=True, parallel_if_pool_exists=parallel_if_pool_exists + debug=True, ) outcome_queue_ = task_queue_ diff --git a/zetta_utils/common/semaphores.py b/zetta_utils/mazepa/semaphores.py similarity index 89% rename from zetta_utils/common/semaphores.py rename to zetta_utils/mazepa/semaphores.py index 4dafffca9..f35ad982e 100644 --- a/zetta_utils/common/semaphores.py +++ b/zetta_utils/mazepa/semaphores.py @@ -5,6 +5,7 @@ from typing import List, Literal, get_args import attrs +import psutil from posix_ipc import ( # pylint: disable=no-name-in-module O_CREX, ExistentialError, @@ -13,7 +14,7 @@ from zetta_utils import log -logger = log.get_logger("zetta_utils") +logger = log.get_logger("mazepa") SemaphoreType = Literal["read", "write", "cuda", "cpu"] DEFAULT_SEMA_COUNT = 1 @@ -103,7 +104,10 @@ def unlink(self): def semaphore(name: SemaphoreType) -> Semaphore: """ Fetches and returns either the semaphore associated with the current process, - or the semaphore associated with the parent process, or a dummy semaphore, in that order. + or the semaphore associated with the grandparent process, or a dummy semaphore, + in that order. Note the grandparent, and not the parent, process is used - + this is because subprocess.popen is the child, and the `zetta` binary wrapper + ends up being the grandchild. """ if not name in get_args(SemaphoreType): raise ValueError(f"`{name}` is not a valid semaphore type.") @@ -111,6 +115,6 @@ def semaphore(name: SemaphoreType) -> Semaphore: return Semaphore(name_to_posix_name(name, os.getpid())) except ExistentialError: try: - return Semaphore(name_to_posix_name(name, os.getppid())) + return Semaphore(name_to_posix_name(name, psutil.Process(os.getppid()).ppid())) except ExistentialError: return DummySemaphore() diff --git a/zetta_utils/mazepa/worker.py b/zetta_utils/mazepa/worker.py index 4ba375367..d9c53e687 100644 --- a/zetta_utils/mazepa/worker.py +++ b/zetta_utils/mazepa/worker.py @@ -4,18 +4,12 @@ import sys import time import traceback -from contextlib import ExitStack from typing import Any, Callable, Optional import tenacity from zetta_utils import log -from zetta_utils.common import ( - RepeatTimer, - SemaphoreType, - configure_semaphores, - setup_persistent_process_pool, -) +from zetta_utils.common import RepeatTimer from zetta_utils.mazepa import constants, exceptions from zetta_utils.mazepa.exceptions import MazepaCancel, MazepaTimeoutError from zetta_utils.mazepa.task_outcome import OutcomeReport, TaskOutcome @@ -27,14 +21,15 @@ from . import Task -logger = log.get_logger("mazepa") - class AcceptAllTasks: def __call__(self, task: Task): return True +logger = log.get_logger("mazepa") + + def run_worker( task_queue: MessageQueue[Task], outcome_queue: MessageQueue[OutcomeReport], @@ -42,67 +37,62 @@ def run_worker( max_pull_num: int = 1, max_runtime: Optional[float] = None, task_filter_fn: Callable[[Task], bool] = AcceptAllTasks(), - num_procs: int = 1, - semaphores_spec: dict[SemaphoreType, int] | None = None, debug: bool = False, ): - with ExitStack() as stack: - stack.enter_context(configure_semaphores(semaphores_spec)) - stack.enter_context(setup_persistent_process_pool(num_procs)) - start_time = time.time() - while True: - try: - task_msgs = task_queue.pull(max_num=max_pull_num) - except (exceptions.MazepaException, SystemExit, KeyboardInterrupt) as e: - raise e # pragma: no cover - except Exception as e: # pylint: disable=broad-except - # The broad except here is OK because it will be propagated to the outcome - # queue and reraise the exception - logger.error("Failed pulling tasks from the queue:") - logger.exception(e) - exc_type, exception, tb = sys.exc_info() - traceback_text = "".join(traceback.format_exception(exc_type, exception, tb)) - - outcome = TaskOutcome[Any]( - exception=exception, - traceback_text=traceback_text, - execution_sec=0, - return_value=None, - ) - outcome_report = OutcomeReport(task_id=constants.UNKNOWN_TASK_ID, outcome=outcome) - outcome_queue.push([outcome_report]) - raise e - - logger.info(f"Got {len(task_msgs)} tasks.") - - if len(task_msgs) == 0: - logger.info(f"Sleeping for {sleep_sec} secs.") - time.sleep(sleep_sec) - else: - logger.info("STARTING: task batch execution.") - time_start = time.time() - for msg in task_msgs: - task = msg.payload - with log.logging_tag_ctx("task_id", task.id_): - with log.logging_tag_ctx("execution_id", task.execution_id): - if task_filter_fn(task): - ack_task, outcome = process_task_message(msg=msg, debug=debug) - else: - ack_task = True - outcome = TaskOutcome(exception=MazepaCancel()) - - if ack_task: - outcome_report = OutcomeReport( - task_id=msg.payload.id_, outcome=outcome - ) - outcome_queue.push([outcome_report]) - msg.acknowledge_fn() - - time_end = time.time() - logger.info(f"DONE: task batch execution ({time_end - time_start:.2f}sec).") - - if max_runtime is not None and time.time() - start_time > max_runtime: - break + start_time = time.time() + while True: + try: + task_msgs = task_queue.pull(max_num=max_pull_num) + except (exceptions.MazepaException, SystemExit, KeyboardInterrupt) as e: + raise e # pragma: no cover + except Exception as e: # pylint: disable=broad-except + # The broad except here is OK because it will be propagated to the outcome + # queue and reraise the exception + logger.error("Failed pulling tasks from the queue:") + logger.exception(e) + exc_type, exception, tb = sys.exc_info() + traceback_text = "".join(traceback.format_exception(exc_type, exception, tb)) + + outcome = TaskOutcome[Any]( + exception=exception, + traceback_text=traceback_text, + execution_sec=0, + return_value=None, + ) + outcome_report = OutcomeReport(task_id=constants.UNKNOWN_TASK_ID, outcome=outcome) + outcome_queue.push([outcome_report]) + raise e + + logger.info(f"Got {len(task_msgs)} tasks.") + + if len(task_msgs) == 0: + logger.info(f"Sleeping for {sleep_sec} secs.") + time.sleep(sleep_sec) + else: + logger.info("STARTING: task batch execution.") + time_start = time.time() + for msg in task_msgs: + task = msg.payload + with log.logging_tag_ctx("task_id", task.id_): + with log.logging_tag_ctx("execution_id", task.execution_id): + if task_filter_fn(task): + ack_task, outcome = process_task_message(msg=msg, debug=debug) + else: + ack_task = True + outcome = TaskOutcome(exception=MazepaCancel()) + + if ack_task: + outcome_report = OutcomeReport( + task_id=msg.payload.id_, outcome=outcome + ) + outcome_queue.push([outcome_report]) + msg.acknowledge_fn() + + time_end = time.time() + logger.info(f"DONE: task batch execution ({time_end - time_start:.2f}sec).") + + if max_runtime is not None and time.time() - start_time > max_runtime: + break def process_task_message( diff --git a/zetta_utils/mazepa_addons/configurations/execute_locally.py b/zetta_utils/mazepa_addons/configurations/execute_locally.py index 7d11743a7..163ba0b59 100644 --- a/zetta_utils/mazepa_addons/configurations/execute_locally.py +++ b/zetta_utils/mazepa_addons/configurations/execute_locally.py @@ -1,21 +1,19 @@ # pylint: disable=too-many-locals from __future__ import annotations +import os from contextlib import ExitStack from typing import Callable, Optional, Union from typeguard import typechecked from zetta_utils import builder, log -from zetta_utils.common import ( - ComparablePartial, - SemaphoreType, - configure_semaphores, - setup_persistent_process_pool, -) -from zetta_utils.mazepa import Flow, Task, execute -from zetta_utils.mazepa.autoexecute_task_queue import AutoexecuteTaskQueue +from zetta_utils.common import ComparablePartial +from zetta_utils.mazepa import Flow, SemaphoreType, Task, configure_semaphores, execute from zetta_utils.mazepa.execution_state import ExecutionState, InMemoryExecutionState +from zetta_utils.message_queues import FileQueue + +from .worker_pool import setup_local_worker_pool logger = log.get_logger("mazepa") @@ -24,7 +22,6 @@ @builder.register("mazepa.execute_locally") def execute_locally( target: Union[Task, Flow, ExecutionState, ComparablePartial, Callable], - task_queue: AutoexecuteTaskQueue | None = None, max_batch_len: int = 1000, batch_gap_sleep_sec: float = 0.5, state_constructor: Callable[..., ExecutionState] = InMemoryExecutionState, @@ -37,18 +34,32 @@ def execute_locally( raise_on_failed_checkpoint: bool = True, num_procs: int = 1, semaphores_spec: dict[SemaphoreType, int] | None = None, + debug: bool = False, ): + with ExitStack() as stack: logger.info( "Configuring for local execution: " - "allocating semaphores and persistent process pool as needed." + "creating local queues, allocating semaphores, and starting local workers." ) stack.enter_context(configure_semaphores(semaphores_spec)) - stack.enter_context(setup_persistent_process_pool(num_procs)) + + if debug: + logger.info("Debug mode: Using single process execution without local queues.") + task_queue = None + outcome_queue = None + else: + task_queue_name = f"{os.getpid()}_task_queue" + outcome_queue_name = f"{os.getpid()}_outcome_queue" + task_queue = stack.enter_context(FileQueue(task_queue_name)) + outcome_queue = stack.enter_context(FileQueue(outcome_queue_name)) + stack.enter_context( + setup_local_worker_pool(num_procs, task_queue_name, outcome_queue_name) + ) execute( target=target, task_queue=task_queue, - outcome_queue=task_queue, + outcome_queue=outcome_queue, max_batch_len=max_batch_len, batch_gap_sleep_sec=batch_gap_sleep_sec, state_constructor=state_constructor, diff --git a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py index e0910d4bf..b4dbf3c86 100644 --- a/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py +++ b/zetta_utils/mazepa_addons/configurations/execute_on_gcp_with_sqs.py @@ -7,8 +7,7 @@ from zetta_utils import builder, log, mazepa from zetta_utils.cloud_management import execution_tracker, resource_allocation -from zetta_utils.common import SemaphoreType -from zetta_utils.mazepa import execute +from zetta_utils.mazepa import SemaphoreType, execute from zetta_utils.mazepa.task_outcome import OutcomeReport from zetta_utils.mazepa.tasks import Task from zetta_utils.message_queues import sqs # pylint: disable=unused-import @@ -131,6 +130,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals show_progress: bool = True, do_dryrun_estimation: bool = True, local_test: bool = False, + debug: bool = False, checkpoint: Optional[str] = None, checkpoint_interval_sec: float = 300.0, raise_on_failed_checkpoint: bool = True, @@ -146,6 +146,8 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals execution_tracker.record_execution_run(execution_id) ctx_managers = copy.copy(list(extra_ctx_managers)) + if debug and not local_test: + raise ValueError("`debug` can only be set to `True` when `local_test` is also `True`.") if local_test: execution_tracker.register_execution(execution_id, []) else: @@ -189,7 +191,6 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals if local_test: execute_locally( target=target, - task_queue=None, execution_id=execution_id, max_batch_len=max_batch_len, batch_gap_sleep_sec=batch_gap_sleep_sec, @@ -200,6 +201,7 @@ def execute_on_gcp_with_sqs( # pylint: disable=too-many-locals raise_on_failed_checkpoint=raise_on_failed_checkpoint, num_procs=num_procs, semaphores_spec=semaphores_spec, + debug=debug, ) else: execute( diff --git a/zetta_utils/mazepa_addons/configurations/worker_pool.py b/zetta_utils/mazepa_addons/configurations/worker_pool.py new file mode 100644 index 000000000..f3e6a8145 --- /dev/null +++ b/zetta_utils/mazepa_addons/configurations/worker_pool.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +import contextlib +import json +import os +import subprocess +import tempfile +import time +from contextlib import ExitStack + +import psutil + +from zetta_utils import builder, log +from zetta_utils.mazepa import SemaphoreType, Task, configure_semaphores +from zetta_utils.mazepa.task_outcome import OutcomeReport +from zetta_utils.message_queues import SQSQueue + +logger = log.get_logger("mazepa") + + +def detach_from_process_group() -> None: + os.setpgrp() + + +def get_local_worker_command( + task_queue_name: str, outcome_queue_name: str, local: bool = True, sleep_sec: float = 0.1 +): + if local: + task_queue_spec = { + "@type": "FileQueue", + "name": "./FileQueue_" + task_queue_name, + } + outcome_queue_spec = { + "@type": "FileQueue", + "name": "./FileQueue_" + outcome_queue_name, + "pull_wait_sec": 1.0, + } + else: + task_queue_spec = { + "@type": "SQSQueue", + "name": task_queue_name, + } + outcome_queue_spec = { + "@type": "SQSQueue", + "name": outcome_queue_name, + "pull_wait_sec": 1.0, + } + + result = ( + """ + zetta -vv -l try run -s '{ + "@type": "mazepa.run_worker" + """ + + f"task_queue: {json.dumps(task_queue_spec)}\n" + + f"outcome_queue: {json.dumps(outcome_queue_spec)}\n" + + f"sleep_sec: {sleep_sec}\n" + + """ + max_pull_num: 1 + }' + """ + ) + return result + + +@contextlib.contextmanager +def setup_local_worker_pool( + num_procs: int, + task_queue_name: str, + outcome_queue_name: str, + local: bool = True, + sleep_sec: float = 0.1, +): + """ + Context manager for creating task/outcome queues, alongside a persistent pool of workers. + """ + worker_procs = [] + with tempfile.TemporaryFile() as iofile: + try: + worker_procs = [ + psutil.Process( + subprocess.Popen( # pylint: disable=subprocess-popen-preexec-fn + get_local_worker_command( + task_queue_name, outcome_queue_name, local=local, sleep_sec=sleep_sec + ), + shell=True, + stdin=iofile, + stdout=iofile, + stderr=iofile, + preexec_fn=detach_from_process_group, + ).pid + ) + for _ in range(num_procs) + ] + logger.info( + f"Created {num_procs} local workers attached to queues " + f"`{task_queue_name}`/`{outcome_queue_name}`." + ) + yield + finally: + for proc in worker_procs: + for pid in proc.children(recursive=True): + pid.kill() + proc.kill() + logger.info( + f"Cleaned up {num_procs} local workers that were attached to queues " + f"`{task_queue_name}`/`{outcome_queue_name}`." + ) + + +@builder.register("mazepa.run_worker_manager") +def run_worker_manager( + task_queue: SQSQueue[Task], + outcome_queue: SQSQueue[OutcomeReport], + sleep_sec: float = 1.0, + num_procs: int = 1, + semaphores_spec: dict[SemaphoreType, int] | None = None, +): + with ExitStack() as stack: + stack.enter_context(configure_semaphores(semaphores_spec)) + stack.enter_context( + setup_local_worker_pool( + num_procs, task_queue.name, outcome_queue.name, local=False, sleep_sec=sleep_sec + ) + ) + while True: + time.sleep(1) diff --git a/zetta_utils/mazepa_layer_processing/alignment/aced_relaxation_flow.py b/zetta_utils/mazepa_layer_processing/alignment/aced_relaxation_flow.py index 1704a3e52..00c69102e 100644 --- a/zetta_utils/mazepa_layer_processing/alignment/aced_relaxation_flow.py +++ b/zetta_utils/mazepa_layer_processing/alignment/aced_relaxation_flow.py @@ -3,6 +3,7 @@ from typing import Literal, Optional, Sequence import attrs +import torch from zetta_utils import alignment, builder, mazepa, tensor_ops from zetta_utils.geometry import BBox3D, Vec3D @@ -12,6 +13,7 @@ VolumetricLayer, VolumetricLayerSet, ) +from zetta_utils.mazepa.semaphores import semaphore from ..common import build_chunked_volumetric_callable_flow_schema @@ -67,16 +69,19 @@ def __call__( max_dist: int, ): idx_padded = idx.padded(self.crop_pad) - tissue_mask_data = tissue_mask[idx_padded] - if (tissue_mask_data != 0).sum() > 0: - result = alignment.aced_relaxation.get_aced_match_offsets( - tissue_mask=tissue_mask_data, - misalignment_masks={k: v[idx_padded] for k, v in misalignment_masks.items()}, - pairwise_fields={k: v[idx_padded] for k, v in pairwise_fields.items()}, - pairwise_fields_inv={k: v[idx_padded] for k, v in pairwise_fields_inv.items()}, - max_dist=max_dist, - ) + with semaphore("read"): + tissue_mask_data = tissue_mask[idx_padded] + with semaphore("cuda"): + if (tissue_mask_data != 0).sum() > 0: + result = alignment.aced_relaxation.get_aced_match_offsets( + tissue_mask=tissue_mask_data, + misalignment_masks={k: v[idx_padded] for k, v in misalignment_masks.items()}, + pairwise_fields={k: v[idx_padded] for k, v in pairwise_fields.items()}, + pairwise_fields_inv={k: v[idx_padded] for k, v in pairwise_fields_inv.items()}, + max_dist=max_dist, + ) result = {k: tensor_ops.crop(v, self.crop_pad) for k, v in result.items()} + with semaphore("write"): dst[idx] = result @@ -112,30 +117,35 @@ def __call__( first_section_idx_padded = idx_padded.translated_end((0, 0, 1 - idx_padded.shape[-1])) last_section_idx_padded = idx_padded.translated_start((0, 0, idx_padded.shape[-1] - 1)) - match_offsets_data = match_offsets[idx_padded] + with semaphore("read"): + match_offsets_data = match_offsets[idx_padded] if (match_offsets_data != 0).sum() > 0: - result = alignment.aced_relaxation.perform_aced_relaxation( - match_offsets=match_offsets_data, - pfields={k: v[idx_padded] for k, v in pfields.items()}, - rigidity_masks=rigidity_masks[idx_padded] if rigidity_masks else None, - first_section_fix_field=( - first_section_fix_field[first_section_idx_padded] - if first_section_fix_field - else None - ), - last_section_fix_field=( - last_section_fix_field[last_section_idx_padded] - if last_section_fix_field - else None - ), - num_iter=num_iter, - lr=lr, - rigidity_weight=rigidity_weight, - fix=fix, - ) - result_cropped = tensor_ops.crop(result, self.crop_pad) - dst[idx] = result_cropped + with semaphore("cuda"): + result = alignment.aced_relaxation.perform_aced_relaxation( + match_offsets=match_offsets_data, + pfields={k: v[idx_padded] for k, v in pfields.items()}, + rigidity_masks=rigidity_masks[idx_padded] if rigidity_masks else None, + first_section_fix_field=( + first_section_fix_field[first_section_idx_padded] + if first_section_fix_field + else None + ), + last_section_fix_field=( + last_section_fix_field[last_section_idx_padded] + if last_section_fix_field + else None + ), + num_iter=num_iter, + lr=lr, + rigidity_weight=rigidity_weight, + fix=fix, + ) + result_cropped = tensor_ops.crop(result, self.crop_pad) + torch.cuda.empty_cache() + + with semaphore("write"): + dst[idx] = result_cropped @builder.register("build_aced_relaxation_flow") diff --git a/zetta_utils/mazepa_layer_processing/alignment/compute_field_flow.py b/zetta_utils/mazepa_layer_processing/alignment/compute_field_flow.py index 6bc1a53ea..600e63a9d 100644 --- a/zetta_utils/mazepa_layer_processing/alignment/compute_field_flow.py +++ b/zetta_utils/mazepa_layer_processing/alignment/compute_field_flow.py @@ -15,6 +15,7 @@ VolumetricIndexTranslator, VolumetricLayer, ) +from zetta_utils.mazepa import semaphore from zetta_utils.mazepa_layer_processing.alignment.common import ( translation_adjusted_download, ) @@ -77,20 +78,22 @@ def __call__( src_field: Optional[VolumetricLayer], tgt_field: Optional[VolumetricLayer], ): - idx_input = copy.deepcopy(idx) - idx_input.resolution = self.get_input_resolution(idx.resolution) - idx_input_padded = idx_input.padded(self.crop_pad) - - src_data, src_field_data, src_translation = translation_adjusted_download( - src=src, - field=src_field, - idx=idx_input_padded, - ) - if src_data.abs().sum() > 0: - tgt_data, tgt_field_data, _ = translation_adjusted_download( - src=tgt, field=tgt_field, idx=idx_input_padded + with semaphore("read"): + idx_input = copy.deepcopy(idx) + idx_input.resolution = self.get_input_resolution(idx.resolution) + idx_input_padded = idx_input.padded(self.crop_pad) + + src_data, src_field_data, src_translation = translation_adjusted_download( + src=src, + field=src_field, + idx=idx_input_padded, ) + if src_data.abs().sum() > 0: + tgt_data, tgt_field_data, _ = translation_adjusted_download( + src=tgt, field=tgt_field, idx=idx_input_padded + ) + with semaphore("cpu"): if tgt_field_data is not None: tgt_field_data_zcxy = einops.rearrange(tgt_field_data, "C X Y Z -> Z C X Y") tgt_data_zcxy = einops.rearrange(tgt_data, "C X Y Z -> Z C X Y") @@ -109,13 +112,16 @@ def __call__( else: tgt_data_final = tgt_data + with semaphore("cuda"): result_raw = self.fn( src=src_data, tgt=tgt_data_final, src_field=src_field_data, ) result = tensor_ops.crop(result_raw, crop=self.output_crop_px) + torch.cuda.empty_cache() + with semaphore("write"): result[0] += src_translation[0] result[1] += src_translation[1] dst[idx] = result diff --git a/zetta_utils/mazepa_layer_processing/alignment/warp_operation.py b/zetta_utils/mazepa_layer_processing/alignment/warp_operation.py index b9aedaacd..dbc981e4c 100644 --- a/zetta_utils/mazepa_layer_processing/alignment/warp_operation.py +++ b/zetta_utils/mazepa_layer_processing/alignment/warp_operation.py @@ -7,9 +7,9 @@ import torchfields # pylint: disable=unused-import # monkeypatch from zetta_utils import builder, mazepa, tensor_ops -from zetta_utils.common import semaphore from zetta_utils.geometry import Vec3D from zetta_utils.layer.volumetric import VolumetricIndex, VolumetricLayer +from zetta_utils.mazepa import semaphore from zetta_utils.mazepa_layer_processing.alignment.common import ( translation_adjusted_download, ) diff --git a/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py index 64f8679a4..116e78097 100644 --- a/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py @@ -24,12 +24,12 @@ from typing_extensions import ParamSpec from zetta_utils import builder, log, mazepa -from zetta_utils.common import SemaphoreType from zetta_utils.common.pprint import lrpad, utcnow_ISO8601 from zetta_utils.geometry import BBox3D, Vec3D from zetta_utils.layer.volumetric import VolumetricBasedLayerProtocol, VolumetricIndex from zetta_utils.layer.volumetric.cloudvol.build import build_cv_layer from zetta_utils.layer.volumetric.tensorstore.build import build_ts_layer +from zetta_utils.mazepa import SemaphoreType from zetta_utils.ng.link_builder import make_ng_link from zetta_utils.typing import ensure_seq_of_seq @@ -55,7 +55,6 @@ class DelegatedSubchunkedOperation(Generic[P]): flow_schema: VolumetricApplyFlowSchema[P, None] operation_name: str level: int - parallel_if_pool_exists: bool def get_input_resolution( # pylint: disable=no-self-use self, dst_resolution: Vec3D @@ -75,7 +74,6 @@ def __call__( mazepa.Executor( do_dryrun_estimation=False, show_progress=False, - parallel_if_pool_exists=self.parallel_if_pool_exists, )(self.flow_schema(idx, dst, op_args, op_kwargs)) @@ -874,7 +872,6 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg flow_schema, op_name, level, - parallel_if_pool_exists=(level == 1), ), processing_chunk_size=processing_chunk_sizes[-level - 1], max_reduction_chunk_size=max_reduction_chunk_sizes[-level - 1], diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index 7b2e1ab78..674c815af 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -12,13 +12,13 @@ from typing_extensions import ParamSpec from zetta_utils import log, mazepa -from zetta_utils.common import semaphore from zetta_utils.geometry import Vec3D from zetta_utils.layer.volumetric import ( VolumetricBasedLayerProtocol, VolumetricIndex, VolumetricIndexChunker, ) +from zetta_utils.mazepa import semaphore from ..operation_protocols import VolumetricOpProtocol diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py b/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py index 656caf9b6..105931e7d 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_callable_operation.py @@ -10,10 +10,10 @@ from typing_extensions import ParamSpec from zetta_utils import builder, mazepa, tensor_ops -from zetta_utils.common import SemaphoreType, semaphore from zetta_utils.geometry import Vec3D from zetta_utils.layer import IndexChunker from zetta_utils.layer.volumetric import VolumetricIndex, VolumetricLayer +from zetta_utils.mazepa import SemaphoreType, semaphore from . import ChunkedApplyFlowSchema from .callable_operation import _process_callable_kwargs @@ -93,6 +93,7 @@ def __call__( # pylint: disable=keyword-arg-before-vararg for semaphore_type in self.fn_semaphores: semaphore_stack.enter_context(semaphore(semaphore_type)) result_raw = self.fn(**task_kwargs) + torch.cuda.empty_cache() # Data crop amount is determined by the index pad and the # difference between the resolutions of idx and dst_idx. # Padding was applied before the first read processor, so cropping diff --git a/zetta_utils/message_queues/__init__.py b/zetta_utils/message_queues/__init__.py index 34f941d95..10fdaf971 100644 --- a/zetta_utils/message_queues/__init__.py +++ b/zetta_utils/message_queues/__init__.py @@ -1,3 +1,4 @@ -from .base import ReceivedMessage, MessageQueue +from .base import ReceivedMessage, MessageQueue, TQTask from . import serialization -from . import sqs +from .file import FileQueue +from .sqs import SQSQueue diff --git a/zetta_utils/message_queues/base.py b/zetta_utils/message_queues/base.py index 74c3f9111..c0b2fc7fe 100644 --- a/zetta_utils/message_queues/base.py +++ b/zetta_utils/message_queues/base.py @@ -2,6 +2,7 @@ from typing import Callable, Generic, Sequence, TypeVar import attrs +import taskqueue T = TypeVar("T") @@ -10,6 +11,20 @@ def return_none() -> None: # pragma: no cover return None +class TQTask(taskqueue.RegisteredTask): + """ + Wrapper that makes Mazepa tasks submittable with `python-task-queue`. + """ + + def __init__(self, task_ser: str): + super().__init__( + task_ser=task_ser, + ) + + def execute(self): # pragma: no cover + raise NotImplementedError() + + @attrs.frozen class ReceivedMessage(Generic[T]): """ diff --git a/zetta_utils/message_queues/file/__init__.py b/zetta_utils/message_queues/file/__init__.py new file mode 100644 index 000000000..7c37ce4e0 --- /dev/null +++ b/zetta_utils/message_queues/file/__init__.py @@ -0,0 +1,2 @@ +from . import queue +from .queue import FileQueue diff --git a/zetta_utils/message_queues/file/queue.py b/zetta_utils/message_queues/file/queue.py new file mode 100644 index 000000000..ac33316d7 --- /dev/null +++ b/zetta_utils/message_queues/file/queue.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import os +import shutil +from typing import Any, Sequence, TypeVar + +import attrs +import taskqueue +from typeguard import typechecked + +from zetta_utils import builder +from zetta_utils.common.partial import ComparablePartial +from zetta_utils.common.path import abspath, strip_prefix +from zetta_utils.log import get_logger +from zetta_utils.message_queues.base import MessageQueue + +from .. import ReceivedMessage, TQTask, serialization + +logger = get_logger("zetta_utils") +T = TypeVar("T") + + +@builder.register("FileQueue") +@typechecked +@attrs.mutable +class FileQueue(MessageQueue[T]): + name: str + _queue: Any = attrs.field(init=False, default=None) + pull_wait_sec: float = 0.5 + pull_lease_sec: int = 10 # TODO: get a better value + + def _check_name_no_prefix(self) -> None: + if self.name != strip_prefix(self.name): + raise ValueError( + "Invalid FileQueue name: Name cannot contain a prefix with `://`; " + f"received `{self.name}`" + ) + + def __enter__(self) -> FileQueue: + self._check_name_no_prefix() + queue_path = abspath(f"fq://{self.name}") + queue_folder_path = strip_prefix(queue_path) + if os.path.exists(queue_folder_path): + raise RuntimeError( + "Could not create FileQueue: " f"{queue_folder_path} already exists." + ) + self._queue = taskqueue.TaskQueue(queue_path) + logger.info(f"Initialised FileQueue at `{queue_folder_path}`.") + return self + + def __exit__(self, *args) -> None: + queue_path = abspath(f"fq://{self.name}") + queue_folder_path = strip_prefix(queue_path) + shutil.rmtree(queue_folder_path) + logger.info(f"Cleaned up FileQueue at `{queue_folder_path}`.") + + def _get_tq_queue(self) -> Any: + if self._queue is None: + self._check_name_no_prefix() + self._queue = taskqueue.TaskQueue(abspath(f"fq://{self.name}")) + return self._queue + + def push(self, payloads: Sequence[T]) -> None: + if len(payloads) > 0: + tq_tasks = [] + for e in payloads: + tq_task = TQTask(serialization.serialize(e)) + tq_tasks.append(tq_task) + self._get_tq_queue().insert(tq_tasks) + + def _delete_task(self, task_id: str) -> None: + self._get_tq_queue().delete(task_id) + + def _extend_task_lease(self, duration_sec: int, task: TQTask): + self._get_tq_queue().renew(task, duration_sec) + + def pull(self, max_num: int = 500) -> list[ReceivedMessage[T]]: + results: list[ReceivedMessage[T]] = [] + try: + lease_result = self._get_tq_queue().lease( + num_tasks=max_num, + seconds=self.pull_lease_sec, + wait_sec=self.pull_wait_sec, + ) + tasks = [lease_result] if isinstance(lease_result, TQTask) else lease_result + except taskqueue.QueueEmptyError: + return results + + for task in tasks: + # Deserialize task object + payload = serialization.deserialize(task.task_ser) + acknowledge_fn = ComparablePartial( + self._delete_task, + task_id=task.id, + ) + + extend_lease_fn = ComparablePartial( + self._extend_task_lease, + task=task, + ) + + result = ReceivedMessage[T]( + payload=payload, + approx_receive_count=1, + acknowledge_fn=acknowledge_fn, + extend_lease_fn=extend_lease_fn, + ) + + results.append(result) + return results diff --git a/zetta_utils/message_queues/sqs/queue.py b/zetta_utils/message_queues/sqs/queue.py index bf9833b9b..d4aebb84e 100644 --- a/zetta_utils/message_queues/sqs/queue.py +++ b/zetta_utils/message_queues/sqs/queue.py @@ -11,24 +11,10 @@ from zetta_utils.common.partial import ComparablePartial from zetta_utils.message_queues.base import MessageQueue -from .. import ReceivedMessage, serialization +from .. import ReceivedMessage, TQTask, serialization from . import utils -class TQTask(taskqueue.RegisteredTask): - """ - Wrapper that makes Mazepa tasks submittable with `python-task-queue`. - """ - - def __init__(self, task_ser: str): - super().__init__( - task_ser=task_ser, - ) - - def execute(self): # pragma: no cover - raise NotImplementedError() - - def _delete_task_message( receipt_handle: str, queue_name: str,