Skip to content

Commit

Permalink
8267 fix normalize intensity (#8286)
Browse files Browse the repository at this point in the history
Fixes #8267 .

### Description

Fix channel-wise intensity normalization for integer type inputs. 

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: advcu987 <[email protected]>
Signed-off-by: advcu <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
  • Loading branch information
advcu987 and ericspod authored Jan 20, 2025
1 parent 56d1f62 commit e39bad9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
4 changes: 4 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ class NormalizeIntensity(Transform):
mean and std on each channel separately.
When `channel_wise` is True, the first dimension of `subtrahend` and `divisor` should
be the number of image channels if they are not None.
If the input is not of floating point type, it will be converted to float32
Args:
subtrahend: the amount to subtract by (usually the mean).
Expand Down Expand Up @@ -907,6 +908,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
if self.divisor is not None and len(self.divisor) != len(img):
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")

if not img.dtype.is_floating_point:
img, *_ = convert_data_type(img, dtype=torch.float32)

for i, d in enumerate(img):
img[i] = self._normalize( # type: ignore
d,
Expand Down
21 changes: 21 additions & 0 deletions tests/test_normalize_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,27 @@ def test_channel_wise(self, im_type):
normalized = normalizer(input_data)
assert_allclose(normalized, im_type(expected), type_test="tensor")

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_channel_wise_int(self, im_type):
normalizer = NormalizeIntensity(nonzero=True, channel_wise=True)
input_data = im_type(torch.arange(1, 25).reshape(2, 3, 4))
expected = np.array(
[
[
[-1.593255, -1.3035723, -1.0138896, -0.7242068],
[-0.4345241, -0.1448414, 0.1448414, 0.4345241],
[0.7242068, 1.0138896, 1.3035723, 1.593255],
],
[
[-1.593255, -1.3035723, -1.0138896, -0.7242068],
[-0.4345241, -0.1448414, 0.1448414, 0.4345241],
[0.7242068, 1.0138896, 1.3035723, 1.593255],
],
]
)
normalized = normalizer(input_data)
assert_allclose(normalized, im_type(expected), type_test="tensor", rtol=1e-7, atol=1e-7) # tolerance

@parameterized.expand([[p] for p in TEST_NDARRAYS])
def test_value_errors(self, im_type):
input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))
Expand Down

0 comments on commit e39bad9

Please sign in to comment.