Skip to content

Commit

Permalink
Fix Randomhue (#20652)
Browse files Browse the repository at this point in the history
* Small fix in random hue

* use self.backend for seed
  • Loading branch information
IMvision12 authored Dec 17, 2024
1 parent 4c05e0c commit d8afc05
Showing 1 changed file with 5 additions and 24 deletions.
29 changes: 5 additions & 24 deletions keras/src/layers/preprocessing/image_preprocessing/random_hue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,6 @@
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator


def transform_value_range(images, original_range, target_range):
if (
original_range[0] == target_range[0]
and original_range[1] == target_range[1]
):
return images

original_min_value, original_max_value = original_range
target_min_value, target_max_value = target_range

# images in the [0, 1] scale
images = (images - original_min_value) / (
original_max_value - original_min_value
)

scale_factor = target_max_value - target_min_value
return (images * scale_factor) + target_min_value


@keras_export("keras.layers.RandomHue")
Expand Down Expand Up @@ -70,7 +50,7 @@ def __init__(
self._set_factor(factor)
self.value_range = value_range
self.seed = seed
self.generator = SeedGenerator(seed)
self.generator = self.backend.random.SeedGenerator(seed)

def get_random_transformation(self, data, training=True, seed=None):
if isinstance(data, dict):
Expand Down Expand Up @@ -107,7 +87,8 @@ def get_random_transformation(self, data, training=True, seed=None):
return {"factor": invert * factor * 0.5}

def transform_images(self, images, transformation=None, training=True):
images = transform_value_range(images, self.value_range, (0, 1))
images = self.backend.cast(images, self.compute_dtype)
images = self._transform_value_range(images, self.value_range, (0, 1))
adjust_factors = transformation["factor"]
adjust_factors = self.backend.cast(adjust_factors, images.dtype)
adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)
Expand Down Expand Up @@ -144,8 +125,8 @@ def transform_images(self, images, transformation=None, training=True):
)

images = self.backend.numpy.clip(images, 0, 1)
images = transform_value_range(images, (0, 1), self.value_range)

images = self._transform_value_range(images, (0, 1), self.value_range)
images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
Expand Down

0 comments on commit d8afc05

Please sign in to comment.