Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Compose encapsulation #6224

Merged
merged 33 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1759f96
Initial commit to resolve #6223
atbenmurray Mar 22, 2023
ba4115d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2023
632119e
[MONAI] code formatting
monai-bot Mar 22, 2023
50a9d06
Initial commit to resolve #6223
atbenmurray Mar 22, 2023
dd5ad30
Merging into autoformat fixes
atbenmurray Mar 22, 2023
10edabd
DCO Remediation Commit for Ben Murray <[email protected]>
atbenmurray Mar 22, 2023
5a8ccaf
DCO Remediation Commit for Ben Murray <[email protected]>
atbenmurray Mar 22, 2023
918cbf6
Merge branch 'compose_refactor' of github.com:project-monai/monai int…
atbenmurray Mar 22, 2023
c3bd8c0
DCO Remediation Commit for Ben Murray <[email protected]>
atbenmurray Mar 22, 2023
0c0baa4
Fixes to make test_cachedataset, test_persistentdataset and test_cach…
atbenmurray Mar 22, 2023
c5a73f6
Documentation for Compose.execute
atbenmurray Mar 22, 2023
41a156a
style/docs
wyli Mar 22, 2023
eda9b4a
Merge remote-tracking branch 'upstream/dev' into compose_refactor
wyli Mar 22, 2023
61f67e4
Added tests; updated documentation
atbenmurray Mar 24, 2023
bd591b7
Merge branch 'compose_refactor' of github.com:project-monai/monai int…
atbenmurray Mar 24, 2023
7f37392
DCO Remediation Commit for Ben Murray <[email protected]>
atbenmurray Mar 24, 2023
5301af1
Honoring the self.copy_cache flag
atbenmurray Mar 24, 2023
5bab340
Merge branch 'dev' into compose_refactor
atbenmurray Mar 24, 2023
6bdedac
Updating for lazy resampling
atbenmurray Mar 24, 2023
d99b2b9
Autoformatting
atbenmurray Mar 24, 2023
0223e62
Merge branch 'dev' into compose_refactor
wyli Mar 27, 2023
d939395
Moving Compose.execute to execute_compose as per @ericspod's request.…
atbenmurray Mar 29, 2023
4ecf2c3
Test fix: missed Compose.execute to execute_compose changes
atbenmurray Mar 29, 2023
948444f
DCO Remediation Commit for Ben Murray <[email protected]>
atbenmurray Mar 29, 2023
9097e07
Bug fix for SomeOff; generate list of transforms in execution order
atbenmurray Mar 29, 2023
d701642
Documentation for Compose.get_index_of_first
atbenmurray Mar 29, 2023
77e465d
Slight documentation reformatting for Compose.get_index_of_first
atbenmurray Mar 29, 2023
5605e71
Updated docstrings for execute_compose. Renamed input_ to data for
atbenmurray Mar 29, 2023
c06e014
Fixing errors reported by flake8-py3 (mypy) output
atbenmurray Mar 29, 2023
191506d
Had to go back to lazy_evaluation default of None for now but this is a
atbenmurray Mar 29, 2023
ec71402
execute_compose type ignore as it can't be fixed without polluting more
atbenmurray Mar 29, 2023
85b18de
type: ignore suppression as this is being addressed separately
atbenmurray Mar 29, 2023
7684edf
Merge branch 'dev' into compose_refactor
wyli Mar 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 44 additions & 54 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
Compose,
Randomizable,
RandomizableTrait,
ThreadUnsafe,
Transform,
apply_transform,
convert_to_contiguous,
Expand Down Expand Up @@ -209,6 +208,11 @@ class PersistentDataset(Dataset):
not guaranteed, so caution should be used when modifying transforms to avoid unexpected
errors. If in doubt, it is advisable to clear the cache directory.

Lazy Resampling:
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
its documentation to familiarize yourself with the interaction between `PersistentDataset` and
lazy resampling.

"""

def __init__(
Expand Down Expand Up @@ -316,15 +320,15 @@ def _pre_transform(self, item_transformed):
random transform object

"""
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
break
# this is to be consistent with CacheDataset even though it's not in a multi-thread situation.
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform)
item_transformed = apply_transform(_xform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
wyli marked this conversation as resolved.
Show resolved Hide resolved

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)

item_transformed = self.transform(item_transformed, end=first_random, threading=True)

if self.reset_ops_id:
reset_ops_id(item_transformed)
return item_transformed
Expand All @@ -342,17 +346,12 @@ def _post_transform(self, item_transformed):
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
start_post_randomize_run = False
for _transform in self.transform.transforms:
if (
start_post_randomize_run
or isinstance(_transform, RandomizableTrait)
or not isinstance(_transform, Transform)
):
start_post_randomize_run = True
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _transform)
item_transformed = apply_transform(_transform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
if first_random is not None:
item_transformed = self.transform(item_transformed, start=first_random)
return item_transformed

def _cachecheck(self, item_transformed):
Expand Down Expand Up @@ -496,13 +495,9 @@ def _pre_transform(self, item_transformed):
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for i, _transform in enumerate(self.transform.transforms):
if i == self.cache_n_trans:
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform)
item_transformed = apply_transform(_xform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)

item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True)

reset_ops_id(item_transformed)
return item_transformed

Expand All @@ -518,12 +513,8 @@ def _post_transform(self, item_transformed):
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for i, _transform in enumerate(self.transform.transforms):
if i >= self.cache_n_trans:
item_transformed = self.transform.evaluate_with_overrides(item_transformed, item_transformed)
item_transformed = apply_transform(_transform, item_transformed)
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
return item_transformed

return self.transform(item_transformed, start=self.cache_n_trans)


class LMDBDataset(PersistentDataset):
Expand Down Expand Up @@ -748,6 +739,11 @@ class CacheDataset(Dataset):
So to debug or verify the program before real training, users can set `cache_rate=0.0` or `cache_num=0` to
temporarily skip caching.

Lazy Resampling:
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
its documentation to familiarize yourself with the interaction between `CacheDataset` and
lazy resampling.

"""

def __init__(
Expand Down Expand Up @@ -887,14 +883,12 @@ def _load_cache_item(self, idx: int):
idx: the index of the input data sequence.
"""
item = self.data[idx]
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item = self.transform.evaluate_with_overrides(item, _xform)
item = apply_transform(_xform, item)
item = self.transform.evaluate_with_overrides(item, None)

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
item = self.transform(item, end=first_random, threading=True)

if self.as_contiguous:
item = convert_to_contiguous(item, memory_format=torch.contiguous_format)
return item
Expand All @@ -921,19 +915,16 @@ def _transform(self, index: int):
data = self._cache[cache_index] = self._load_cache_item(cache_index)

# load data from cache and execute from the first random transform
start_run = False
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
if start_run or isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
# only need to deep copy data on first non-deterministic transform
if not start_run:
start_run = True
if self.copy_cache:
data = deepcopy(data)
data = self.transform.evaluate_with_overrides(data, _transform)
data = apply_transform(_transform, data)
data = self.transform.evaluate_with_overrides(data, None)

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
if first_random is not None:
wyli marked this conversation as resolved.
Show resolved Hide resolved
data = deepcopy(data) if self.copy_cache is True else data
data = self.transform(data, start=first_random)

return data


Expand Down Expand Up @@ -1008,7 +999,6 @@ class SmartCacheDataset(Randomizable, CacheDataset):
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
runtime_cache: Default to `False`, other options are not implemented yet.

"""

def __init__(
Expand Down
Loading