Skip to content

Commit

Permalink
2099 - enhances dataloader/randomizable docstrings (Project-MONAI#2188)
Browse files Browse the repository at this point in the history
* fixes docstring

Signed-off-by: Wenqi Li <[email protected]>

* 2099 docstring updates

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored May 13, 2021
1 parent 4ef7d22 commit a5cf253
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 9 deletions.
44 changes: 39 additions & 5 deletions monai/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,54 @@


class DataLoader(_TorchDataLoader):
"""Generates images/labels for train/validation/testing from dataset.
It inherits from PyTorch DataLoader and adds default callbacks for `collate`
and `worker_fn` if user doesn't set them.
"""
Provides an iterable over the given `dataset`. It inherits the PyTorch
DataLoader and adds enhanced `collate_fn` and `worker_fn` by default.
Although this class could be configured to be the same as
`torch.utils.data.DataLoader`, its default configuration is
recommended, mainly for the following extra features:
More information about PyTorch DataLoader, please check:
- It handles MONAI randomizable objects with appropriate random state
managements for deterministic behaviour.
- It is aware of the patch-based transform (such as
:py:class:`monai.transforms.RandSpatialCropSamplesDict`) samples for
preprocessing with enhanced data collating behaviour.
See: :py:class:`monai.transforms.Compose`.
For more details about :py:class:`torch.utils.data.DataLoader`, please see:
https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py
For example, to construct a randomized dataset and iterate with the data loader:
.. code-block:: python
import torch
from monai.data import DataLoader
from monai.transforms import Randomizable
class RandomDataset(torch.utils.data.Dataset, Randomizable):
def __getitem__(self, index):
return self.R.randint(0, 1000, (1,))
def __len__(self):
return 16
dataset = RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=4)
for epoch in range(2):
for i, batch in enumerate(dataloader):
print(epoch, i, batch.data.numpy().flatten().tolist())
Args:
dataset: dataset from which to load the data.
num_workers: how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
kwargs: other parameters for PyTorch DataLoader.
"""

def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
Expand Down
5 changes: 4 additions & 1 deletion monai/transforms/croppad/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class PadListDataCollate(InvertibleTransform):
pass the inverse through multiprocessing.
Args:
batch: batch of data to pad-collate
method: padding method (see :py:class:`monai.transforms.SpatialPad`)
mode: padding mode (see :py:class:`monai.transforms.SpatialPad`)
"""
Expand All @@ -72,6 +71,10 @@ def __init__(
self.mode = mode

def __call__(self, batch: Any):
"""
Args:
batch: batch of data to pad-collate
"""
# data is either list of dicts or list of lists
is_list_of_dicts = isinstance(batch[0], dict)
# loop over items inside of each element in a batch
Expand Down
8 changes: 6 additions & 2 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,12 @@ def _log_stats(data, prefix: Optional[str] = "Data"):

class Randomizable(ABC):
"""
An interface for handling random state locally, currently based on a class variable `R`,
which is an instance of `np.random.RandomState`.
An interface for handling random state locally, currently based on a class
variable `R`, which is an instance of `np.random.RandomState`. This
provides the flexibility of component-specific determinism without
affecting the global states. It is recommended to use this API with
:py:class:`monai.data.DataLoader` for deterministic behaviour of the
preprocessing pipelines.
"""

R: np.random.RandomState = np.random.RandomState()
Expand Down
32 changes: 31 additions & 1 deletion tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from parameterized import parameterized

from monai.data import CacheDataset, DataLoader, Dataset
from monai.transforms import Compose, DataStatsd, SimulateDelayd
from monai.transforms import Compose, DataStatsd, Randomizable, SimulateDelayd
from monai.utils import set_determinism

TEST_CASE_1 = [
[
Expand Down Expand Up @@ -64,5 +65,34 @@ def test_exception(self, datalist):
pass


class _RandomDataset(torch.utils.data.Dataset, Randomizable):
def __getitem__(self, index):
return self.R.randint(0, 1000, (1,))

def __len__(self):
return 8


class TestLoaderRandom(unittest.TestCase):
"""
Testing data loader working with the randomizable interface
"""

def setUp(self):
set_determinism(0)

def tearDown(self):
set_determinism(None)

def test_randomize(self):
dataset = _RandomDataset()
dataloader = DataLoader(dataset, batch_size=2, num_workers=3)
output = []
for _ in range(2):
for batch in dataloader:
output.extend(batch.data.numpy().flatten().tolist())
self.assertListEqual(output, [594, 170, 524, 778, 370, 906, 292, 589, 762, 763, 156, 886, 42, 405, 221, 166])


if __name__ == "__main__":
unittest.main()

0 comments on commit a5cf253

Please sign in to comment.