From 30ca1953e9876dccf9862bf3fa2cb18bcff720bd Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Thu, 21 Dec 2023 01:58:03 +0800 Subject: [PATCH] Fix compiled crossentropy (#18975) * Fix compiled crossentropy * Update test case naming --- keras/trainers/compile_utils.py | 2 ++ keras/trainers/compile_utils_test.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/keras/trainers/compile_utils.py b/keras/trainers/compile_utils.py index 8f830f4c9de..789844c31d7 100644 --- a/keras/trainers/compile_utils.py +++ b/keras/trainers/compile_utils.py @@ -468,6 +468,8 @@ def build(self, y_true, y_pred): "must be a callable. " f"Received instead:\nloss={loss} of type {type(loss)}" ) + if isinstance(y_pred, list) and len(y_pred) == 1: + y_pred = y_pred[0] if is_function_like(loss) and tree.is_nested(y_pred): # The model has multiple outputs but only one loss fn diff --git a/keras/trainers/compile_utils_test.py b/keras/trainers/compile_utils_test.py index f8927f111db..0d53dcacc35 100644 --- a/keras/trainers/compile_utils_test.py +++ b/keras/trainers/compile_utils_test.py @@ -251,6 +251,21 @@ def test_single_output_case(self): value = compile_loss(y_true, y_pred) self.assertAllClose(value, 0.068333, atol=1e-5) + def test_single_output_case_with_crossentropy_loss(self): + compile_loss = CompileLoss(loss="crossentropy") + + # Test symbolic build + y_true, y_pred = backend.KerasTensor((3, 4)), backend.KerasTensor( + (3, 4) + ) + compile_loss.build(y_true, y_pred) + # Test eager build + y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) + y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]) + compile_loss.build(y_true, y_pred) + value = compile_loss(y_true, y_pred) + self.assertAllClose(value, 0.706595, atol=1e-5) + @parameterized.parameters(True, False) def test_list_output_case(self, broadcast): if broadcast: