forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
29 changed files
with
1,077 additions
and
239 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Secure Federated Kaplan-Meier Analysis via Homomorphic Encryption | ||
|
||
This example illustrates two features: | ||
* How to perform Kaplan-Meier survival analysis in federated setting securely via Homomorphic Encryption (HE). | ||
* How to use the Flare Workflow Controller API to contract a workflow to facilitate HE under simulator mode. | ||
|
||
## Secure Multi-party Kaplan-Meier Analysis | ||
Kaplan-Meier survival analysis is a one-shot (non-iterative) analysis performed on a list of events and their corresponding time. In this example, we use [lifelines](https://zenodo.org/records/10456828) to perform this analysis. | ||
|
||
Essentially, the estimator needs to get access to the event list, and under the setting of federated analysis, the aggregated event list from all participants. | ||
|
||
However, this poses a data security concern - by sharing the event list, the raw data can be exposed to external parties, which break the core value of federated analysis. | ||
|
||
Therefore, we would like to design a secure mechanism to enable collaborative Kaplan-Meier analysis without the risk of exposing any raw information from a certain participant (at server end). This is achieved by two techniques: | ||
|
||
- Condense the raw event list to two histograms (one for observed events and the other for censored event) binned at certain interval (e.g. a week), such that events happened within the same bin from different participants can be aggregated and will not be distinguishable for the final aggregated histograms. | ||
- The local histograms will be encrypted as one single vector before sending to server, and the global aggregation operation at server side will be performed entirely within encryption space with HE. | ||
|
||
With these two settings, the server will have no access to any knowledge regarding local submissions, and participants will only receive global aggregated histograms that will not contain distinguishable information regarding any individual participants (client number >= 3 - if only two participants, one can infer the other party's info by subtracting its own histograms). | ||
|
||
The final Kaplan-Meier survival analysis will be performed locally on the global aggregated event list, recovered from global histograms. | ||
|
||
|
||
## Simulated HE Analysis via FLARE Workflow Controller API | ||
|
||
The Flare Workflow Controller API (`WFController`) provides the functionality of flexible FLModel payloads for each round of federated analysis. This gives us the flexibility of transmitting various information needed by our scheme at different stages of federated learning. | ||
|
||
Our [existing HE examples](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) uses data filter mechanism for HE, provisioning the HE context information (specs and keys) for both client and server of the federated job under [CKKS](https://github.com/NVIDIA/NVFlare/blob/main/nvflare/app_opt/he/model_encryptor.py) scheme. In this example, we would like to illustrate WFController's capability in supporting customized needs beyond the existing HE functionalities (designed mainly for encrypting deep learning models). | ||
- different HE schemes (BFV) rather than CKKS | ||
- different content at different rounds of federated learning, and only specific payload needs to be encrypted | ||
|
||
With the WFController API, such "proof of concept" experiment becomes easy. In this example, the federated analysis pipeline includes 3 rounds: | ||
1. Server send the simple start message without any payload. | ||
2. Clients collect the information of the local maximum bin number (for event time) and send to server, where server aggregates the information by selecting the maximum among all clients. The global maximum number is then distributed back to clients. This step is necessary because we would like to standardize the histograms generated by all clients, such that they will have the exact same length and can be encrypted as vectors of same size, which will be addable. | ||
3. Clients condense their local raw event lists into two histograms with the global length received, encrypt the histrogram value vectors, and send to server. Server aggregated the received histograms by adding the encrypted vectors together, and sends the aggregated histograms back to clients. | ||
|
||
After Round 3, the federated work is completed. Then at each client, the aggregated histograms will be decrypted and converted back to an event list, and Kaplan-Meier analysis can be performed on the global information. | ||
|
||
## Run the job | ||
We first run a baseline analysis with full event information: | ||
```commandline | ||
python baseline_kaplan_meier.py | ||
``` | ||
By default, this will generate a KM curve image `km_curve_baseline.png` under the current working directory. | ||
|
||
Then we run a 5-client federated job with simulator, begin with splitting and generating the data files for each client: | ||
```commandline | ||
python utils/prepare_data.py --out_path "/tmp/flare/dataset/km_data" | ||
``` | ||
Then we prepare HE context for clients and server, note that this step is done by secure provisioning for real-life applications, but in this study experimenting with BFV scheme, we use this step to distribute the HE context. | ||
```commandline | ||
python utils/prepare_he_context.py --out_path "/tmp/flare/he_context" | ||
``` | ||
|
||
Next, we set the location of the job templates directory. | ||
```commandline | ||
nvflare config -jt ./job_templates | ||
``` | ||
|
||
Then we can generate the job configuration from the `kaplan_meier_he` template: | ||
|
||
```commandline | ||
N_CLIENTS=5 | ||
nvflare job create -force -j "./jobs/kaplan-meier-he" -w "kaplan_meier_he" -sd "./src" \ | ||
-f config_fed_client.conf app_script="kaplan_meier_train.py" app_config="--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" \ | ||
-f config_fed_server.conf min_clients=${N_CLIENTS} he_context_path="/tmp/flare/he_context/he_context_server.txt" | ||
``` | ||
|
||
And we can run the federated job: | ||
```commandline | ||
nvflare simulator -w workspace_km_he -n 5 -t 5 jobs/kaplan-meier-he | ||
``` | ||
By default, this will generate a KM curve image `km_curve_fl.png` under each client's directory. | ||
|
||
## Display Result | ||
|
||
By comparing the two curves, we can observe that the two are identical: | ||
 | ||
 |
72 changes: 72 additions & 0 deletions
72
examples/advanced/kaplan-meier-he/baseline_kaplan_meier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from lifelines import KaplanMeierFitter | ||
from sksurv.datasets import load_veterans_lung_cancer | ||
|
||
|
||
def args_parser(): | ||
parser = argparse.ArgumentParser(description="Kaplan Meier Survival Analysis Baseline") | ||
parser.add_argument( | ||
"--output_curve_path", | ||
type=str, | ||
default="./km_curve_baseline.png", | ||
help="save path for the output curve", | ||
) | ||
return parser | ||
|
||
|
||
def prepare_data(bin_days: int = 7): | ||
data_x, data_y = load_veterans_lung_cancer() | ||
total_data_num = data_x.shape[0] | ||
print(f"Total data count: {total_data_num}") | ||
event = data_y["Status"] | ||
time = data_y["Survival_in_days"] | ||
# Categorize data to a bin, default is a week (7 days) | ||
time = np.ceil(time / bin_days).astype(int) | ||
return event, time | ||
|
||
|
||
def main(): | ||
parser = args_parser() | ||
args = parser.parse_args() | ||
|
||
# Set parameters | ||
output_curve_path = args.output_curve_path | ||
|
||
# Generate data | ||
event, time = prepare_data() | ||
|
||
# Fit and plot Kaplan Meier curve with lifelines | ||
kmf = KaplanMeierFitter() | ||
# Fit the survival data | ||
kmf.fit(time, event) | ||
# Plot and save the Kaplan-Meier survival curve | ||
plt.figure() | ||
plt.title("Baseline") | ||
kmf.plot_survival_function() | ||
plt.ylim(0, 1) | ||
plt.ylabel("prob") | ||
plt.xlabel("time") | ||
plt.legend("", frameon=False) | ||
plt.tight_layout() | ||
plt.savefig(output_curve_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
116 changes: 116 additions & 0 deletions
116
examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_client.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
|
||
# This is the application script which will be invoked. Client can replace this script with user's own training script. | ||
app_script = "kaplan_meier_train.py" | ||
|
||
# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx. | ||
app_config = "--data_root /tmp/flare/dataset/km_data --he_context_path /tmp/flare/he_context/he_context_client.txt" | ||
|
||
# Client Computing Executors. | ||
executors = [ | ||
{ | ||
# tasks the executors are defined to handle | ||
tasks = ["train"] | ||
|
||
# This particular executor | ||
executor { | ||
|
||
# This is an executor for Client API. The underline data exchange is using Pipe. | ||
path = "nvflare.app_opt.pt.client_api_launcher_executor.ClientAPILauncherExecutor" | ||
|
||
args { | ||
# launcher_id is used to locate the Launcher object in "components" | ||
launcher_id = "launcher" | ||
|
||
# pipe_id is used to locate the Pipe object in "components" | ||
pipe_id = "pipe" | ||
|
||
# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. | ||
# Please refer to the class docstring for all available arguments | ||
heartbeat_timeout = 60 | ||
|
||
# format of the exchange parameters | ||
params_exchange_format = "raw" | ||
|
||
# if the transfer_type is FULL, then it will be sent directly | ||
# if the transfer_type is DIFF, then we will calculate the | ||
# difference VS received parameters and send the difference | ||
params_transfer_type = "FULL" | ||
|
||
# if train_with_evaluation is true, the executor will expect | ||
# the custom code need to send back both the trained parameters and the evaluation metric | ||
# otherwise only trained parameters are expected | ||
train_with_evaluation = false | ||
} | ||
} | ||
} | ||
], | ||
|
||
# this defined an array of task data filters. If provided, it will control the data from server controller to client executor | ||
task_data_filters = [] | ||
|
||
# this defined an array of task result filters. If provided, it will control the result from client executor to server controller | ||
task_result_filters = [] | ||
|
||
components = [ | ||
{ | ||
# component id is "launcher" | ||
id = "launcher" | ||
|
||
# the class path of this component | ||
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" | ||
|
||
args { | ||
# the launcher will invoke the script | ||
script = "python3 custom/{app_script} {app_config} " | ||
# if launch_once is true, the SubprocessLauncher will launch once for the whole job | ||
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server | ||
launch_once = true | ||
} | ||
} | ||
{ | ||
id = "pipe" | ||
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" | ||
args { | ||
mode = "PASSIVE" | ||
site_name = "{SITE_NAME}" | ||
token = "{JOB_ID}" | ||
root_url = "{ROOT_URL}" | ||
secure_mode = "{SECURE_MODE}" | ||
workspace_dir = "{WORKSPACE}" | ||
} | ||
} | ||
{ | ||
id = "metrics_pipe" | ||
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" | ||
args { | ||
mode = "PASSIVE" | ||
site_name = "{SITE_NAME}" | ||
token = "{JOB_ID}" | ||
root_url = "{ROOT_URL}" | ||
secure_mode = "{SECURE_MODE}" | ||
workspace_dir = "{WORKSPACE}" | ||
} | ||
}, | ||
{ | ||
id = "metric_relay" | ||
path = "nvflare.app_common.widgets.metric_relay.MetricRelay" | ||
args { | ||
pipe_id = "metrics_pipe" | ||
event_type = "fed.analytix_log_stats" | ||
# how fast should it read from the peer | ||
read_interval = 0.1 | ||
} | ||
}, | ||
{ | ||
# we use this component so the client api `flare.init()` can get required information | ||
id = "config_preparer" | ||
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" | ||
args { | ||
component_ids = ["metric_relay"] | ||
} | ||
} | ||
] | ||
} |
20 changes: 20 additions & 0 deletions
20
examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/config_fed_server.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
task_data_filters =[] | ||
task_result_filters = [] | ||
|
||
workflows = [ | ||
{ | ||
id = "km" | ||
path = "kaplan_meier_wf.KM" | ||
args { | ||
min_clients = 5 | ||
he_context_path = "/tmp/flare/he_context/he_context_server.txt" | ||
} | ||
} | ||
] | ||
|
||
components = [] | ||
|
||
} |
5 changes: 5 additions & 0 deletions
5
examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{ | ||
description = "Kaplan-Meier survival analysis with homomorphic encryption" | ||
execution_api_type = "client_api" | ||
controller_type = "server" | ||
} |
11 changes: 11 additions & 0 deletions
11
examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/info.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Job Template Information Card | ||
|
||
## kaplan_meier_he | ||
name = "kaplan_meier_he" | ||
description = "Kaplan-Meier survival analysis with homomorphic encryption" | ||
class_name = "KM" | ||
controller_type = "server" | ||
executor_type = "launcher_executor" | ||
contributor = "NVIDIA" | ||
init_publish_date = "2024-04-09" | ||
last_updated_date = "2024-04-09" |
8 changes: 8 additions & 0 deletions
8
examples/advanced/kaplan-meier-he/job_templates/kaplan_meier_he/meta.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
name = "kaplan_meier_he" | ||
resource_spec {} | ||
min_clients = 2 | ||
deploy_map { | ||
app = [ | ||
"@ALL" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
lifelines | ||
tenseal | ||
scikit-survival |
Oops, something went wrong.