Skip to content

Commit

Permalink
Update example and logic for mix_up (#20642)
Browse files Browse the repository at this point in the history
* Update example and logic for mix_up

* remove tf from example
  • Loading branch information
shashaka authored Dec 13, 2024
1 parent 7a16b8e commit 4aa6a67
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions keras/src/layers/preprocessing/image_preprocessing/mix_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,10 @@ class MixUp(BaseImagePreprocessingLayer):
Example:
```python
(images, labels), _ = keras.datasets.cifar10.load_data()
images, labels = images[:10], labels[:10]
# Labels must be floating-point and one-hot encoded
labels = tf.cast(tf.one_hot(labels, 10), tf.float32)
mixup = keras.layers.MixUp(alpha=0.2)
augmented_images, updated_labels = mixup(
{'images': images, 'labels': labels}
)
# output == {'images': updated_images, 'labels': updated_labels}
images, labels = images[:8], labels[:8]
labels = keras.ops.cast(keras.ops.one_hot(labels.flatten(), 10), "float32")
mix_up = keras.layers.MixUp(alpha=0.2)
output = mix_up({"images": images, "labels": labels})
```
"""

Expand Down Expand Up @@ -62,7 +58,7 @@ def get_random_transformation(self, data, training=True, seed=None):
)

mix_weight = self.backend.random.beta(
(1,), self.alpha, self.alpha, seed=seed
(batch_size,), self.alpha, self.alpha, seed=seed
)
return {
"mix_weight": mix_weight,
Expand All @@ -79,26 +75,26 @@ def transform_images(self, images, transformation=None, training=True):
dtype=self.compute_dtype,
)

mixup_images = self.backend.cast(
mix_up_images = self.backend.cast(
self.backend.numpy.take(images, permutation_order, axis=0),
dtype=self.compute_dtype,
)

images = mix_weight * images + (1.0 - mix_weight) * mixup_images
images = mix_weight * images + (1.0 - mix_weight) * mix_up_images

return images

def transform_labels(self, labels, transformation, training=True):
mix_weight = transformation["mix_weight"]
permutation_order = transformation["permutation_order"]

labels_for_mixup = self.backend.numpy.take(
labels_for_mix_up = self.backend.numpy.take(
labels, permutation_order, axis=0
)

mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])

labels = mix_weight * labels + (1.0 - mix_weight) * labels_for_mixup
labels = mix_weight * labels + (1.0 - mix_weight) * labels_for_mix_up

return labels

Expand All @@ -110,11 +106,11 @@ def transform_bounding_boxes(
):
permutation_order = transformation["permutation_order"]
boxes, classes = bounding_boxes["boxes"], bounding_boxes["classes"]
boxes_for_mixup = self.backend.numpy.take(boxes, permutation_order)
classes_for_mixup = self.backend.numpy.take(classes, permutation_order)
boxes = self.backend.numpy.concat([boxes, boxes_for_mixup], axis=1)
boxes_for_mix_up = self.backend.numpy.take(boxes, permutation_order)
classes_for_mix_up = self.backend.numpy.take(classes, permutation_order)
boxes = self.backend.numpy.concat([boxes, boxes_for_mix_up], axis=1)
classes = self.backend.numpy.concat(
[classes, classes_for_mixup], axis=1
[classes, classes_for_mix_up], axis=1
)
return {"boxes": boxes, "classes": classes}

Expand All @@ -126,13 +122,13 @@ def transform_segmentation_masks(

mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])

segmentation_masks_for_mixup = self.backend.numpy.take(
segmentation_masks_for_mix_up = self.backend.numpy.take(
segmentation_masks, permutation_order
)

segmentation_masks = (
mix_weight * segmentation_masks
+ (1.0 - mix_weight) * segmentation_masks_for_mixup
+ (1.0 - mix_weight) * segmentation_masks_for_mix_up
)

return segmentation_masks
Expand Down

0 comments on commit 4aa6a67

Please sign in to comment.