diff --git a/tests/unit_test/app_common/aggregators/in_time_accumulate_weighted_aggregator_test.py b/tests/unit_test/app_common/aggregators/in_time_accumulate_weighted_aggregator_test.py index e96c81ea5f..3f5ae02e4c 100644 --- a/tests/unit_test/app_common/aggregators/in_time_accumulate_weighted_aggregator_test.py +++ b/tests/unit_test/app_common/aggregators/in_time_accumulate_weighted_aggregator_test.py @@ -72,11 +72,14 @@ class TestInTimeAccumulateWeightedAggregator: ) def test_invalid_create(self, exclude_vars, aggregation_weights, expected_data_kind, error, error_msg): with pytest.raises(error, match=re.escape(error_msg)): - _ = InTimeAccumulateWeightedAggregator( + aggregator = InTimeAccumulateWeightedAggregator( exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, expected_data_kind=expected_data_kind, ) + aggregator._initialize( + aggregator.aggregation_weights, aggregator.exclude_vars, aggregator.expected_data_kind + ) @pytest.mark.parametrize( "exclude_vars,aggregation_weights,expected_data_kind,expected_object", @@ -115,9 +118,13 @@ def test_invalid_create(self, exclude_vars, aggregation_weights, expected_data_k ], ) def test_create(self, exclude_vars, aggregation_weights, expected_data_kind, expected_object): + expected_object._initialize( + expected_object.aggregation_weights, expected_object.exclude_vars, expected_object.expected_data_kind + ) result = InTimeAccumulateWeightedAggregator( exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, expected_data_kind=expected_data_kind ) + result._initialize(result.aggregation_weights, result.exclude_vars, result.expected_data_kind) assert result.exclude_vars == expected_object.exclude_vars assert result.aggregation_weights == expected_object.aggregation_weights assert result.expected_data_kind == expected_object.expected_data_kind @@ -126,6 +133,7 @@ def test_create(self, exclude_vars, aggregation_weights, expected_data_kind, exp def test_accept(self, current_round, contribution_round, expected): aggregation_weights = {f"client_{i}": random.random() for i in range(2)} agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights) + agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind) client_name = "client_0" iter_number = 1 weights = np.random.random(4) @@ -192,6 +200,7 @@ def test_accept(self, current_round, contribution_round, expected): def test_aggregate(self, received, expected): aggregation_weights = {k: v["weight"] for k, v in received.items()} agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights) + agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind) fl_ctx = FLContext() fl_ctx.set_prop(AppConstants.CURRENT_ROUND, 0) for k, v in received.items(): @@ -216,6 +225,7 @@ def test_aggregate(self, received, expected): def test_aggregate_random(self, shape, n_clients): aggregation_weights = {f"client_{i}": random.random() for i in range(n_clients)} agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights) + agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind) weighted_sum = np.zeros(shape) sum_of_weights = 0 fl_ctx = FLContext() @@ -254,6 +264,7 @@ def test_aggregate_random_dxos(self, num_dxo, shape, n_clients): aggregation_weights=aggregation_weights, expected_data_kind={dxo_name: DataKind.WEIGHT_DIFF for dxo_name in dxo_names}, ) + agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind) weighted_sum = {dxo_name: np.zeros(shape) for dxo_name in dxo_names} sum_of_weights = {dxo_name: 0 for dxo_name in dxo_names} fl_ctx = FLContext()