diff --git a/tests/benchmarks/scenarios/test_task_aware.py b/tests/benchmarks/scenarios/test_task_aware.py index 6686a224f..a7822f661 100644 --- a/tests/benchmarks/scenarios/test_task_aware.py +++ b/tests/benchmarks/scenarios/test_task_aware.py @@ -16,7 +16,8 @@ class TestsTaskAware(unittest.TestCase): def test_taskaware(self): """Common use case: add tas labels to class-incremental benchmark.""" n_classes, n_samples_per_class, n_features = 10, 3, 7 - while True: + + for _ in range(10000): dataset = make_classification( n_samples=n_classes * n_samples_per_class, n_classes=n_classes, @@ -25,6 +26,9 @@ def test_taskaware(self): n_redundant=0, ) + # The following check is required to ensure that at least 2 exemplars + # per class are generated. Otherwise, the train_test_split function will + # fail. _, unique_count = np.unique(dataset[1], return_counts=True) if np.min(unique_count) > 1: break