diff --git a/integration_tests/import_test.py b/integration_tests/import_test.py index fbb97b9d810e..25facbc50ce1 100644 --- a/integration_tests/import_test.py +++ b/integration_tests/import_test.py @@ -8,7 +8,7 @@ BACKEND_REQ = { "tensorflow": ("tensorflow-cpu", ""), "torch": ( - "torch torchvision", + "torch", "--extra-index-url https://download.pytorch.org/whl/cpu ", ), "jax": ("jax[cpu]", ""), diff --git a/keras/src/backend/torch/image.py b/keras/src/backend/torch/image.py index d466ffeb5efa..ea7c91c78b12 100644 --- a/keras/src/backend/torch/image.py +++ b/keras/src/backend/torch/image.py @@ -3,12 +3,16 @@ import operator import torch +import torch.nn.functional as F from keras.src import backend from keras.src.backend.torch.core import convert_to_tensor -from keras.src.utils.module_utils import torchvision -RESIZE_INTERPOLATIONS = {} # populated after torchvision import +RESIZE_INTERPOLATIONS = { + "bilinear": "bilinear", + "nearest": "nearest-exact", + "bicubic": "bicubic", +} UNSUPPORTED_INTERPOLATIONS = ( "lanczos3", @@ -19,23 +23,27 @@ def rgb_to_grayscale(images, data_format=None): images = convert_to_tensor(images) data_format = backend.standardize_data_format(data_format) - if data_format == "channels_last": - if images.ndim == 4: - images = images.permute((0, 3, 1, 2)) - elif images.ndim == 3: - images = images.permute((2, 0, 1)) - else: - raise ValueError( - "Invalid images rank: expected rank 3 (single image) " - "or rank 4 (batch of images). Received input with shape: " - f"images.shape={images.shape}" - ) - images = torchvision.transforms.functional.rgb_to_grayscale(img=images) - if data_format == "channels_last": - if len(images.shape) == 4: - images = images.permute((0, 2, 3, 1)) - elif len(images.shape) == 3: - images = images.permute((1, 2, 0)) + if images.ndim not in (3, 4): + raise ValueError( + "Invalid images rank: expected rank 3 (single image) " + "or rank 4 (batch of images). Received input with shape: " + f"images.shape={images.shape}" + ) + channel_axis = -3 if data_format == "channels_first" else -1 + if images.shape[channel_axis] not in (1, 3): + raise ValueError( + "Invalid channel size: expected 3 (RGB) or 1 (Grayscale). " + f"Received input with shape: images.shape={images.shape}" + ) + + # This implementation is based on + # https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py + if images.shape[channel_axis] == 3: + r, g, b = images.unbind(dim=channel_axis) + images = (0.2989 * r + 0.587 * g + 0.114 * b).to(images.dtype) + images = images.unsqueeze(dim=channel_axis) + else: + images = images.clone() return images @@ -129,6 +137,40 @@ def hsv_planes_to_rgb_planes(hue, saturation, value): return images +def _cast_squeeze_in(image, req_dtypes): + need_squeeze = False + # make image NCHW + if image.ndim < 4: + image = image.unsqueeze(dim=0) + need_squeeze = True + + out_dtype = image.dtype + need_cast = False + if out_dtype not in req_dtypes: + need_cast = True + req_dtype = req_dtypes[0] + image = image.to(req_dtype) + return image, need_cast, need_squeeze, out_dtype + + +def _cast_squeeze_out(image, need_cast, need_squeeze, out_dtype): + if need_squeeze: + image = image.squeeze(dim=0) + + if need_cast: + if out_dtype in ( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ): + # it is better to round before cast + image = torch.round(image) + image = image.to(out_dtype) + return image + + def resize( images, size, @@ -141,13 +183,6 @@ def resize( data_format=None, ): data_format = backend.standardize_data_format(data_format) - RESIZE_INTERPOLATIONS.update( - { - "bilinear": torchvision.transforms.InterpolationMode.BILINEAR, - "nearest": torchvision.transforms.InterpolationMode.NEAREST_EXACT, - "bicubic": torchvision.transforms.InterpolationMode.BICUBIC, - } - ) if interpolation in UNSUPPORTED_INTERPOLATIONS: raise ValueError( "Resizing with Lanczos interpolation is " @@ -182,11 +217,11 @@ def resize( "or rank 4 (batch of images). Received input with shape: " f"images.shape={images.shape}" ) + images, need_cast, need_squeeze, out_dtype = _cast_squeeze_in( + images, [torch.float32, torch.float64] + ) if data_format == "channels_last": - if images.ndim == 4: - images = images.permute((0, 3, 1, 2)) - else: - images = images.permute((2, 0, 1)) + images = images.permute((0, 3, 1, 2)) if crop_to_aspect_ratio: shape = images.shape @@ -198,19 +233,12 @@ def resize( crop_width = max(min(width, crop_width), 1) crop_box_hstart = int(float(height - crop_height) / 2) crop_box_wstart = int(float(width - crop_width) / 2) - if len(images.shape) == 4: - images = images[ - :, - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] - else: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] elif pad_to_aspect_ratio: shape = images.shape height, width = shape[-2], shape[-1] @@ -221,105 +249,77 @@ def resize( pad_width = max(width, pad_width) img_box_hstart = int(float(pad_height - height) / 2) img_box_wstart = int(float(pad_width - width) / 2) - if len(images.shape) == 4: - batch_size = images.shape[0] - channels = images.shape[1] - if img_box_hstart > 0: - padded_img = torch.cat( - [ - torch.ones( - (batch_size, channels, img_box_hstart, width), - dtype=images.dtype, - device=images.device, - ) - * fill_value, - images, - torch.ones( - (batch_size, channels, img_box_hstart, width), - dtype=images.dtype, - device=images.device, - ) - * fill_value, - ], - axis=2, - ) - else: - padded_img = images - - if img_box_wstart > 0: - padded_img = torch.cat( - [ - torch.ones( - (batch_size, channels, height, img_box_wstart), - dtype=images.dtype, - device=images.device, - ), - padded_img, - torch.ones( - (batch_size, channels, height, img_box_wstart), - dtype=images.dtype, - device=images.device, - ) - * fill_value, - ], - axis=3, - ) + batch_size = images.shape[0] + channels = images.shape[1] + if img_box_hstart > 0: + padded_img = torch.cat( + [ + torch.ones( + (batch_size, channels, img_box_hstart, width), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + images, + torch.ones( + (batch_size, channels, img_box_hstart, width), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + ], + axis=2, + ) else: - channels = images.shape[0] - if img_box_wstart > 0: - padded_img = torch.cat( - [ - torch.ones( - (channels, img_box_hstart, width), - dtype=images.dtype, - device=images.device, - ) - * fill_value, - images, - torch.ones( - (channels, img_box_hstart, width), - dtype=images.dtype, - device=images.device, - ) - * fill_value, - ], - axis=1, - ) - else: - padded_img = images - if img_box_wstart > 0: - torch.cat( - [ - torch.ones( - (channels, height, img_box_wstart), - dtype=images.dtype, - device=images.device, - ) - * fill_value, - padded_img, - torch.ones( - (channels, height, img_box_wstart), - dtype=images.dtype, - device=images.device, - ) - * fill_value, - ], - axis=2, - ) + padded_img = images + if img_box_wstart > 0: + padded_img = torch.cat( + [ + torch.ones( + (batch_size, channels, height, img_box_wstart), + dtype=images.dtype, + device=images.device, + ), + padded_img, + torch.ones( + (batch_size, channels, height, img_box_wstart), + dtype=images.dtype, + device=images.device, + ) + * fill_value, + ], + axis=3, + ) images = padded_img - resized = torchvision.transforms.functional.resize( - img=images, + # This implementation is based on + # https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py + if antialias and interpolation not in ("bilinear", "bicubic"): + # We manually set it to False to avoid an error downstream in + # interpolate(). This behaviour is documented: the parameter is + # irrelevant for modes that are not bilinear or bicubic. We used to + # raise an error here, but now we don't use True as the default. + antialias = False + # Define align_corners to avoid warnings + align_corners = False if interpolation in ("bilinear", "bicubic") else None + resized = F.interpolate( + images, size=size, - interpolation=RESIZE_INTERPOLATIONS[interpolation], + mode=RESIZE_INTERPOLATIONS[interpolation], + align_corners=align_corners, antialias=antialias, ) + if interpolation == "bicubic" and out_dtype == torch.uint8: + resized = resized.clamp(min=0, max=255) if data_format == "channels_last": - if len(images.shape) == 4: - resized = resized.permute((0, 2, 3, 1)) - elif len(images.shape) == 3: - resized = resized.permute((1, 2, 0)) + resized = resized.permute((0, 2, 3, 1)) + resized = _cast_squeeze_out( + resized, + need_cast=need_cast, + need_squeeze=need_squeeze, + out_dtype=out_dtype, + ) return resized diff --git a/keras/src/utils/module_utils.py b/keras/src/utils/module_utils.py index 190bc8dc72fe..d81ec05028e4 100644 --- a/keras/src/utils/module_utils.py +++ b/keras/src/utils/module_utils.py @@ -44,7 +44,6 @@ def __repr__(self): tensorflow_io = LazyModule("tensorflow_io") scipy = LazyModule("scipy") jax = LazyModule("jax") -torchvision = LazyModule("torchvision") torch_xla = LazyModule( "torch_xla", import_error_msg=( diff --git a/requirements-common.txt b/requirements-common.txt index 51c682f9ef41..cc9ce873edaa 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -4,6 +4,7 @@ pytest numpy scipy scikit-learn +pillow pandas absl-py requests diff --git a/requirements-jax-cuda.txt b/requirements-jax-cuda.txt index a6b8671ad6fb..1cc1a1b75985 100644 --- a/requirements-jax-cuda.txt +++ b/requirements-jax-cuda.txt @@ -5,7 +5,6 @@ tf2onnx # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu torch==2.6.0+cpu -torchvision==0.21.0+cpu # Jax with cuda support. --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html diff --git a/requirements-tensorflow-cuda.txt b/requirements-tensorflow-cuda.txt index c52ccf568948..dbbf7a3b5106 100644 --- a/requirements-tensorflow-cuda.txt +++ b/requirements-tensorflow-cuda.txt @@ -5,7 +5,6 @@ tf2onnx # Torch cpu-only version (needed for testing). --extra-index-url https://download.pytorch.org/whl/cpu torch==2.6.0+cpu -torchvision==0.21.0+cpu # Jax cpu-only version (needed for testing). jax[cpu] diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index b0957e94fcb2..7b8eb7434a0e 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -7,7 +7,6 @@ tf2onnx # - torch-xla is pinned to a version that supports GPU (2.6 doesn't) --extra-index-url https://download.pytorch.org/whl/cu121 torch==2.5.1+cu121 -torchvision==0.20.1+cu121 torch-xla==2.5.1;sys_platform != 'darwin' # Jax cpu-only version (needed for testing). diff --git a/requirements.txt b/requirements.txt index bb9881e1f435..134b89d6f6e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ tf2onnx # Torch. --extra-index-url https://download.pytorch.org/whl/cpu torch==2.6.0+cpu -torchvision==0.21.0+cpu torch-xla==2.6.0;sys_platform != 'darwin' # Jax.