diff --git a/src/colonyscanalyser/imaging.py b/src/colonyscanalyser/imaging.py index 2d7fbc7..73b576b 100644 --- a/src/colonyscanalyser/imaging.py +++ b/src/colonyscanalyser/imaging.py @@ -254,7 +254,10 @@ def image_as_rgb(image: ndarray) -> ndarray: return gray2rgb(image) # Remove alpha channel if present - return rgba2rgb(image) + if image.shape[-1] == 4: + image = rgba2rgb(image) + + return image def remove_background_mask(image: ndarray, smoothing: float = 1, sigmoid_cutoff: float = 0.4, **filter_args) -> ndarray: diff --git a/tests/unit/test_imaging.py b/tests/unit/test_imaging.py index f77a5f5..843da40 100644 --- a/tests/unit/test_imaging.py +++ b/tests/unit/test_imaging.py @@ -226,6 +226,17 @@ def test_grayscale(self, image): assert len(result.shape) == 3 assert result.shape[-1] == 3 + def test_rgb(self, image): + from numpy import empty + + image = empty(image.shape + (3, ), dtype = image.dtype) + result = image_as_rgb(image) + + assert len(image.shape) == 3 + assert image.shape[-1] == 3 + assert len(result.shape) == 3 + assert result.shape[-1] == 3 + def test_rgba(self, image): from numpy import empty