Skip to content

Commit

Permalink
Validate parameters for keras.Model.fit()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595627079
  • Loading branch information
ronshapiro authored and tensorflower-gardener committed Jan 4, 2024
1 parent 6f412b4 commit 87a2c04
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 17 deletions.
41 changes: 31 additions & 10 deletions tf_keras/engine/data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,20 +1259,25 @@ def __init__(
`Model` should always set this to `True`.
pss_evaluation_shards: See `Model.fit`.
"""
if batch_size is not None:
_check_positive("batch_size", batch_size)
if steps_per_epoch not in (None, -1) and steps_per_epoch <= 0:
raise ValueError(
"steps_per_epoch must be positive, None or -1. Received "
f"{steps_per_epoch}. See `Model.fit`."
)
self._initial_epoch = _check_non_negative(
"initial_epoch", initial_epoch
)
_check_positive("max_queue_size", max_queue_size)
_check_positive("workers", workers)
if steps_per_execution is not None:
_check_positive("steps_per_execution", steps_per_execution)

self._initial_epoch = initial_epoch
self._initial_step = 0
self._epochs = epochs
self._epochs = _check_positive("epochs", epochs)
self._insufficient_data = False
self._model = model

if steps_per_epoch == 0:
raise ValueError(
"Unexpected value for `steps_per_epoch`. Received value is 0. "
"Please check the docstring for `model.fit()` for supported "
"values."
)

self._steps_per_epoch = steps_per_epoch

# `steps_per_execution_value` is the cached initial value.
Expand Down Expand Up @@ -1954,6 +1959,22 @@ def _check_data_cardinality(data):
raise ValueError(msg)


def _check_non_negative(name, value):
if value < 0:
raise ValueError(
f"Expected {name} to be non-negative. Received is {value}."
)
return value


def _check_positive(name, value):
if value <= 0:
raise ValueError(
f"Expected {name} to be positive. Received is {value}."
)
return value


def _get_tensor_types():
if pd is None:
return (tf.Tensor, np.ndarray)
Expand Down
2 changes: 1 addition & 1 deletion tf_keras/engine/data_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ def test_error_if_zero_steps_per_epoch(self):

with self.assertRaisesRegex(
ValueError,
"Unexpected value for `steps_per_epoch`. Received value is 0.",
"steps_per_epoch must be positive, None or -1. Received 0.",
):
data_adapter.DataHandler(
data, initial_epoch=0, epochs=2, steps_per_epoch=0
Expand Down
12 changes: 6 additions & 6 deletions tf_keras/engine/training_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_fit_generator_method(self):
steps_per_epoch=5,
validation_data=custom_generator(),
validation_steps=1,
workers=0,
workers=1,
)

@test_combinations.run_with_all_model_types
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_evaluate_generator_method(self):
steps=5,
max_queue_size=10,
use_multiprocessing=False,
workers=0,
workers=1,
)

@test_combinations.run_with_all_model_types
Expand All @@ -192,7 +192,7 @@ def test_predict_generator_method(self):
use_multiprocessing=False,
)
model.predict_generator(
custom_generator(), steps=5, max_queue_size=10, workers=0
custom_generator(), steps=5, max_queue_size=10, workers=1
)
# Test generator with just inputs (no targets)
model.predict_generator(
Expand All @@ -209,7 +209,7 @@ def test_predict_generator_method(self):
use_multiprocessing=False,
)
model.predict_generator(
custom_generator(mode=1), steps=5, max_queue_size=10, workers=0
custom_generator(mode=1), steps=5, max_queue_size=10, workers=1
)

@test_combinations.run_with_all_model_types
Expand Down Expand Up @@ -453,7 +453,7 @@ def __len__(self):
validation_data=custom_generator(),
validation_steps=1,
max_queue_size=10,
workers=0,
workers=1,
use_multiprocessing=True,
)
model.fit_generator(
Expand All @@ -462,7 +462,7 @@ def __len__(self):
validation_data=custom_generator(),
validation_steps=1,
max_queue_size=10,
workers=0,
workers=1,
use_multiprocessing=False,
)

Expand Down

0 comments on commit 87a2c04

Please sign in to comment.