-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add implementations for random_saturation (#20646)
* Correct bug for MixUp initialization. * Update format indent * Add implementations for random_saturation * change parse_factor method to inner method. * correct test cases failed. * correct failed test cases * Add training argument check condition * correct source code * add value_range args description * update description example * change _apply_random_saturation method to inline
- Loading branch information
Showing
5 changed files
with
273 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
167 changes: 167 additions & 0 deletions
167
keras/src/layers/preprocessing/image_preprocessing/random_saturation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
from keras.src.api_export import keras_export | ||
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 | ||
BaseImagePreprocessingLayer, | ||
) | ||
from keras.src.random import SeedGenerator | ||
|
||
|
||
@keras_export("keras.layers.RandomSaturation") | ||
class RandomSaturation(BaseImagePreprocessingLayer): | ||
"""Randomly adjusts the saturation on given images. | ||
This layer will randomly increase/reduce the saturation for the input RGB | ||
images. | ||
Args: | ||
factor: A tuple of two floats or a single float. | ||
`factor` controls the extent to which the image saturation | ||
is impacted. `factor=0.5` makes this layer perform a no-op | ||
operation. `factor=0.0` makes the image fully grayscale. | ||
`factor=1.0` makes the image fully saturated. Values should | ||
be between `0.0` and `1.0`. If a tuple is used, a `factor` | ||
is sampled between the two values for every image augmented. | ||
If a single float is used, a value between `0.0` and the passed | ||
float is sampled. To ensure the value is always the same, | ||
pass a tuple with two identical floats: `(0.5, 0.5)`. | ||
value_range: the range of values the incoming images will have. | ||
Represented as a two-number tuple written `[low, high]`. This is | ||
typically either `[0, 1]` or `[0, 255]` depending on how your | ||
preprocessing pipeline is set up. | ||
seed: Integer. Used to create a random seed. | ||
Example: | ||
```python | ||
(images, labels), _ = keras.datasets.cifar10.load_data() | ||
images = images.astype("float32") | ||
random_saturation = keras.layers.RandomSaturation(factor=0.2) | ||
augmented_images = random_saturation(images) | ||
``` | ||
""" | ||
|
||
_VALUE_RANGE_VALIDATION_ERROR = ( | ||
"The `value_range` argument should be a list of two numbers. " | ||
) | ||
|
||
def __init__( | ||
self, | ||
factor, | ||
value_range=(0, 255), | ||
data_format=None, | ||
seed=None, | ||
**kwargs, | ||
): | ||
super().__init__(data_format=data_format, **kwargs) | ||
self._set_factor(factor) | ||
self._set_value_range(value_range) | ||
self.seed = seed | ||
self.generator = SeedGenerator(seed) | ||
|
||
def _set_value_range(self, value_range): | ||
if not isinstance(value_range, (tuple, list)): | ||
raise ValueError( | ||
self._VALUE_RANGE_VALIDATION_ERROR | ||
+ f"Received: value_range={value_range}" | ||
) | ||
if len(value_range) != 2: | ||
raise ValueError( | ||
self._VALUE_RANGE_VALIDATION_ERROR | ||
+ f"Received: value_range={value_range}" | ||
) | ||
self.value_range = sorted(value_range) | ||
|
||
def get_random_transformation(self, data, training=True, seed=None): | ||
if isinstance(data, dict): | ||
images = data["images"] | ||
else: | ||
images = data | ||
images_shape = self.backend.shape(images) | ||
rank = len(images_shape) | ||
if rank == 3: | ||
batch_size = 1 | ||
elif rank == 4: | ||
batch_size = images_shape[0] | ||
else: | ||
raise ValueError( | ||
"Expected the input image to be rank 3 or 4. Received: " | ||
f"inputs.shape={images_shape}" | ||
) | ||
|
||
if seed is None: | ||
seed = self._get_seed_generator(self.backend._backend) | ||
|
||
factor = self.backend.random.uniform( | ||
(batch_size,), | ||
minval=self.factor[0], | ||
maxval=self.factor[1], | ||
seed=seed, | ||
) | ||
factor = factor / (1 - factor) | ||
return {"factor": factor} | ||
|
||
def transform_images(self, images, transformation=None, training=True): | ||
def _apply_random_saturation(images, transformation): | ||
adjust_factors = transformation["factor"] | ||
adjust_factors = self.backend.cast( | ||
adjust_factors, self.compute_dtype | ||
) | ||
adjust_factors = self.backend.numpy.reshape( | ||
adjust_factors, self.backend.shape(adjust_factors) + (1, 1) | ||
) | ||
images = self.backend.image.rgb_to_hsv( | ||
images, data_format=self.data_format | ||
) | ||
if self.data_format == "channels_first": | ||
s_channel = self.backend.numpy.multiply( | ||
images[:, 1, :, :], adjust_factors | ||
) | ||
s_channel = self.backend.numpy.clip( | ||
s_channel, self.value_range[0], self.value_range[1] | ||
) | ||
images = self.backend.numpy.stack( | ||
[images[:, 0, :, :], s_channel, images[:, 2, :, :]], axis=1 | ||
) | ||
else: | ||
s_channel = self.backend.numpy.multiply( | ||
images[..., 1], adjust_factors | ||
) | ||
s_channel = self.backend.numpy.clip( | ||
s_channel, self.value_range[0], self.value_range[1] | ||
) | ||
images = self.backend.numpy.stack( | ||
[images[..., 0], s_channel, images[..., 2]], axis=-1 | ||
) | ||
images = self.backend.image.hsv_to_rgb( | ||
images, data_format=self.data_format | ||
) | ||
return images | ||
|
||
if training: | ||
images = _apply_random_saturation(images, transformation) | ||
return images | ||
|
||
def transform_labels(self, labels, transformation, training=True): | ||
return labels | ||
|
||
def transform_segmentation_masks( | ||
self, segmentation_masks, transformation, training=True | ||
): | ||
return segmentation_masks | ||
|
||
def transform_bounding_boxes( | ||
self, bounding_boxes, transformation, training=True | ||
): | ||
return bounding_boxes | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"factor": self.factor, | ||
"value_range": self.value_range, | ||
"seed": self.seed, | ||
} | ||
) | ||
return config | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape |
97 changes: 97 additions & 0 deletions
97
keras/src/layers/preprocessing/image_preprocessing/random_saturation_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import numpy as np | ||
import pytest | ||
from tensorflow import data as tf_data | ||
|
||
import keras | ||
from keras.src import backend | ||
from keras.src import layers | ||
from keras.src import testing | ||
|
||
|
||
class RandomSaturationTest(testing.TestCase): | ||
@pytest.mark.requires_trainable_backend | ||
def test_layer(self): | ||
self.run_layer_test( | ||
layers.RandomSaturation, | ||
init_kwargs={ | ||
"factor": 0.75, | ||
"seed": 1, | ||
}, | ||
input_shape=(8, 3, 4, 3), | ||
supports_masking=False, | ||
expected_output_shape=(8, 3, 4, 3), | ||
) | ||
|
||
def test_random_saturation_value_range(self): | ||
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) | ||
|
||
layer = layers.RandomSaturation(0.2) | ||
adjusted_image = layer(image) | ||
|
||
self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) | ||
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) | ||
|
||
def test_random_saturation_no_op(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
inputs = np.random.random((2, 8, 8, 3)) | ||
else: | ||
inputs = np.random.random((2, 3, 8, 8)) | ||
|
||
layer = layers.RandomSaturation((0.5, 0.5)) | ||
output = layer(inputs, training=False) | ||
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5) | ||
|
||
def test_random_saturation_full_grayscale(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
inputs = np.random.random((2, 8, 8, 3)) | ||
else: | ||
inputs = np.random.random((2, 3, 8, 8)) | ||
layer = layers.RandomSaturation(factor=(0.0, 0.0)) | ||
result = layer(inputs) | ||
|
||
if data_format == "channels_last": | ||
self.assertAllClose(result[..., 0], result[..., 1]) | ||
self.assertAllClose(result[..., 1], result[..., 2]) | ||
else: | ||
self.assertAllClose(result[:, 0, :, :], result[:, 1, :, :]) | ||
self.assertAllClose(result[:, 1, :, :], result[:, 2, :, :]) | ||
|
||
def test_random_saturation_full_saturation(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
inputs = np.random.random((2, 8, 8, 3)) | ||
else: | ||
inputs = np.random.random((2, 3, 8, 8)) | ||
layer = layers.RandomSaturation(factor=(1.0, 1.0)) | ||
result = layer(inputs) | ||
|
||
hsv = backend.image.rgb_to_hsv(result) | ||
s_channel = hsv[..., 1] | ||
|
||
self.assertAllClose( | ||
keras.ops.numpy.max(s_channel), layer.value_range[1] | ||
) | ||
|
||
def test_random_saturation_randomness(self): | ||
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5] | ||
|
||
layer = layers.RandomSaturation(0.2) | ||
adjusted_images = layer(image) | ||
|
||
self.assertNotAllClose(adjusted_images, image) | ||
|
||
def test_tf_data_compatibility(self): | ||
data_format = backend.config.image_data_format() | ||
if data_format == "channels_last": | ||
input_data = np.random.random((2, 8, 8, 3)) | ||
else: | ||
input_data = np.random.random((2, 3, 8, 8)) | ||
layer = layers.RandomSaturation( | ||
factor=0.5, data_format=data_format, seed=1337 | ||
) | ||
|
||
ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) | ||
for output in ds.take(1): | ||
output.numpy() |