From 247307f95e2e7b3bed97fee4f69590783f4640c3 Mon Sep 17 00:00:00 2001 From: Hector Cadavid Date: Tue, 14 Jan 2025 17:01:52 +0100 Subject: [PATCH] The node running the central algorithm is excluded from the list of nodes that will perform the partial model training, given that in the MDT setting, the aggregator node will have no data --- .gitignore | 6 +++++- federated_cvdm_training_poc/central_ci.py | 11 ++++++++++- test/dummy_test_data/empty.dataset.csv | 1 + test/test_lifelines.py | 15 +++++++++++---- 4 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 test/dummy_test_data/empty.dataset.csv diff --git a/.gitignore b/.gitignore index 9747db7..83c211c 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,8 @@ venv.bak/ # mypy .mypy_cache/ -.vscode/ \ No newline at end of file +.vscode/ + +# model training performance statistics +federated_cvdm_training_poc/figure_ci_results_* +federated_cvdm_training_poc/ttest_ci/ \ No newline at end of file diff --git a/federated_cvdm_training_poc/central_ci.py b/federated_cvdm_training_poc/central_ci.py index f4116c6..5bab9d3 100644 --- a/federated_cvdm_training_poc/central_ci.py +++ b/federated_cvdm_training_poc/central_ci.py @@ -43,9 +43,16 @@ def central_ci( """ Central part of the algorithm """ # central function. # get all organizations (ids) within the collaboration so you can send a task to them. + organizations = client.organization.list() org_ids = [organization.get("id") for organization in organizations] + # The central function is expected to be executed from an 'aggregator' node with + # no data. Hence, it is excluded (client.organization_id) from the list of organizations + # that will perform the partial model traning with their local datasets. + print(client.organization.list()) + org_ids.remove(client.organization_id) + global_ci_list = [] # List of C-statistic for performance evaluation local_ci_list = [] # List of C-statistic for performance evaluation @@ -174,6 +181,8 @@ def central_ci( "iterations":i, "runtime":end_time - start_time, "predictor_cols":predictor_cols, - "outcome_cols":outcome_cols + "outcome_cols":outcome_cols, + "data_nodes":org_ids, + "aggregator":client.organization_id }) diff --git a/test/dummy_test_data/empty.dataset.csv b/test/dummy_test_data/empty.dataset.csv new file mode 100644 index 0000000..3206c2b --- /dev/null +++ b/test/dummy_test_data/empty.dataset.csv @@ -0,0 +1 @@ +empty,only,for,testing \ No newline at end of file diff --git a/test/test_lifelines.py b/test/test_lifelines.py index bdb31ad..4bde81b 100644 --- a/test/test_lifelines.py +++ b/test/test_lifelines.py @@ -69,6 +69,15 @@ def main(): "db_type": "csv", "input_data": {} }], + # Fourth organization (the aggregator) which shouldn't be used for the partial trainings - hence, + # an empty database is configured for it. + [{ + "database": str(current_path/"dummy_test_data"/"empty.dataset.csv"), + "db_type": "csv", + "input_data": {} + }], + + ], module="federated_cvdm_training_poc" ) @@ -113,8 +122,8 @@ def main(): "agg_weight_filename": output_pth } }, - # organizations=[org_ids[0]], - organizations=[1], + #The last organization (the aggregator) + organizations=[3], ) results = client.wait_for_results(central_task.get("id")) @@ -122,8 +131,6 @@ def main(): - - if __name__ == '__main__': main()