forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add IPC agent and exchanger (NVIDIA#2435)
* support av ipc agent * removed unused import * address PR comments
- Loading branch information
1 parent
317e579
commit b829303
Showing
15 changed files
with
1,406 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
from nvflare.apis.dxo import DXO, DataKind, from_shareable | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.shareable import Shareable | ||
from nvflare.app_common.abstract.aggregator import Aggregator | ||
from nvflare.app_common.abstract.model import ModelLearnableKey | ||
from nvflare.app_common.app_constant import AppConstants | ||
from nvflare.app_common.app_event_type import AppEventType | ||
|
||
from .component_base import ComponentBase | ||
|
||
|
||
class AppDefinedAggregator(Aggregator, ComponentBase, ABC): | ||
def __init__(self): | ||
Aggregator.__init__(self) | ||
ComponentBase.__init__(self) | ||
self.current_round = None | ||
self.base_model_obj = None | ||
|
||
def handle_event(self, event_type, fl_ctx: FLContext): | ||
if event_type == AppEventType.ROUND_STARTED: | ||
self.fl_ctx = fl_ctx | ||
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) | ||
base_model_learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL) | ||
if isinstance(base_model_learnable, dict): | ||
self.base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS) | ||
self.reset() | ||
|
||
@abstractmethod | ||
def reset(self): | ||
pass | ||
|
||
@abstractmethod | ||
def processing_training_result(self, client_name: str, trained_weights: Any, trained_meta: dict) -> bool: | ||
pass | ||
|
||
@abstractmethod | ||
def aggregate_training_result(self) -> (Any, dict): | ||
pass | ||
|
||
def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool: | ||
dxo = from_shareable(shareable) | ||
trained_weights = dxo.data | ||
trained_meta = dxo.meta | ||
self.fl_ctx = fl_ctx | ||
peer_ctx = fl_ctx.get_peer_context() | ||
client_name = peer_ctx.get_identity_name() | ||
return self.processing_training_result(client_name, trained_weights, trained_meta) | ||
|
||
def aggregate(self, fl_ctx: FLContext) -> Shareable: | ||
self.fl_ctx = fl_ctx | ||
aggregated_result, aggregated_meta = self.aggregate_training_result() | ||
dxo = DXO( | ||
data_kind=DataKind.APP_DEFINED, | ||
data=aggregated_result, | ||
meta=aggregated_meta, | ||
) | ||
self.debug(f"learnable_to_shareable: {dxo.data}") | ||
return dxo.to_shareable() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nvflare.apis.fl_component import FLComponent | ||
|
||
|
||
class ComponentBase(FLComponent): | ||
def __init__(self): | ||
FLComponent.__init__(self) | ||
self.fl_ctx = None | ||
|
||
def debug(self, msg: str): | ||
"""Convenience method for logging an DEBUG message with contextual info | ||
Args: | ||
msg: the message to be logged | ||
Returns: | ||
""" | ||
self.log_debug(self.fl_ctx, msg) | ||
|
||
def info(self, msg: str): | ||
"""Convenience method for logging an INFO message with contextual info | ||
Args: | ||
msg: the message to be logged | ||
Returns: | ||
""" | ||
self.log_info(self.fl_ctx, msg) | ||
|
||
def error(self, msg: str): | ||
"""Convenience method for logging an ERROR message with contextual info | ||
Args: | ||
msg: the message to be logged | ||
Returns: | ||
""" | ||
self.log_error(self.fl_ctx, msg) | ||
|
||
def warning(self, msg: str): | ||
"""Convenience method for logging a WARNING message with contextual info | ||
Args: | ||
msg: the message to be logged | ||
Returns: | ||
""" | ||
self.log_warning(self.fl_ctx, msg) | ||
|
||
def exception(self, msg: str): | ||
"""Convenience method for logging an EXCEPTION message with contextual info | ||
Args: | ||
msg: the message to be logged | ||
Returns: | ||
""" | ||
self.log_exception(self.fl_ctx, msg) | ||
|
||
def critical(self, msg: str): | ||
"""Convenience method for logging a CRITICAL message with contextual info | ||
Args: | ||
msg: the message to be logged | ||
Returns: | ||
""" | ||
self.log_critical(self.fl_ctx, msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable | ||
from nvflare.app_common.abstract.model_persistor import ModelPersistor | ||
|
||
from .component_base import ComponentBase | ||
|
||
|
||
class AppDefinedModelPersistor(ModelPersistor, ComponentBase, ABC): | ||
def __init__(self): | ||
ModelPersistor.__init__(self) | ||
ComponentBase.__init__(self) | ||
|
||
@abstractmethod | ||
def read_model(self) -> Any: | ||
"""Load model object. | ||
Returns: a model object | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def write_model(self, model_obj: Any): | ||
"""Save the model object | ||
Args: | ||
model_obj: the model object to be saved | ||
Returns: None | ||
""" | ||
pass | ||
|
||
def load_model(self, fl_ctx: FLContext) -> ModelLearnable: | ||
self.fl_ctx = fl_ctx | ||
model = self.read_model() | ||
return make_model_learnable(weights=model, meta_props={}) | ||
|
||
def save_model(self, learnable: ModelLearnable, fl_ctx: FLContext): | ||
self.fl_ctx = fl_ctx | ||
self.write_model(learnable.get(ModelLearnableKey.WEIGHTS)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
from nvflare.apis.dxo import DXO, DataKind, from_shareable | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.shareable import Shareable | ||
from nvflare.app_common.abstract.learnable import Learnable | ||
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable | ||
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator | ||
from nvflare.app_common.app_constant import AppConstants | ||
|
||
from .component_base import ComponentBase | ||
|
||
|
||
class AppDefinedShareableGenerator(ShareableGenerator, ComponentBase, ABC): | ||
def __init__(self): | ||
ShareableGenerator.__init__(self) | ||
ComponentBase.__init__(self) | ||
self.current_round = None | ||
|
||
@abstractmethod | ||
def model_to_trainable(self, model_obj: Any) -> (Any, dict): | ||
"""Convert the model weights and meta to a format that can be sent to clients to do training | ||
Args: | ||
model_obj: model object | ||
Returns: a tuple of (weights, meta) | ||
The returned weights and meta will be for training and serializable | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def update_model(self, model_obj: Any, training_result: Any, meta: dict) -> Any: | ||
"""Update model with training result and meta | ||
Args: | ||
model_obj: base model object to be updated | ||
training_result: training result to be applied to the model object | ||
meta: trained meta | ||
Returns: the updated model object | ||
""" | ||
pass | ||
|
||
def learnable_to_shareable(self, learnable: Learnable, fl_ctx: FLContext) -> Shareable: | ||
self.fl_ctx = fl_ctx | ||
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) | ||
self.debug(f"{learnable=}") | ||
base_model_obj = learnable.get(ModelLearnableKey.WEIGHTS) | ||
trainable_weights, trainable_meta = self.model_to_trainable(base_model_obj) | ||
self.debug(f"trainable weights: {trainable_weights}") | ||
dxo = DXO( | ||
data_kind=DataKind.APP_DEFINED, | ||
data=trainable_weights, | ||
meta=trainable_meta, | ||
) | ||
self.debug(f"learnable_to_shareable: {dxo.data}") | ||
return dxo.to_shareable() | ||
|
||
def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Learnable: | ||
self.fl_ctx = fl_ctx | ||
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) | ||
base_model_learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL) | ||
|
||
if not base_model_learnable: | ||
self.system_panic(reason="No global base model!", fl_ctx=fl_ctx) | ||
return base_model_learnable | ||
|
||
if not isinstance(base_model_learnable, ModelLearnable): | ||
raise ValueError(f"expect global model to be ModelLearnable but got {type(base_model_learnable)}") | ||
base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS) | ||
|
||
dxo = from_shareable(shareable) | ||
training_result = dxo.data | ||
trained_meta = dxo.meta | ||
model_obj = self.update_model(model_obj=base_model_obj, training_result=training_result, meta=trained_meta) | ||
return make_model_learnable(model_obj, {}) |
Oops, something went wrong.