Skip to content

Commit

Permalink
The node running the central algorithm is excluded from the list of n…
Browse files Browse the repository at this point in the history
…odes that will perform the partial model training, given that in the MDT setting, the aggregator node will have no data
  • Loading branch information
hcadavid committed Jan 14, 2025
1 parent c0fc4f3 commit 247307f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 6 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,8 @@ venv.bak/
# mypy
.mypy_cache/

.vscode/
.vscode/

# model training performance statistics
federated_cvdm_training_poc/figure_ci_results_*
federated_cvdm_training_poc/ttest_ci/
11 changes: 10 additions & 1 deletion federated_cvdm_training_poc/central_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,16 @@ def central_ci(
""" Central part of the algorithm """
# central function.
# get all organizations (ids) within the collaboration so you can send a task to them.

organizations = client.organization.list()
org_ids = [organization.get("id") for organization in organizations]

# The central function is expected to be executed from an 'aggregator' node with
# no data. Hence, it is excluded (client.organization_id) from the list of organizations
# that will perform the partial model traning with their local datasets.
print(client.organization.list())
org_ids.remove(client.organization_id)

global_ci_list = [] # List of C-statistic for performance evaluation
local_ci_list = [] # List of C-statistic for performance evaluation

Expand Down Expand Up @@ -174,6 +181,8 @@ def central_ci(
"iterations":i,
"runtime":end_time - start_time,
"predictor_cols":predictor_cols,
"outcome_cols":outcome_cols
"outcome_cols":outcome_cols,
"data_nodes":org_ids,
"aggregator":client.organization_id
})

1 change: 1 addition & 0 deletions test/dummy_test_data/empty.dataset.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
empty,only,for,testing
15 changes: 11 additions & 4 deletions test/test_lifelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ def main():
"db_type": "csv",
"input_data": {}
}],
# Fourth organization (the aggregator) which shouldn't be used for the partial trainings - hence,
# an empty database is configured for it.
[{
"database": str(current_path/"dummy_test_data"/"empty.dataset.csv"),
"db_type": "csv",
"input_data": {}
}],


],
module="federated_cvdm_training_poc"
)
Expand Down Expand Up @@ -113,17 +122,15 @@ def main():
"agg_weight_filename": output_pth
}
},
# organizations=[org_ids[0]],
organizations=[1],
#The last organization (the aggregator)
organizations=[3],
)

results = client.wait_for_results(central_task.get("id"))
print(results)





if __name__ == '__main__':
main()

0 comments on commit 247307f

Please sign in to comment.