Skip to content

Commit

Permalink
fix kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 24, 2024
1 parent 98006fe commit 14bd5e7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 20 deletions.
8 changes: 5 additions & 3 deletions examples/hello-world/python_jobs/pt/client_api_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator
from nvflare.client.config import ExchangeFormat


def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample_rate, site_name_prefix="site-"):
Expand All @@ -35,7 +36,7 @@ def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample


if __name__ == "__main__":
n_clients = 2
n_clients = 3
num_rounds = 2
train_script = "code/kmeans_fl.py"
data_input_dir = "/tmp/nvflare/higgs/data"
Expand Down Expand Up @@ -111,9 +112,10 @@ def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample

# Add clients
for i in range(n_clients):
executor = ScriptExecutor( # TODO: ScriptExecutor()
executor = ScriptExecutor(
task_script_path=train_script,
task_script_args=f"--data_root_dir {data_output_dir}"
task_script_args=f"--data_root_dir {data_output_dir}",
params_exchange_format=ExchangeFormat.RAW # kmeans requires raw values only rather than PyTorch Tensors (the default)
)
job.to(executor, f"site-{i+1}", gpu=0) # HIGGs data splitter assumes site names start from 1

Expand Down
34 changes: 17 additions & 17 deletions nvflare/app_common/executors/script_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nvflare.app_common.app_constant import AppConstants
Expand Down Expand Up @@ -38,22 +39,22 @@ def __init__(
submit_model_task_name: str = "submit_model",
params_exchange_format=ExchangeFormat.PYTORCH,
):
super(ScriptExecutor, self).__init__(
task_script_path=task_script_path,
task_script_args=task_script_args,
task_wait_time=task_wait_time,
result_pull_interval=result_pull_interval,
train_with_evaluation=train_with_evaluation,
train_task_name=train_task_name,
evaluate_task_name=evaluate_task_name,
submit_model_task_name=submit_model_task_name,
from_nvflare_converter_id=from_nvflare_converter_id,
to_nvflare_converter_id=to_nvflare_converter_id,
params_exchange_format=params_exchange_format,
params_transfer_type=params_transfer_type,
log_pull_interval=log_pull_interval,
)
if params_exchange_format == ExchangeFormat.PYTORCH:
super(ScriptExecutor, self).__init__(
task_script_path=task_script_path,
task_script_args=task_script_args,
task_wait_time=task_wait_time,
result_pull_interval=result_pull_interval,
train_with_evaluation=train_with_evaluation,
train_task_name=train_task_name,
evaluate_task_name=evaluate_task_name,
submit_model_task_name=submit_model_task_name,
from_nvflare_converter_id=from_nvflare_converter_id,
to_nvflare_converter_id=to_nvflare_converter_id,
params_exchange_format=params_exchange_format,
params_transfer_type=params_transfer_type,
log_pull_interval=log_pull_interval,
)
fobs.register(TensorDecomposer)

if self._from_nvflare_converter is None:
Expand All @@ -64,5 +65,4 @@ def __init__(
self._to_nvflare_converter = PTToNumpyParamsConverter(
[AppConstants.TASK_TRAIN, AppConstants.TASK_SUBMIT_MODEL]
)
else:
raise NotImplementedError(f"`params_exchange_format`={params_exchange_format} not supported!")
# TODO: support other params_exchange_format

0 comments on commit 14bd5e7

Please sign in to comment.