Skip to content

Commit

Permalink
minor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 12, 2024
1 parent 534b4ac commit 9c11323
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
num_rounds = 2
train_script = "code/cifar10_fl.py"

job = FedJob(name="cifar10_fedavg")
job = FedJob(name="cifar10_fedavg_privacy")

controller = FedAvg(
min_clients=n_clients,
Expand All @@ -40,7 +40,7 @@
task_script_path=train_script,
task_script_args=""
)
client_app = ExecutorApp(executors=[executor], external_scripts=[train_script])
client_app = ExecutorApp(executors=[executor], external_scripts=[train_script]) # TODO: add_executor?

# add privacy filter. # TODO: is there a better way to handle task names?
client_app.add_task_result_filter(["train"], PercentilePrivacy(percentile=10, gamma=0.01))
Expand Down
10 changes: 5 additions & 5 deletions nvflare/fed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def __init__(self):
def get_app_config(self):
return self.app

def add_task_result_filter(self, tasks: List[str], filter: Filter):
self.app.add_task_result_filter(tasks, filter)
def add_task_result_filter(self, tasks: List[str], task_filter: Filter):
self.app.add_task_result_filter(tasks, task_filter)

def add_task_data_filter(self, tasks: List[str], filter: Filter):
self.app.add_task_data_filter(tasks, filter)
def add_task_data_filter(self, tasks: List[str], task_filter: Filter):
self.app.add_task_data_filter(tasks, task_filter)


class FedJob:
def __init__(self, name="client-api-pt", min_clients=1, mandatory_clients=None) -> None:
def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None) -> None:
self.job_name = name
self.gpus = []
self.clients = []
Expand Down

0 comments on commit 9c11323

Please sign in to comment.