Skip to content

Commit

Permalink
Remove torchvision dep and simplify resize and rgb_to_grayscale
Browse files Browse the repository at this point in the history
… in torch backend (#20868)

* Remove `torchvision` dependency and simplify `resize`.

* Add pillow as the testing requirement
  • Loading branch information
james77777778 authored Feb 7, 2025
1 parent 3906e32 commit b2d5b88
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 140 deletions.
2 changes: 1 addition & 1 deletion integration_tests/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
BACKEND_REQ = {
"tensorflow": ("tensorflow-cpu", ""),
"torch": (
"torch torchvision",
"torch",
"--extra-index-url https://download.pytorch.org/whl/cpu ",
),
"jax": ("jax[cpu]", ""),
Expand Down
268 changes: 134 additions & 134 deletions keras/src/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import operator

import torch
import torch.nn.functional as F

from keras.src import backend
from keras.src.backend.torch.core import convert_to_tensor
from keras.src.utils.module_utils import torchvision

RESIZE_INTERPOLATIONS = {} # populated after torchvision import
RESIZE_INTERPOLATIONS = {
"bilinear": "bilinear",
"nearest": "nearest-exact",
"bicubic": "bicubic",
}

UNSUPPORTED_INTERPOLATIONS = (
"lanczos3",
Expand All @@ -19,23 +23,27 @@
def rgb_to_grayscale(images, data_format=None):
images = convert_to_tensor(images)
data_format = backend.standardize_data_format(data_format)
if data_format == "channels_last":
if images.ndim == 4:
images = images.permute((0, 3, 1, 2))
elif images.ndim == 3:
images = images.permute((2, 0, 1))
else:
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"images.shape={images.shape}"
)
images = torchvision.transforms.functional.rgb_to_grayscale(img=images)
if data_format == "channels_last":
if len(images.shape) == 4:
images = images.permute((0, 2, 3, 1))
elif len(images.shape) == 3:
images = images.permute((1, 2, 0))
if images.ndim not in (3, 4):
raise ValueError(
"Invalid images rank: expected rank 3 (single image) "
"or rank 4 (batch of images). Received input with shape: "
f"images.shape={images.shape}"
)
channel_axis = -3 if data_format == "channels_first" else -1
if images.shape[channel_axis] not in (1, 3):
raise ValueError(
"Invalid channel size: expected 3 (RGB) or 1 (Grayscale). "
f"Received input with shape: images.shape={images.shape}"
)

# This implementation is based on
# https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
if images.shape[channel_axis] == 3:
r, g, b = images.unbind(dim=channel_axis)
images = (0.2989 * r + 0.587 * g + 0.114 * b).to(images.dtype)
images = images.unsqueeze(dim=channel_axis)
else:
images = images.clone()
return images


Expand Down Expand Up @@ -129,6 +137,40 @@ def hsv_planes_to_rgb_planes(hue, saturation, value):
return images


def _cast_squeeze_in(image, req_dtypes):
need_squeeze = False
# make image NCHW
if image.ndim < 4:
image = image.unsqueeze(dim=0)
need_squeeze = True

out_dtype = image.dtype
need_cast = False
if out_dtype not in req_dtypes:
need_cast = True
req_dtype = req_dtypes[0]
image = image.to(req_dtype)
return image, need_cast, need_squeeze, out_dtype


def _cast_squeeze_out(image, need_cast, need_squeeze, out_dtype):
if need_squeeze:
image = image.squeeze(dim=0)

if need_cast:
if out_dtype in (
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
):
# it is better to round before cast
image = torch.round(image)
image = image.to(out_dtype)
return image


def resize(
images,
size,
Expand All @@ -141,13 +183,6 @@ def resize(
data_format=None,
):
data_format = backend.standardize_data_format(data_format)
RESIZE_INTERPOLATIONS.update(
{
"bilinear": torchvision.transforms.InterpolationMode.BILINEAR,
"nearest": torchvision.transforms.InterpolationMode.NEAREST_EXACT,
"bicubic": torchvision.transforms.InterpolationMode.BICUBIC,
}
)
if interpolation in UNSUPPORTED_INTERPOLATIONS:
raise ValueError(
"Resizing with Lanczos interpolation is "
Expand Down Expand Up @@ -182,11 +217,11 @@ def resize(
"or rank 4 (batch of images). Received input with shape: "
f"images.shape={images.shape}"
)
images, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(
images, [torch.float32, torch.float64]
)
if data_format == "channels_last":
if images.ndim == 4:
images = images.permute((0, 3, 1, 2))
else:
images = images.permute((2, 0, 1))
images = images.permute((0, 3, 1, 2))

if crop_to_aspect_ratio:
shape = images.shape
Expand All @@ -198,19 +233,12 @@ def resize(
crop_width = max(min(width, crop_width), 1)
crop_box_hstart = int(float(height - crop_height) / 2)
crop_box_wstart = int(float(width - crop_width) / 2)
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
elif pad_to_aspect_ratio:
shape = images.shape
height, width = shape[-2], shape[-1]
Expand All @@ -221,105 +249,77 @@ def resize(
pad_width = max(width, pad_width)
img_box_hstart = int(float(pad_height - height) / 2)
img_box_wstart = int(float(pad_width - width) / 2)
if len(images.shape) == 4:
batch_size = images.shape[0]
channels = images.shape[1]
if img_box_hstart > 0:
padded_img = torch.cat(
[
torch.ones(
(batch_size, channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
images,
torch.ones(
(batch_size, channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=2,
)
else:
padded_img = images

if img_box_wstart > 0:
padded_img = torch.cat(
[
torch.ones(
(batch_size, channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
),
padded_img,
torch.ones(
(batch_size, channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=3,
)

batch_size = images.shape[0]
channels = images.shape[1]
if img_box_hstart > 0:
padded_img = torch.cat(
[
torch.ones(
(batch_size, channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
images,
torch.ones(
(batch_size, channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=2,
)
else:
channels = images.shape[0]
if img_box_wstart > 0:
padded_img = torch.cat(
[
torch.ones(
(channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
images,
torch.ones(
(channels, img_box_hstart, width),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=1,
)
else:
padded_img = images
if img_box_wstart > 0:
torch.cat(
[
torch.ones(
(channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
)
* fill_value,
padded_img,
torch.ones(
(channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=2,
)
padded_img = images
if img_box_wstart > 0:
padded_img = torch.cat(
[
torch.ones(
(batch_size, channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
),
padded_img,
torch.ones(
(batch_size, channels, height, img_box_wstart),
dtype=images.dtype,
device=images.device,
)
* fill_value,
],
axis=3,
)
images = padded_img

resized = torchvision.transforms.functional.resize(
img=images,
# This implementation is based on
# https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
if antialias and interpolation not in ("bilinear", "bicubic"):
# We manually set it to False to avoid an error downstream in
# interpolate(). This behaviour is documented: the parameter is
# irrelevant for modes that are not bilinear or bicubic. We used to
# raise an error here, but now we don't use True as the default.
antialias = False
# Define align_corners to avoid warnings
align_corners = False if interpolation in ("bilinear", "bicubic") else None
resized = F.interpolate(
images,
size=size,
interpolation=RESIZE_INTERPOLATIONS[interpolation],
mode=RESIZE_INTERPOLATIONS[interpolation],
align_corners=align_corners,
antialias=antialias,
)
if interpolation == "bicubic" and out_dtype == torch.uint8:
resized = resized.clamp(min=0, max=255)
if data_format == "channels_last":
if len(images.shape) == 4:
resized = resized.permute((0, 2, 3, 1))
elif len(images.shape) == 3:
resized = resized.permute((1, 2, 0))
resized = resized.permute((0, 2, 3, 1))
resized = _cast_squeeze_out(
resized,
need_cast=need_cast,
need_squeeze=need_squeeze,
out_dtype=out_dtype,
)
return resized


Expand Down
1 change: 0 additions & 1 deletion keras/src/utils/module_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __repr__(self):
tensorflow_io = LazyModule("tensorflow_io")
scipy = LazyModule("scipy")
jax = LazyModule("jax")
torchvision = LazyModule("torchvision")
torch_xla = LazyModule(
"torch_xla",
import_error_msg=(
Expand Down
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pytest
numpy
scipy
scikit-learn
pillow
pandas
absl-py
requests
Expand Down
1 change: 0 additions & 1 deletion requirements-jax-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ tf2onnx
# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu
torchvision==0.21.0+cpu

# Jax with cuda support.
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Expand Down
1 change: 0 additions & 1 deletion requirements-tensorflow-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ tf2onnx
# Torch cpu-only version (needed for testing).
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu
torchvision==0.21.0+cpu

# Jax cpu-only version (needed for testing).
jax[cpu]
Expand Down
1 change: 0 additions & 1 deletion requirements-torch-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ tf2onnx
# - torch-xla is pinned to a version that supports GPU (2.6 doesn't)
--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.5.1+cu121
torchvision==0.20.1+cu121
torch-xla==2.5.1;sys_platform != 'darwin'

# Jax cpu-only version (needed for testing).
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ tf2onnx
# Torch.
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu
torchvision==0.21.0+cpu
torch-xla==2.6.0;sys_platform != 'darwin'

# Jax.
Expand Down

0 comments on commit b2d5b88

Please sign in to comment.