From 14bd5e722ffd96ad2c5b1ac9c2bd053d7b92a0a9 Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Wed, 24 Apr 2024 14:30:48 -0400 Subject: [PATCH] fix kmeans --- .../python_jobs/pt/client_api_kmeans.py | 8 +++-- .../app_common/executors/script_executor.py | 34 +++++++++---------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/examples/hello-world/python_jobs/pt/client_api_kmeans.py b/examples/hello-world/python_jobs/pt/client_api_kmeans.py index efd9a860c2..3267cabb9f 100644 --- a/examples/hello-world/python_jobs/pt/client_api_kmeans.py +++ b/examples/hello-world/python_jobs/pt/client_api_kmeans.py @@ -22,6 +22,7 @@ from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator +from nvflare.client.config import ExchangeFormat def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample_rate, site_name_prefix="site-"): @@ -35,7 +36,7 @@ def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample if __name__ == "__main__": - n_clients = 2 + n_clients = 3 num_rounds = 2 train_script = "code/kmeans_fl.py" data_input_dir = "/tmp/nvflare/higgs/data" @@ -111,9 +112,10 @@ def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample # Add clients for i in range(n_clients): - executor = ScriptExecutor( # TODO: ScriptExecutor() + executor = ScriptExecutor( task_script_path=train_script, - task_script_args=f"--data_root_dir {data_output_dir}" + task_script_args=f"--data_root_dir {data_output_dir}", + params_exchange_format=ExchangeFormat.RAW # kmeans requires raw values only rather than PyTorch Tensors (the default) ) job.to(executor, f"site-{i+1}", gpu=0) # HIGGs data splitter assumes site names start from 1 diff --git a/nvflare/app_common/executors/script_executor.py b/nvflare/app_common/executors/script_executor.py index 4d8173568d..1c10b7dae3 100644 --- a/nvflare/app_common/executors/script_executor.py +++ b/nvflare/app_common/executors/script_executor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Optional from nvflare.app_common.app_constant import AppConstants @@ -38,22 +39,22 @@ def __init__( submit_model_task_name: str = "submit_model", params_exchange_format=ExchangeFormat.PYTORCH, ): + super(ScriptExecutor, self).__init__( + task_script_path=task_script_path, + task_script_args=task_script_args, + task_wait_time=task_wait_time, + result_pull_interval=result_pull_interval, + train_with_evaluation=train_with_evaluation, + train_task_name=train_task_name, + evaluate_task_name=evaluate_task_name, + submit_model_task_name=submit_model_task_name, + from_nvflare_converter_id=from_nvflare_converter_id, + to_nvflare_converter_id=to_nvflare_converter_id, + params_exchange_format=params_exchange_format, + params_transfer_type=params_transfer_type, + log_pull_interval=log_pull_interval, + ) if params_exchange_format == ExchangeFormat.PYTORCH: - super(ScriptExecutor, self).__init__( - task_script_path=task_script_path, - task_script_args=task_script_args, - task_wait_time=task_wait_time, - result_pull_interval=result_pull_interval, - train_with_evaluation=train_with_evaluation, - train_task_name=train_task_name, - evaluate_task_name=evaluate_task_name, - submit_model_task_name=submit_model_task_name, - from_nvflare_converter_id=from_nvflare_converter_id, - to_nvflare_converter_id=to_nvflare_converter_id, - params_exchange_format=params_exchange_format, - params_transfer_type=params_transfer_type, - log_pull_interval=log_pull_interval, - ) fobs.register(TensorDecomposer) if self._from_nvflare_converter is None: @@ -64,5 +65,4 @@ def __init__( self._to_nvflare_converter = PTToNumpyParamsConverter( [AppConstants.TASK_TRAIN, AppConstants.TASK_SUBMIT_MODEL] ) - else: - raise NotImplementedError(f"`params_exchange_format`={params_exchange_format} not supported!") + # TODO: support other params_exchange_format \ No newline at end of file