Skip to content

Commit

Permalink
Add implementations for random_saturation (#20646)
Browse files Browse the repository at this point in the history
* Correct bug for MixUp initialization.

* Update format indent

* Add implementations for random_saturation

* change parse_factor method to inner method.

* correct test cases failed.

* correct failed test cases

* Add training argument check condition

* correct source code

* add value_range args description

* update description example

* change _apply_random_saturation method to inline
  • Loading branch information
shashaka authored Dec 17, 2024
1 parent 2b6c800 commit 9bcf324
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
RandomSaturation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
RandomSaturation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_rotation import (
RandomRotation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_saturation import (
RandomSaturation,
)
from keras.src.layers.preprocessing.image_preprocessing.random_translation import (
RandomTranslation,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator


@keras_export("keras.layers.RandomSaturation")
class RandomSaturation(BaseImagePreprocessingLayer):
"""Randomly adjusts the saturation on given images.
This layer will randomly increase/reduce the saturation for the input RGB
images.
Args:
factor: A tuple of two floats or a single float.
`factor` controls the extent to which the image saturation
is impacted. `factor=0.5` makes this layer perform a no-op
operation. `factor=0.0` makes the image fully grayscale.
`factor=1.0` makes the image fully saturated. Values should
be between `0.0` and `1.0`. If a tuple is used, a `factor`
is sampled between the two values for every image augmented.
If a single float is used, a value between `0.0` and the passed
float is sampled. To ensure the value is always the same,
pass a tuple with two identical floats: `(0.5, 0.5)`.
value_range: the range of values the incoming images will have.
Represented as a two-number tuple written `[low, high]`. This is
typically either `[0, 1]` or `[0, 255]` depending on how your
preprocessing pipeline is set up.
seed: Integer. Used to create a random seed.
Example:
```python
(images, labels), _ = keras.datasets.cifar10.load_data()
images = images.astype("float32")
random_saturation = keras.layers.RandomSaturation(factor=0.2)
augmented_images = random_saturation(images)
```
"""

_VALUE_RANGE_VALIDATION_ERROR = (
"The `value_range` argument should be a list of two numbers. "
)

def __init__(
self,
factor,
value_range=(0, 255),
data_format=None,
seed=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)
self._set_factor(factor)
self._set_value_range(value_range)
self.seed = seed
self.generator = SeedGenerator(seed)

def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self.value_range = sorted(value_range)

def get_random_transformation(self, data, training=True, seed=None):
if isinstance(data, dict):
images = data["images"]
else:
images = data
images_shape = self.backend.shape(images)
rank = len(images_shape)
if rank == 3:
batch_size = 1
elif rank == 4:
batch_size = images_shape[0]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received: "
f"inputs.shape={images_shape}"
)

if seed is None:
seed = self._get_seed_generator(self.backend._backend)

factor = self.backend.random.uniform(
(batch_size,),
minval=self.factor[0],
maxval=self.factor[1],
seed=seed,
)
factor = factor / (1 - factor)
return {"factor": factor}

def transform_images(self, images, transformation=None, training=True):
def _apply_random_saturation(images, transformation):
adjust_factors = transformation["factor"]
adjust_factors = self.backend.cast(
adjust_factors, self.compute_dtype
)
adjust_factors = self.backend.numpy.reshape(
adjust_factors, self.backend.shape(adjust_factors) + (1, 1)
)
images = self.backend.image.rgb_to_hsv(
images, data_format=self.data_format
)
if self.data_format == "channels_first":
s_channel = self.backend.numpy.multiply(
images[:, 1, :, :], adjust_factors
)
s_channel = self.backend.numpy.clip(
s_channel, self.value_range[0], self.value_range[1]
)
images = self.backend.numpy.stack(
[images[:, 0, :, :], s_channel, images[:, 2, :, :]], axis=1
)
else:
s_channel = self.backend.numpy.multiply(
images[..., 1], adjust_factors
)
s_channel = self.backend.numpy.clip(
s_channel, self.value_range[0], self.value_range[1]
)
images = self.backend.numpy.stack(
[images[..., 0], s_channel, images[..., 2]], axis=-1
)
images = self.backend.image.hsv_to_rgb(
images, data_format=self.data_format
)
return images

if training:
images = _apply_random_saturation(images, transformation)
return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

def transform_bounding_boxes(
self, bounding_boxes, transformation, training=True
):
return bounding_boxes

def get_config(self):
config = super().get_config()
config.update(
{
"factor": self.factor,
"value_range": self.value_range,
"seed": self.seed,
}
)
return config

def compute_output_shape(self, input_shape):
return input_shape
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

import keras
from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandomSaturationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomSaturation,
init_kwargs={
"factor": 0.75,
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_random_saturation_value_range(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)

layer = layers.RandomSaturation(0.2)
adjusted_image = layer(image)

self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))

def test_random_saturation_no_op(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
inputs = np.random.random((2, 8, 8, 3))
else:
inputs = np.random.random((2, 3, 8, 8))

layer = layers.RandomSaturation((0.5, 0.5))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)

def test_random_saturation_full_grayscale(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
inputs = np.random.random((2, 8, 8, 3))
else:
inputs = np.random.random((2, 3, 8, 8))
layer = layers.RandomSaturation(factor=(0.0, 0.0))
result = layer(inputs)

if data_format == "channels_last":
self.assertAllClose(result[..., 0], result[..., 1])
self.assertAllClose(result[..., 1], result[..., 2])
else:
self.assertAllClose(result[:, 0, :, :], result[:, 1, :, :])
self.assertAllClose(result[:, 1, :, :], result[:, 2, :, :])

def test_random_saturation_full_saturation(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
inputs = np.random.random((2, 8, 8, 3))
else:
inputs = np.random.random((2, 3, 8, 8))
layer = layers.RandomSaturation(factor=(1.0, 1.0))
result = layer(inputs)

hsv = backend.image.rgb_to_hsv(result)
s_channel = hsv[..., 1]

self.assertAllClose(
keras.ops.numpy.max(s_channel), layer.value_range[1]
)

def test_random_saturation_randomness(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]

layer = layers.RandomSaturation(0.2)
adjusted_images = layer(image)

self.assertNotAllClose(adjusted_images, image)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandomSaturation(
factor=0.5, data_format=data_format, seed=1337
)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()

0 comments on commit 9bcf324

Please sign in to comment.