Skip to content

Commit

Permalink
feat: add LocalQueue and local pool, refactor semaphores
Browse files Browse the repository at this point in the history
  • Loading branch information
dodamih committed Oct 31, 2023
1 parent 3147f19 commit 9576650
Show file tree
Hide file tree
Showing 33 changed files with 483 additions and 291 deletions.
9 changes: 5 additions & 4 deletions specs/dodam/warp_zfish_subchunkable_example.cue
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
}

start_coord: [4096 * 3, 4096 * 4, 3003]
end_coord: [12288 * 2, 12288 * 2, 3011]
end_coord: [12288 * 2, 12288 * 2, 3006]
coord_resolution: [16, 16, 30]

dst_resolution: [32, 32, 30]
Expand All @@ -24,7 +24,7 @@
//processing_crop_pads: [[0, 0, 0], [16, 16, 0], [32, 32, 0]]
//processing_blend_pads: [[0, 0, 0], [0, 0, 0], [16, 16, 0]]
//level_intermediaries_dirs: [#TEMP_PATH2, #TEMP_PATH1, #TEMP_PATH0]
processing_chunk_sizes: [[8192, 8192, 1], [2048, 2048, 1]]
processing_chunk_sizes: [[1024, 1024, 1], [512, 512, 1]]
processing_crop_pads: [[0, 0, 0], [16, 16, 0]]
processing_blend_pads: [[0, 0, 0], [16, 16, 0]]
processing_blend_modes: "quadratic"
Expand Down Expand Up @@ -62,10 +62,11 @@
"@type": "mazepa.execute_locally"
target:
#FLOW_TMPL
num_procs: 4
num_procs: 32
semaphores_spec: {
read: 4
write: 4
cuda: 4
cuda: 1
cpu: 4
}
debug: true
43 changes: 0 additions & 43 deletions tests/unit/common/test_multiprocessing.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import List

import posix_ipc
import psutil
import pytest

