Skip to content

Commit

Permalink
set process method to spawn
Browse files Browse the repository at this point in the history
Slower but prevents deadlock when OpenMP is used.
Also improved dataloader example to handle torch.Tensor better.
  • Loading branch information
nlgranger committed Sep 16, 2023
1 parent 96abc22 commit 23c6922
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 65 deletions.
18 changes: 13 additions & 5 deletions docs/examples/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copyreg
import numbers
from functools import partial

import torch

Expand All @@ -10,10 +11,13 @@

# overload torch.Tensor pickling to benefit from zero copy on buffer
def pickle_tensor(t: torch.Tensor):
return torch.from_numpy, (t.numpy(),)
return torch.from_numpy, (t.contiguous().numpy(),)


copyreg.pickle(torch.Tensor, pickle_tensor)
def worker_init_fn_wrapper(user_fn, *kargs, **kwargs):
copyreg.pickle(torch.Tensor, pickle_tensor)
if user_fn is not None:
user_fn(*kargs, **kwargs)


def pin_tensors_memory(value):
Expand Down Expand Up @@ -45,6 +49,10 @@ def default_collate_fn(values):
)


def gather_items(a, items):
return [a[i] for i in items]


class DataLoader:
def __init__(
self,
Expand Down Expand Up @@ -84,7 +92,7 @@ def __init__(
self.collate_fn = collate_fn or default_collate_fn
self.pin_memory = pin_memory
self.drop_last = drop_last
self.worker_init_fn = worker_init_fn
self.worker_init_fn = partial(worker_init_fn_wrapper, worker_init_fn)
self.prefetch_factor = prefetch_factor
self.shm_size = shm_size

Expand All @@ -103,7 +111,7 @@ def make_sequence(self):
# shuffling
if self.batch_sampler:
batch_indices = list(self.batch_sampler)
out = seqtools.smap(lambda bi: [self.dataset[i] for i in bi], batch_indices)
out = seqtools.smap(partial(gather_items, self.dataset), batch_indices)
elif self.sampler:
shuffle_indices = list(self.sampler)
out = seqtools.gather(self.dataset, shuffle_indices)
Expand All @@ -128,7 +136,7 @@ def make_sequence(self):
if self.num_workers > 0:
out = seqtools.prefetch(
out,
max_buffered=self.num_workers * self.prefetch_factor,
max_buffered=max(4, self.num_workers * self.prefetch_factor),
nworkers=self.num_workers,
method="process",
start_hook=self.worker_init_fn,
Expand Down
2 changes: 1 addition & 1 deletion seqtools/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
pickling_support.install()

logger = get_logger(__name__)
mp_ctx = multiprocessing.get_context()


# Asynchronous item fetching backends -----------------------------------------
Expand Down Expand Up @@ -64,6 +63,7 @@ def __init__(self, seq, num_workers=0, buffer_size=10, init_fn=None, shm_size=0)
self.free_shm_slots = set()

# initialize workers
mp_ctx = multiprocessing.get_context(method="spawn") # spawn is OpenMP-friendly
self.job_queue = mp_ctx.Queue()
self.result_pipes = []

Expand Down
116 changes: 57 additions & 59 deletions tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import string
import tempfile
import threading
from functools import partial
from multiprocessing import Process
from time import sleep, time

import numpy as np
import pytest
from numpy.testing import assert_array_equal

from seqtools import EvaluationError, prefetch, seterr, smap, repeat

from seqtools import EvaluationError, prefetch, repeat, seterr, smap

logging.basicConfig(level=logging.DEBUG)
seed = int(random.random() * 100000)
Expand Down Expand Up @@ -110,6 +110,8 @@ def test_prefetch_infinite(prefetch_kwargs):


tls = None


def set_seed(*kargs):
global tls
tls = threading.local()
Expand All @@ -132,17 +134,18 @@ def test_prefetch_start_hook(prefetch_kwargs):
compare_random_objects(y, z)


def sleep_and_return(x):
sleep(0.005 * (1 + random.random()))
return x


@pytest.mark.parametrize("prefetch_kwargs", prefetch_kwargs_set)
@pytest.mark.timeout(15)
def test_prefetch_timings(prefetch_kwargs):
def f1(x):
sleep(0.005 * (1 + random.random()))
return x

start_hook = random.seed

arr = np.random.rand(100, 10)
y = smap(f1, arr)
y = smap(sleep_and_return, arr)
y = prefetch(
y, nworkers=4, max_buffered=10, start_hook=start_hook, **prefetch_kwargs
)
Expand All @@ -151,14 +154,14 @@ def f1(x):

# overly large buffer
arr = np.random.rand(10, 10)
y = smap(f1, arr)
y = smap(sleep_and_return, arr)
y = prefetch(y, nworkers=4, max_buffered=50, **prefetch_kwargs)
y = [y_.copy() for y_ in y]
assert_array_equal(np.stack(y), arr)

# multiple restarts
arr = np.random.rand(100, 10)
y = smap(f1, arr)
y = smap(sleep_and_return, arr)
y = prefetch(y, nworkers=4, max_buffered=10, **prefetch_kwargs)
for _ in range(10):
n = np.random.randint(0, 99)
Expand Down Expand Up @@ -200,48 +203,48 @@ def f1(x):
assert duration < 4.5


def raise_if_none(x, picklable_err):
class CustomError(Exception):
pass

if x is None:
if picklable_err:
raise ValueError()
else:
raise CustomError()

return x


@pytest.mark.parametrize("error_mode", ["wrap", "passthrough"])
@pytest.mark.parametrize("prefetch_kwargs", prefetch_kwargs_set)
@pytest.mark.parametrize("picklable_err", [False, True])
@pytest.mark.timeout(10)
def test_prefetch_errors(error_mode, prefetch_kwargs, picklable_err):
class CustomError(Exception):
pass

def f1(x):
if x is None:
raise ValueError("blablabla") if picklable_err else CustomError()
else:
return x
seterr(error_mode)

arr1 = [np.random.rand(10), np.random.rand(10), np.random.rand(10), None]
arr2 = smap(f1, arr1)
arr2 = smap(partial(raise_if_none, picklable_err=picklable_err), arr1)
y = prefetch(arr2, nworkers=2, max_buffered=4, **prefetch_kwargs)

seterr(error_mode)
if (
prefetch_kwargs["method"] != "thread" and not picklable_err
) or error_mode == "wrap":
error_t = EvaluationError
else:
error_t = ValueError if picklable_err else CustomError
if error_mode == "wrap":
expected_err = "EvaluationError"
else: # passthrough
if prefetch_kwargs["method"] == "process" and not picklable_err:
expected_err = "EvaluationError"
elif picklable_err:
expected_err = "ValueError"
else:
expected_err = "CustomError"

for i in range(3):
assert_array_equal(y[i], arr1[i])
try:
a = y[3]
except Exception as e:
assert type(e) == error_t

if (prefetch_kwargs["method"] == "process") and error_mode == "passthrough":

class CustomObject: # unpicklable object
pass

arr1 = [np.random.rand(10), CustomObject(), np.random.rand(10)]
y = prefetch(arr1, nworkers=2, max_buffered=4, **prefetch_kwargs)
with pytest.raises(ValueError):
y[1]
assert type(e).__name__ == expected_err
else:
assert False, "Should have raised"


def check_pid(pid):
Expand All @@ -254,30 +257,28 @@ def check_pid(pid):
return True


def write_pid_file(directory, worker_id):
with open("{}/{}".format(directory, os.getpid()), "w"):
pass


@pytest.mark.timeout(10)
def test_worker_crash():
if platform.python_implementation() == "PyPy":
pytest.skip("broken with pypy")

# worker dies
with tempfile.TemporaryDirectory() as d:

def init_fn(worker_id):
with open("{}/{}".format(d, os.getpid()), "w"):
pass

def f1(x):
sleep(0.02 + 0.01 * (random.random() - 0.5))
return x

arr = np.random.rand(1000, 10)
y = smap(f1, arr)
y = smap(sleep_and_return, arr)
y = prefetch(
y, method="process", max_buffered=40, nworkers=4, start_hook=init_fn
y,
method="process",
max_buffered=40,
nworkers=4,
start_hook=partial(write_pid_file, d),
)

sleep(0.1)

while len(os.listdir(d)) == 0:
sleep(0.05)

Expand All @@ -295,18 +296,15 @@ def test_orphan_workers_die():

with tempfile.TemporaryDirectory() as d:

def init_fn(worker_id):
with open("{}/{}".format(d, os.getpid()), "w"):
pass

def f1(x):
return x

def target():
arr = np.random.rand(1000, 10)
y = smap(f1, arr)
y = smap(sleep_and_return, arr)
y = prefetch(
y, method="process", max_buffered=4, nworkers=4, start_hook=init_fn
y,
method="process",
max_buffered=4,
nworkers=4,
start_hook=partial(write_pid_file, d),
)

for i in range(0, 1000):
Expand All @@ -321,7 +319,7 @@ def target():
sleep(0.05)

os.kill(p.pid, signal.SIGKILL) # parent process crashes

sleep(3) # wait for workers to time out

for pid in map(int, os.listdir(d)):
Expand Down

0 comments on commit 23c6922

Please sign in to comment.