diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 5ef8d7e903..7df19d88d3 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -39,7 +39,6 @@ Compose, Randomizable, RandomizableTrait, - ThreadUnsafe, Transform, apply_transform, convert_to_contiguous, @@ -209,6 +208,11 @@ class PersistentDataset(Dataset): not guaranteed, so caution should be used when modifying transforms to avoid unexpected errors. If in doubt, it is advisable to clear the cache directory. + Lazy Resampling: + If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to + its documentation to familiarize yourself with the interaction between `PersistentDataset` and + lazy resampling. + """ def __init__( @@ -316,15 +320,15 @@ def _pre_transform(self, item_transformed): random transform object """ - for _transform in self.transform.transforms: - # execute all the deterministic transforms - if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform): - break - # this is to be consistent with CacheDataset even though it's not in a multi-thread situation. - _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform - item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform) - item_transformed = apply_transform(_xform, item_transformed) - item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) + if not isinstance(self.transform, Compose): + raise ValueError("transform must be an instance of monai.transforms.Compose.") + + first_random = self.transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + + item_transformed = self.transform(item_transformed, end=first_random, threading=True) + if self.reset_ops_id: reset_ops_id(item_transformed) return item_transformed @@ -342,17 +346,12 @@ def _post_transform(self, item_transformed): """ if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - start_post_randomize_run = False - for _transform in self.transform.transforms: - if ( - start_post_randomize_run - or isinstance(_transform, RandomizableTrait) - or not isinstance(_transform, Transform) - ): - start_post_randomize_run = True - item_transformed = self.transform.evaluate_with_overrides(item_transformed, _transform) - item_transformed = apply_transform(_transform, item_transformed) - item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) + + first_random = self.transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + if first_random is not None: + item_transformed = self.transform(item_transformed, start=first_random) return item_transformed def _cachecheck(self, item_transformed): @@ -496,13 +495,9 @@ def _pre_transform(self, item_transformed): """ if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - for i, _transform in enumerate(self.transform.transforms): - if i == self.cache_n_trans: - break - _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform - item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform) - item_transformed = apply_transform(_xform, item_transformed) - item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) + + item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True) + reset_ops_id(item_transformed) return item_transformed @@ -518,12 +513,8 @@ def _post_transform(self, item_transformed): """ if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - for i, _transform in enumerate(self.transform.transforms): - if i >= self.cache_n_trans: - item_transformed = self.transform.evaluate_with_overrides(item_transformed, item_transformed) - item_transformed = apply_transform(_transform, item_transformed) - item_transformed = self.transform.evaluate_with_overrides(item_transformed, None) - return item_transformed + + return self.transform(item_transformed, start=self.cache_n_trans) class LMDBDataset(PersistentDataset): @@ -748,6 +739,11 @@ class CacheDataset(Dataset): So to debug or verify the program before real training, users can set `cache_rate=0.0` or `cache_num=0` to temporarily skip caching. + Lazy Resampling: + If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to + its documentation to familiarize yourself with the interaction between `CacheDataset` and + lazy resampling. + """ def __init__( @@ -887,14 +883,12 @@ def _load_cache_item(self, idx: int): idx: the index of the input data sequence. """ item = self.data[idx] - for _transform in self.transform.transforms: - # execute all the deterministic transforms - if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform): - break - _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform - item = self.transform.evaluate_with_overrides(item, _xform) - item = apply_transform(_xform, item) - item = self.transform.evaluate_with_overrides(item, None) + + first_random = self.transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + item = self.transform(item, end=first_random, threading=True) + if self.as_contiguous: item = convert_to_contiguous(item, memory_format=torch.contiguous_format) return item @@ -921,19 +915,16 @@ def _transform(self, index: int): data = self._cache[cache_index] = self._load_cache_item(cache_index) # load data from cache and execute from the first random transform - start_run = False if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") - for _transform in self.transform.transforms: - if start_run or isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform): - # only need to deep copy data on first non-deterministic transform - if not start_run: - start_run = True - if self.copy_cache: - data = deepcopy(data) - data = self.transform.evaluate_with_overrides(data, _transform) - data = apply_transform(_transform, data) - data = self.transform.evaluate_with_overrides(data, None) + + first_random = self.transform.get_index_of_first( + lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform) + ) + if first_random is not None: + data = deepcopy(data) if self.copy_cache is True else data + data = self.transform(data, start=first_random) + return data @@ -1008,7 +999,6 @@ class SmartCacheDataset(Randomizable, CacheDataset): as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. it may help improve the performance of following logic. runtime_cache: Default to `False`, other options are not implemented yet. - """ def __init__( diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 6cdd1b3d55..8a8518c92b 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -16,6 +16,7 @@ import warnings from collections.abc import Callable, Mapping, Sequence +from copy import deepcopy from typing import Any import numpy as np @@ -23,7 +24,9 @@ import monai import monai.transforms as mt from monai.apps.utils import get_logger +from monai.config import NdarrayOrTensor from monai.transforms.inverse import InvertibleTransform +from monai.transforms.traits import ThreadUnsafe # For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform) from monai.transforms.transform import ( # noqa: F401 @@ -115,6 +118,91 @@ def evaluate_with_overrides( return data +def execute_compose( + data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], + transforms: Sequence[Any], + map_items: bool = True, + unpack_items: bool = False, + start: int = 0, + end: int | None = None, + lazy_evaluation: bool = False, + overrides: dict | None = None, + override_keys: Sequence[str] | None = None, + threading: bool = False, + log_stats: bool = False, + verbose: bool = False, +) -> NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor]: + """ + ``execute_compose`` provides the implementation that the ``Compose`` class uses to execute a sequence + of transforms. As well as being used by Compose, it can be used by subclasses of + Compose and by code that doesn't have a Compose instance but needs to execute a + sequence of transforms is if it were executed by Compose. It should only be used directly + when it is not possible to use ``Compose.__call__`` to achieve the same goal. + Args: + data: a tensor-like object to be transformed + transforms: a sequence of transforms to be carried out + map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. + defaults to `True`. + unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform. + defaults to `False`. + start: the index of the first transform to be executed. If not set, this defaults to 0 + end: the index after the last transform to be exectued. If set, the transform at index-1 + is the last transform that is executed. If this is not set, it defaults to len(transforms) + lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be + carried out on a transform by transform basis. If True, all lazy transforms will + be executed by accumulating changes and resampling as few times as possible. + A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of + the pending operations and make the primary data up-to-date. + overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden + when executing a pipeline. These each parameter that is compatible with a given transform is then applied + to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation + is True. If lazy_evaluation is False they are ignored. + currently supported args are: + {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``}, + please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details. + override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If + ``overrides`` is set, ``override_keys`` must also be set. + threading: whether executing is happening in a threaded environment. If set, copies are made + of transforms that have the ``RandomizedTrait`` interface. + log_stats: whether to log the detailed information of data and applied transform when error happened, + for NumPy array and PyTorch Tensor, log the data shape and value range, + for other metadata, log the values directly. default to `False`. + verbose: whether to print debugging info when lazy_evaluation=True. + + Returns: + A tensorlike, sequence of tensorlikes or dict of tensorlists containing the result of running + `data`` through the sequence of ``transforms``. + """ + end_ = len(transforms) if end is None else end + if start is None: + raise ValueError(f"'start' ({start}) cannot be None") + if start > end_: + raise ValueError(f"'start' ({start}) must be less than 'end' ({end_})") + if end_ > len(transforms): + raise ValueError(f"'end' ({end_}) must be less than or equal to the transform count ({len(transforms)}") + + # no-op if the range is empty + if start == end: + return data + + for _transform in transforms[start:end]: + if threading: + _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform + data = evaluate_with_overrides( + data, + _transform, + lazy_evaluation=lazy_evaluation, + overrides=overrides, + override_keys=override_keys, + verbose=verbose, + ) + data = apply_transform(_transform, data, map_items, unpack_items, log_stats) + data = evaluate_with_overrides( + data, None, lazy_evaluation=lazy_evaluation, overrides=overrides, override_keys=override_keys, verbose=verbose + ) + return data + + class Compose(Randomizable, InvertibleTransform): """ ``Compose`` provides the ability to chain a series of callables together in @@ -183,6 +271,37 @@ class Compose(Randomizable, InvertibleTransform): calls your pre-processing functions taking into account that not all of them are called on the labels. + Lazy resampling: + Lazy resampling is an experimental feature introduced in 1.2. Its purpose is + to reduce the number of resample operations that must be carried out when executing + a pipeline of transforms. This can provide significant performance improvements in + terms of pipeline executing speed and memory usage, but can also significantly + reduce the loss of information that occurs when performing a number of spatial + resamples in succession. + + Lazy resampling can be thought of as acting in a similar fashion to the `Affine` & `RandAffine` + transforms, in that they allow several spatial transform operations can be specified and carried out with + a single resample step. Unlike these transforms, however, lazy resampling can operate on any set of + transforms specified in any ordering. The user is free to mix monai transforms with transforms from other + libraries; lazy resampling will determine the minimum number of resample steps required in order to + execute the pipeline. + + Lazy resampling works with monai `Dataset` classes that provide caching and persistence. However, if you + are implementing your own caching dataset implementation and wish to make use of lazy resampling, you + should ensure that you fully execute the part of the pipeline that generates the data to be cached + before caching it. This is quite simply done however, as shown by the following example. + + Example: + # run the part of the pipeline that needs to be cached + data = self.transform(data, end=self.post_cache_index) + + # --- + + # fetch the data from the cache and run the rest of the pipeline + data = get_data_from_my_cache(data) + data = self.transform(data, start=self.post_cache_index) + + Args: transforms: sequence of callables. map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple. @@ -258,6 +377,41 @@ def randomize(self, data: Any | None = None) -> None: f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning ) + def get_index_of_first(self, predicate): + """ + get_index_of_first takes a ``predicate`` and returns the index of the first transform that + satisfies the predicate (ie. makes the predicate return True). If it is unable to find + a transform that satisfies the ``predicate``, it returns None. + + Example: + c = Compose([Flip(...), Rotate90(...), Zoom(...), RandRotate(...), Resize(...)]) + + print(c.get_index_of_first(lambda t: isinstance(t, RandomTrait))) + >>> 3 + print(c.get_index_of_first(lambda t: isinstance(t, Compose))) + >>> None + + Note: + This is only performed on the transforms directly held by this instance. If this + instance has nested ``Compose`` transforms or other transforms that contain transforms, + it does not iterate into them. + + + Args: + predicate: a callable that takes a single argument and returns a bool. When called + it is passed a transform from the sequence of transforms contained by this compose + instance. + + Returns: + The index of the first transform in the sequence for which ``predicate`` returns + True. None if no transform satisfies the ``predicate`` + + """ + for i in range(len(self.transforms)): + if predicate(self.transforms[i]): + return i + return None + def flatten(self): """Return a Composition with a simple list of transforms, as opposed to any nested Compositions. @@ -293,12 +447,21 @@ def evaluate_with_overrides(self, input_, upcoming_xform): verbose=self.verbose, ) - def __call__(self, input_): - for _transform in self.transforms: - input_ = self.evaluate_with_overrides(input_, _transform) - input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats) - input_ = self.evaluate_with_overrides(input_, None) - return input_ + def __call__(self, input_, start=0, end=None, threading=False): + return execute_compose( + input_, + self.transforms, + start=start, + end=end, + map_items=self.map_items, + unpack_items=self.unpack_items, + lazy_evaluation=self.lazy_evaluation, # type: ignore + overrides=self.overrides, + override_keys=self.override_keys, + threading=threading, + log_stats=self.log_stats, + verbose=self.verbose, + ) # type: ignore def inverse(self, data): invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)] @@ -397,12 +560,23 @@ def flatten(self): weights.append(w) return OneOf(transforms, weights, self.map_items, self.unpack_items) - def __call__(self, data): + def __call__(self, data, start=0, end=None, threading=False): if len(self.transforms) == 0: return data + index = self.R.multinomial(1, self.weights).argmax() _transform = self.transforms[index] - data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats) + + data = execute_compose( + data, + [_transform], + map_items=self.map_items, + unpack_items=self.unpack_items, + start=start, + end=end, + threading=threading, + ) + # if the data is a mapping (dictionary), append the OneOf transform to the end if isinstance(data, monai.data.MetaTensor): self.push_transform(data, extra_info={"index": index}) @@ -481,14 +655,22 @@ def __init__( transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose ) - def __call__(self, input_): + def __call__(self, input_, start=0, end=None, threading=False): if len(self.transforms) == 0: return input_ num = len(self.transforms) applied_order = self.R.permutation(range(num)) - for index in applied_order: - input_ = apply_transform(self.transforms[index], input_, self.map_items, self.unpack_items, self.log_stats) + input_ = execute_compose( + input_, + [self.transforms[ind] for ind in applied_order], + map_items=self.map_items, + unpack_items=self.unpack_items, + start=start, + end=end, + threading=threading, + ) + # if the data is a mapping (dictionary), append the RandomOrder transform to the end if isinstance(input_, monai.data.MetaTensor): self.push_transform(input_, extra_info={"applied_order": applied_order}) @@ -618,15 +800,22 @@ def _normalize_probabilities(self, weights): return ensure_tuple(list(weights)) - def __call__(self, data): + def __call__(self, data, start=0, end=None, threading=False): if len(self.transforms) == 0: return data sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1) applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist() - for i in applied_order: - data = apply_transform(self.transforms[i], data, self.map_items, self.unpack_items, self.log_stats) + data = execute_compose( + data, + [self.transforms[a] for a in applied_order], + map_items=self.map_items, + unpack_items=self.unpack_items, + start=start, + end=end, + threading=threading, + ) if isinstance(data, monai.data.MetaTensor): self.push_transform(data, extra_info={"applied_order": applied_order}) elif isinstance(data, Mapping): diff --git a/tests/test_compose.py b/tests/test_compose.py index ddb7ce25d8..47869b02aa 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -13,9 +13,15 @@ import sys import unittest +from copy import deepcopy + +import numpy as np +import torch +from parameterized import parameterized from monai.data import DataLoader, Dataset -from monai.transforms import AddChannel, Compose +from monai.transforms import AddChannel, Compose, Flip, NormalizeIntensity, Rotate, Rotate90, Rotated, Zoom +from monai.transforms.compose import execute_compose from monai.transforms.transform import Randomizable from monai.utils import set_determinism @@ -56,8 +62,12 @@ def b(d): d["b"] += 1 return d - c = Compose([a, b, a, b, a]) - self.assertDictEqual(c({"a": 0, "b": 0}), {"a": 3, "b": 2}) + transforms = [a, b, a, b, a] + data = {"a": 0, "b": 0} + expected = {"a": 3, "b": 2} + + self.assertDictEqual(Compose(transforms)(data), expected) + self.assertDictEqual(execute_compose(data, transforms), expected) def test_list_dict_compose(self): def a(d): # transform to handle dict data @@ -76,10 +86,15 @@ def c(d): # transform to handle dict data d["c"] += 1 return d - transforms = Compose([a, a, b, c, c]) - value = transforms({"a": 0, "b": 0, "c": 0}) + transforms = [a, a, b, c, c] + data = {"a": 0, "b": 0, "c": 0} + expected = {"a": 2, "b": 1, "c": 2} + value = Compose(transforms)(data) + for item in value: + self.assertDictEqual(item, expected) + value = execute_compose(data, transforms) for item in value: - self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + self.assertDictEqual(item, expected) def test_non_dict_compose_with_unpack(self): def a(i, i2): @@ -88,8 +103,11 @@ def a(i, i2): def b(i, i2): return i + "b", i2 + "b2" - c = Compose([a, b, a, b], map_items=False, unpack_items=True) - self.assertEqual(c(("", "")), ("abab", "a2b2a2b2")) + transforms = [a, b, a, b] + data = ("", "") + expected = ("abab", "a2b2a2b2") + self.assertEqual(Compose(transforms, map_items=False, unpack_items=True)(data), expected) + self.assertEqual(execute_compose(data, transforms, map_items=False, unpack_items=True), expected) def test_list_non_dict_compose_with_unpack(self): def a(i, i2): @@ -98,8 +116,11 @@ def a(i, i2): def b(i, i2): return i + "b", i2 + "b2" - c = Compose([a, b, a, b], unpack_items=True) - self.assertEqual(c([("", ""), ("t", "t")]), [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")]) + transforms = [a, b, a, b] + data = [("", ""), ("t", "t")] + expected = [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")] + self.assertEqual(Compose(transforms, unpack_items=True)(data), expected) + self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected) def test_list_dict_compose_no_map(self): def a(d): # transform to handle dict data @@ -119,10 +140,15 @@ def c(d): # transform to handle dict data di["c"] += 1 return d - transforms = Compose([a, a, b, c, c], map_items=False) - value = transforms({"a": 0, "b": 0, "c": 0}) + transforms = [a, a, b, c, c] + data = {"a": 0, "b": 0, "c": 0} + expected = {"a": 2, "b": 1, "c": 2} + value = Compose(transforms, map_items=False)(data) for item in value: - self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + self.assertDictEqual(item, expected) + value = execute_compose(data, transforms, map_items=False) + for item in value: + self.assertDictEqual(item, expected) def test_random_compose(self): class _Acc(Randomizable): @@ -220,5 +246,97 @@ def test_backwards_compatible_imports(self): from monai.transforms.compose import MapTransform, RandomizableTransform, Transform # noqa: F401 +TEST_COMPOSE_EXECUTE_TEST_CASES = [ + [None, tuple()], + [None, (Rotate(np.pi / 8),)], + [None, (Flip(0), Flip(1), Rotate90(1), Zoom(0.8), NormalizeIntensity())], + [("a",), (Rotated(("a",), np.pi / 8),)], +] + + +class TestComposeExecute(unittest.TestCase): + @parameterized.expand(TEST_COMPOSE_EXECUTE_TEST_CASES) + def test_compose_execute_equivalence(self, keys, pipeline): + if keys is None: + data = torch.unsqueeze(torch.tensor(np.arange(24 * 32).reshape(24, 32)), axis=0) + else: + data = {} + for i_k, k in enumerate(keys): + data[k] = torch.unsqueeze(torch.tensor(np.arange(24 * 32)).reshape(24, 32) + i_k * 768, axis=0) + + expected = Compose(deepcopy(pipeline))(data) + + for cutoff in range(len(pipeline)): + c = Compose(deepcopy(pipeline)) + actual = c(c(data, end=cutoff), start=cutoff) + if isinstance(actual, dict): + for k in actual.keys(): + self.assertTrue(torch.allclose(expected[k], actual[k])) + else: + self.assertTrue(torch.allclose(expected, actual)) + + p = deepcopy(pipeline) + actual = execute_compose(execute_compose(data, p, start=0, end=cutoff), p, start=cutoff) + if isinstance(actual, dict): + for k in actual.keys(): + self.assertTrue(torch.allclose(expected[k], actual[k])) + else: + self.assertTrue(torch.allclose(expected, actual)) + + +class TestOps: + @staticmethod + def concat(value): + def _inner(data): + return data + value + + return _inner + + @staticmethod + def concatd(value): + def _inner(data): + return {k: v + value for k, v in data.items()} + + return _inner + + @staticmethod + def concata(value): + def _inner(data1, data2): + return data1 + value, data2 + value + + return _inner + + +TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES = [ + [{}, ("",), (TestOps.concat("a"), TestOps.concat("b"))], + [{"unpack_items": True}, ("x", "y"), (TestOps.concat("a"), TestOps.concat("b"))], + [{"map_items": False}, {"x": "1", "y": "2"}, (TestOps.concatd("a"), TestOps.concatd("b"))], + [{"unpack_items": True, "map_items": False}, ("x", "y"), (TestOps.concata("a"), TestOps.concata("b"))], +] + + +class TestComposeExecuteWithFlags(unittest.TestCase): + @parameterized.expand(TEST_COMPOSE_EXECUTE_FLAG_TEST_CASES) + def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline): + expected = Compose(pipeline, **flags)(data) + + for cutoff in range(len(pipeline)): + c = Compose(deepcopy(pipeline), **flags) + actual = c(c(data, end=cutoff), start=cutoff) + if isinstance(actual, dict): + for k in actual.keys(): + self.assertEqual(expected[k], actual[k]) + else: + self.assertTrue(expected, actual) + + p = deepcopy(pipeline) + actual = execute_compose(execute_compose(data, p, start=0, end=cutoff, **flags), p, start=cutoff, **flags) + if isinstance(actual, dict): + for k in actual.keys(): + self.assertTrue(expected[k], actual[k]) + else: + self.assertTrue(expected, actual) + + if __name__ == "__main__": unittest.main()