from zetta_utils.common.semaphores import (
from zetta_utils.mazepa.semaphores import (
DummySemaphore,
SemaphoreType,
configure_semaphores,
Expand All @@ -21,7 +22,7 @@ def cleanup_semaphores():
sema_types: List[SemaphoreType] = ["read", "write", "cuda", "cpu"]
for name in sema_types:
try:
# two unlinks in case parent semaphore exists
# two unlinks in case grandparent semaphore exists
semaphore(name).unlink()
semaphore(name).unlink()
except:
Expand Down Expand Up @@ -65,13 +66,14 @@ def test_unlink_nonexistent_exc():
# exception on exiting context


def test_get_parent_semaphore():
def test_get_grandparent_semaphore():
grandpa_pid = psutil.Process(os.getppid()).ppid()
try:
sema = posix_ipc.Semaphore(name_to_posix_name("read", os.getppid()))
sema = posix_ipc.Semaphore(name_to_posix_name("read", grandpa_pid))
sema.unlink()
except:
pass
sema = posix_ipc.Semaphore(name_to_posix_name("read", os.getppid()), flags=posix_ipc.O_CREX)
sema = posix_ipc.Semaphore(name_to_posix_name("read", grandpa_pid), flags=posix_ipc.O_CREX)
assert sema.name == semaphore("read").name
sema.unlink()

Expand Down
Empty file.
64 changes: 64 additions & 0 deletions tests/unit/message_queues/local/test_local_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import time

import pytest

from zetta_utils.message_queues.local.queue import LocalQueue


def success_fn():
return "Success"


def test_make_and_delete_local_queue():
with LocalQueue("test_queue"):
pass


def test_get_tq_queue():
with LocalQueue("test_queue"):
LocalQueue("test_queue")._get_tq_queue() # pylint:disable = protected-access


def test_push_pull():
with LocalQueue("test_queue") as q:
payloads = {None, 1, "asdfadsfdsa", success_fn}
q.push(list(payloads))
time.sleep(0.1)
result = q.pull(max_num=len(payloads))
assert len(result) == len(payloads)
received_payloads = {r.payload for r in result}
assert received_payloads == payloads


def test_delete():
with LocalQueue("test_queue") as q:
q.push([None])
time.sleep(0.1)
result = q.pull(max_num=10)
assert len(result) == 1
result[0].acknowledge_fn()
time.sleep(1.1)
result_empty = q.pull()
assert len(result_empty) == 0


def test_extend_lease():
with LocalQueue("test_queue") as q:
q.push([None])
time.sleep(0.1)
result = q.pull()
assert len(result) == 1
result[0].extend_lease_fn(3)
time.sleep(1)
result_empty = q.pull()
assert len(result_empty) == 0
time.sleep(2.1)
result_nonempty = q.pull()
assert len(result_nonempty) == 1


def test_double_init_exc():
with pytest.raises(RuntimeError):
with LocalQueue("test_queue"):
with LocalQueue("test_queue"):
pass
Empty file.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from kubernetes import client as k8s_client # type: ignore
from zetta_utils import builder, log
from zetta_utils.common import SemaphoreType
from zetta_utils.mazepa import SemaphoreType

from .eks import eks_cluster_data
from .gke import gke_cluster_data
Expand Down Expand Up @@ -68,7 +68,7 @@ def get_mazepa_worker_command(
result = (
"""
zetta -vv -l try run -s '{
"@type": "mazepa.run_worker"
"@type": "mazepa.run_worker_manager"
"""
+ f"task_queue: {json.dumps(task_queue_spec)}\n"
+ f"outcome_queue: {json.dumps(outcome_queue_spec)}\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from kubernetes import client as k8s_client # type: ignore
from zetta_utils import builder, log
from zetta_utils.common import SemaphoreType
from zetta_utils.mazepa import SemaphoreType

from ..resource_tracker import (
ExecutionResource,
Expand Down
2 changes: 0 additions & 2 deletions zetta_utils/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,3 @@
from .pprint import lrpad
from .signal_handlers import custom_signal_handler_ctx
from .timer import RepeatTimer
from .semaphores import SemaphoreType, configure_semaphores, semaphore
from .multiprocessing import setup_persistent_process_pool, get_persistent_process_pool
48 changes: 0 additions & 48 deletions zetta_utils/common/multiprocessing.py

This file was deleted.

9 changes: 7 additions & 2 deletions zetta_utils/common/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@ def abspath(path: str) -> str:
path_no_prefix = split[-1]
if len(prefixes) == 0:
prefixes = ["file"]
if prefixes == ["file"]:
if prefixes in (["file"], ["fq"]):
path_no_prefix = os.path.abspath(os.path.expanduser(path_no_prefix))
return "://".join(prefixes + [path_no_prefix])


def strip_prefix(path: str) -> str: # pragma: no cover
return path.split("://")[-1]


def is_local(path: str) -> bool: # pragma: no cover
return abspath(path).startswith("file://")
local_prefixes = ["file://", "fq://"]
return any(abspath(path).startswith(local_prefix) for local_prefix in local_prefixes)
2 changes: 1 addition & 1 deletion zetta_utils/convnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _load_model(
return result


_load_model_cached = cachetools.cached(cachetools.LRUCache(maxsize=8))(_load_model)
_load_model_cached = cachetools.cached(cachetools.LRUCache(maxsize=2))(_load_model)


@typechecked
Expand Down
3 changes: 1 addition & 2 deletions zetta_utils/layer/volumetric/cloudvol/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _set_cv_defaults(self):
self.cv_kwargs.setdefault("progress", False)
self.cv_kwargs.setdefault("autocrop", False)
self.cv_kwargs.setdefault("non_aligned_writes", False)
self.cv_kwargs.setdefault("cache", not self.is_local)
self.cv_kwargs.setdefault("cache", False)
self.cv_kwargs.setdefault("compress_cache", False)
self.cv_kwargs.setdefault("compress", True)
self.cv_kwargs.setdefault("cdn_cache", False)
Expand Down Expand Up @@ -195,7 +195,6 @@ def clear_disk_cache(self) -> None: # pragma: no cover

def clear_cache(self) -> None: # pragma: no cover
_clear_cv_cache(self.path)
self.clear_disk_cache()

def read(self, idx: VolumetricIndex) -> torch.Tensor:
# Data out: cxyz
Expand Down
1 change: 1 addition & 0 deletions zetta_utils/mazepa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@
from .progress_tracker import progress_ctx_mngr
from .execution import Executor, execute
from .worker import run_worker
from .semaphores import SemaphoreType, configure_semaphores, semaphore
26 changes: 3 additions & 23 deletions zetta_utils/mazepa/autoexecute_task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typeguard import typechecked

from zetta_utils import log
from zetta_utils.common import get_persistent_process_pool
from zetta_utils.mazepa.worker import process_task_message
from zetta_utils.message_queues.base import MessageQueue, ReceivedMessage

Expand All @@ -23,7 +22,6 @@ class AutoexecuteTaskQueue(MessageQueue):
tasks_todo: list[Task] = attrs.field(init=False, factory=list)
debug: bool = False
handle_exceptions: bool = False
parallel_if_pool_exists: bool = False

def push(self, payloads: Iterable[Task]):
# TODO: Fix progress bar issue with multiple live displays in rich
Expand All @@ -40,27 +38,9 @@ def pull(
if len(self.tasks_todo) == 0:
return []
else:
pool = get_persistent_process_pool()
if not self.parallel_if_pool_exists or pool is None:
results: list[ReceivedMessage[OutcomeReport]] = []
for task in self.tasks_todo[:max_num]:
results.append(execute_task(task, self.debug, self.handle_exceptions))
# TODO: remove monkey patching from builder so that unit tests work;
# pickle does not handle monkey patched objects inside Python
else: # pragma: no cover
futures = []
for task in self.tasks_todo[:max_num]:
futures.append(
pool.schedule(
execute_task,
kwargs={
"task": task,
"debug": self.debug,
"handle_exceptions": self.handle_exceptions,
},
)
)
results = [future.result() for future in futures]
results: list[ReceivedMessage[OutcomeReport]] = []
for task in self.tasks_todo[:max_num]:
results.append(execute_task(task, self.debug, self.handle_exceptions))
self.tasks_todo = self.tasks_todo[max_num:]
return results

Expand Down
5 changes: 1 addition & 4 deletions zetta_utils/mazepa/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class Executor:
checkpoint: Optional[str] = None
checkpoint_interval_sec: Optional[float] = None
raise_on_failed_checkpoint: bool = True
parallel_if_pool_exists: bool = False

def __call__(self, target: Union[Task, Flow, ExecutionState, ComparablePartial, Callable]):
assert (self.task_queue is None and self.outcome_queue is None) or (
Expand All @@ -58,7 +57,6 @@ def __call__(self, target: Union[Task, Flow, ExecutionState, ComparablePartial,
checkpoint=self.checkpoint,
checkpoint_interval_sec=self.checkpoint_interval_sec,
raise_on_failed_checkpoint=self.raise_on_failed_checkpoint,
parallel_if_pool_exists=self.parallel_if_pool_exists,
)


Expand All @@ -77,7 +75,6 @@ def execute(
checkpoint: Optional[str] = None,
checkpoint_interval_sec: Optional[float] = 150,
raise_on_failed_checkpoint: bool = True,
parallel_if_pool_exists: bool = False,
):
"""
Executes a target until completion using the given execution queue.
Expand Down Expand Up @@ -122,7 +119,7 @@ def execute(
else:
assert outcome_queue is None
task_queue_ = AutoexecuteTaskQueue(
debug=True, parallel_if_pool_exists=parallel_if_pool_exists
debug=True,
)
outcome_queue_ = task_queue_

Expand Down
Loading

0 comments on commit 9576650

Please sign in to comment.