From 23c69224eccc30ec0fafe88337973f60424e9707 Mon Sep 17 00:00:00 2001 From: Nicolas Granger Date: Sat, 16 Sep 2023 17:44:39 +0200 Subject: [PATCH] set process method to spawn Slower but prevents deadlock when OpenMP is used. Also improved dataloader example to handle torch.Tensor better. --- docs/examples/dataloader.py | 18 ++++-- seqtools/evaluation.py | 2 +- tests/test_evaluation.py | 116 ++++++++++++++++++------------------ 3 files changed, 71 insertions(+), 65 deletions(-) diff --git a/docs/examples/dataloader.py b/docs/examples/dataloader.py index c1da39c..d850a81 100644 --- a/docs/examples/dataloader.py +++ b/docs/examples/dataloader.py @@ -2,6 +2,7 @@ import copyreg import numbers +from functools import partial import torch @@ -10,10 +11,13 @@ # overload torch.Tensor pickling to benefit from zero copy on buffer def pickle_tensor(t: torch.Tensor): - return torch.from_numpy, (t.numpy(),) + return torch.from_numpy, (t.contiguous().numpy(),) -copyreg.pickle(torch.Tensor, pickle_tensor) +def worker_init_fn_wrapper(user_fn, *kargs, **kwargs): + copyreg.pickle(torch.Tensor, pickle_tensor) + if user_fn is not None: + user_fn(*kargs, **kwargs) def pin_tensors_memory(value): @@ -45,6 +49,10 @@ def default_collate_fn(values): ) +def gather_items(a, items): + return [a[i] for i in items] + + class DataLoader: def __init__( self, @@ -84,7 +92,7 @@ def __init__( self.collate_fn = collate_fn or default_collate_fn self.pin_memory = pin_memory self.drop_last = drop_last - self.worker_init_fn = worker_init_fn + self.worker_init_fn = partial(worker_init_fn_wrapper, worker_init_fn) self.prefetch_factor = prefetch_factor self.shm_size = shm_size @@ -103,7 +111,7 @@ def make_sequence(self): # shuffling if self.batch_sampler: batch_indices = list(self.batch_sampler) - out = seqtools.smap(lambda bi: [self.dataset[i] for i in bi], batch_indices) + out = seqtools.smap(partial(gather_items, self.dataset), batch_indices) elif self.sampler: shuffle_indices = list(self.sampler) out = seqtools.gather(self.dataset, shuffle_indices) @@ -128,7 +136,7 @@ def make_sequence(self): if self.num_workers > 0: out = seqtools.prefetch( out, - max_buffered=self.num_workers * self.prefetch_factor, + max_buffered=max(4, self.num_workers * self.prefetch_factor), nworkers=self.num_workers, method="process", start_hook=self.worker_init_fn, diff --git a/seqtools/evaluation.py b/seqtools/evaluation.py index 1e4bbed..8bafd97 100644 --- a/seqtools/evaluation.py +++ b/seqtools/evaluation.py @@ -23,7 +23,6 @@ pickling_support.install() logger = get_logger(__name__) -mp_ctx = multiprocessing.get_context() # Asynchronous item fetching backends ----------------------------------------- @@ -64,6 +63,7 @@ def __init__(self, seq, num_workers=0, buffer_size=10, init_fn=None, shm_size=0) self.free_shm_slots = set() # initialize workers + mp_ctx = multiprocessing.get_context(method="spawn") # spawn is OpenMP-friendly self.job_queue = mp_ctx.Queue() self.result_pipes = [] diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index 483460b..4d76faa 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -6,6 +6,7 @@ import string import tempfile import threading +from functools import partial from multiprocessing import Process from time import sleep, time @@ -13,8 +14,7 @@ import pytest from numpy.testing import assert_array_equal -from seqtools import EvaluationError, prefetch, seterr, smap, repeat - +from seqtools import EvaluationError, prefetch, repeat, seterr, smap logging.basicConfig(level=logging.DEBUG) seed = int(random.random() * 100000) @@ -110,6 +110,8 @@ def test_prefetch_infinite(prefetch_kwargs): tls = None + + def set_seed(*kargs): global tls tls = threading.local() @@ -132,17 +134,18 @@ def test_prefetch_start_hook(prefetch_kwargs): compare_random_objects(y, z) +def sleep_and_return(x): + sleep(0.005 * (1 + random.random())) + return x + + @pytest.mark.parametrize("prefetch_kwargs", prefetch_kwargs_set) @pytest.mark.timeout(15) def test_prefetch_timings(prefetch_kwargs): - def f1(x): - sleep(0.005 * (1 + random.random())) - return x - start_hook = random.seed arr = np.random.rand(100, 10) - y = smap(f1, arr) + y = smap(sleep_and_return, arr) y = prefetch( y, nworkers=4, max_buffered=10, start_hook=start_hook, **prefetch_kwargs ) @@ -151,14 +154,14 @@ def f1(x): # overly large buffer arr = np.random.rand(10, 10) - y = smap(f1, arr) + y = smap(sleep_and_return, arr) y = prefetch(y, nworkers=4, max_buffered=50, **prefetch_kwargs) y = [y_.copy() for y_ in y] assert_array_equal(np.stack(y), arr) # multiple restarts arr = np.random.rand(100, 10) - y = smap(f1, arr) + y = smap(sleep_and_return, arr) y = prefetch(y, nworkers=4, max_buffered=10, **prefetch_kwargs) for _ in range(10): n = np.random.randint(0, 99) @@ -200,48 +203,48 @@ def f1(x): assert duration < 4.5 +def raise_if_none(x, picklable_err): + class CustomError(Exception): + pass + + if x is None: + if picklable_err: + raise ValueError() + else: + raise CustomError() + + return x + + @pytest.mark.parametrize("error_mode", ["wrap", "passthrough"]) @pytest.mark.parametrize("prefetch_kwargs", prefetch_kwargs_set) @pytest.mark.parametrize("picklable_err", [False, True]) @pytest.mark.timeout(10) def test_prefetch_errors(error_mode, prefetch_kwargs, picklable_err): - class CustomError(Exception): - pass - - def f1(x): - if x is None: - raise ValueError("blablabla") if picklable_err else CustomError() - else: - return x + seterr(error_mode) arr1 = [np.random.rand(10), np.random.rand(10), np.random.rand(10), None] - arr2 = smap(f1, arr1) + arr2 = smap(partial(raise_if_none, picklable_err=picklable_err), arr1) y = prefetch(arr2, nworkers=2, max_buffered=4, **prefetch_kwargs) - seterr(error_mode) - if ( - prefetch_kwargs["method"] != "thread" and not picklable_err - ) or error_mode == "wrap": - error_t = EvaluationError - else: - error_t = ValueError if picklable_err else CustomError + if error_mode == "wrap": + expected_err = "EvaluationError" + else: # passthrough + if prefetch_kwargs["method"] == "process" and not picklable_err: + expected_err = "EvaluationError" + elif picklable_err: + expected_err = "ValueError" + else: + expected_err = "CustomError" for i in range(3): assert_array_equal(y[i], arr1[i]) try: a = y[3] except Exception as e: - assert type(e) == error_t - - if (prefetch_kwargs["method"] == "process") and error_mode == "passthrough": - - class CustomObject: # unpicklable object - pass - - arr1 = [np.random.rand(10), CustomObject(), np.random.rand(10)] - y = prefetch(arr1, nworkers=2, max_buffered=4, **prefetch_kwargs) - with pytest.raises(ValueError): - y[1] + assert type(e).__name__ == expected_err + else: + assert False, "Should have raised" def check_pid(pid): @@ -254,6 +257,11 @@ def check_pid(pid): return True +def write_pid_file(directory, worker_id): + with open("{}/{}".format(directory, os.getpid()), "w"): + pass + + @pytest.mark.timeout(10) def test_worker_crash(): if platform.python_implementation() == "PyPy": @@ -261,23 +269,16 @@ def test_worker_crash(): # worker dies with tempfile.TemporaryDirectory() as d: - - def init_fn(worker_id): - with open("{}/{}".format(d, os.getpid()), "w"): - pass - - def f1(x): - sleep(0.02 + 0.01 * (random.random() - 0.5)) - return x - arr = np.random.rand(1000, 10) - y = smap(f1, arr) + y = smap(sleep_and_return, arr) y = prefetch( - y, method="process", max_buffered=40, nworkers=4, start_hook=init_fn + y, + method="process", + max_buffered=40, + nworkers=4, + start_hook=partial(write_pid_file, d), ) - sleep(0.1) - while len(os.listdir(d)) == 0: sleep(0.05) @@ -295,18 +296,15 @@ def test_orphan_workers_die(): with tempfile.TemporaryDirectory() as d: - def init_fn(worker_id): - with open("{}/{}".format(d, os.getpid()), "w"): - pass - - def f1(x): - return x - def target(): arr = np.random.rand(1000, 10) - y = smap(f1, arr) + y = smap(sleep_and_return, arr) y = prefetch( - y, method="process", max_buffered=4, nworkers=4, start_hook=init_fn + y, + method="process", + max_buffered=4, + nworkers=4, + start_hook=partial(write_pid_file, d), ) for i in range(0, 1000): @@ -321,7 +319,7 @@ def target(): sleep(0.05) os.kill(p.pid, signal.SIGKILL) # parent process crashes - + sleep(3) # wait for workers to time out for pid in map(int, os.listdir(d)):