Skip to content

Commit

Permalink
fixed the in_time_accumulate_weighted_aggregator_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yhwen committed Apr 11, 2024
1 parent 7869697 commit 1152511
Showing 1 changed file with 12 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,14 @@ class TestInTimeAccumulateWeightedAggregator:
)
def test_invalid_create(self, exclude_vars, aggregation_weights, expected_data_kind, error, error_msg):
with pytest.raises(error, match=re.escape(error_msg)):
_ = InTimeAccumulateWeightedAggregator(
aggregator = InTimeAccumulateWeightedAggregator(
exclude_vars=exclude_vars,
aggregation_weights=aggregation_weights,
expected_data_kind=expected_data_kind,
)
aggregator._initialize(
aggregator.aggregation_weights, aggregator.exclude_vars, aggregator.expected_data_kind
)

@pytest.mark.parametrize(
"exclude_vars,aggregation_weights,expected_data_kind,expected_object",
Expand Down Expand Up @@ -115,9 +118,13 @@ def test_invalid_create(self, exclude_vars, aggregation_weights, expected_data_k
],
)
def test_create(self, exclude_vars, aggregation_weights, expected_data_kind, expected_object):
expected_object._initialize(
expected_object.aggregation_weights, expected_object.exclude_vars, expected_object.expected_data_kind
)
result = InTimeAccumulateWeightedAggregator(
exclude_vars=exclude_vars, aggregation_weights=aggregation_weights, expected_data_kind=expected_data_kind
)
result._initialize(result.aggregation_weights, result.exclude_vars, result.expected_data_kind)
assert result.exclude_vars == expected_object.exclude_vars
assert result.aggregation_weights == expected_object.aggregation_weights
assert result.expected_data_kind == expected_object.expected_data_kind
Expand All @@ -126,6 +133,7 @@ def test_create(self, exclude_vars, aggregation_weights, expected_data_kind, exp
def test_accept(self, current_round, contribution_round, expected):
aggregation_weights = {f"client_{i}": random.random() for i in range(2)}
agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights)
agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind)
client_name = "client_0"
iter_number = 1
weights = np.random.random(4)
Expand Down Expand Up @@ -192,6 +200,7 @@ def test_accept(self, current_round, contribution_round, expected):
def test_aggregate(self, received, expected):
aggregation_weights = {k: v["weight"] for k, v in received.items()}
agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights)
agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind)
fl_ctx = FLContext()
fl_ctx.set_prop(AppConstants.CURRENT_ROUND, 0)
for k, v in received.items():
Expand All @@ -216,6 +225,7 @@ def test_aggregate(self, received, expected):
def test_aggregate_random(self, shape, n_clients):
aggregation_weights = {f"client_{i}": random.random() for i in range(n_clients)}
agg = InTimeAccumulateWeightedAggregator(aggregation_weights=aggregation_weights)
agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind)
weighted_sum = np.zeros(shape)
sum_of_weights = 0
fl_ctx = FLContext()
Expand Down Expand Up @@ -254,6 +264,7 @@ def test_aggregate_random_dxos(self, num_dxo, shape, n_clients):
aggregation_weights=aggregation_weights,
expected_data_kind={dxo_name: DataKind.WEIGHT_DIFF for dxo_name in dxo_names},
)
agg._initialize(agg.aggregation_weights, agg.exclude_vars, agg.expected_data_kind)
weighted_sum = {dxo_name: np.zeros(shape) for dxo_name in dxo_names}
sum_of_weights = {dxo_name: 0 for dxo_name in dxo_names}
fl_ctx = FLContext()
Expand Down

0 comments on commit 1152511

Please sign in to comment.