Skip to content

Commit

Permalink
Add IPC agent and exchanger (NVIDIA#2435)
Browse files Browse the repository at this point in the history
* support av ipc agent

* removed unused import

* address PR comments
  • Loading branch information
yanchengnv authored Apr 5, 2024
1 parent 317e579 commit b829303
Show file tree
Hide file tree
Showing 15 changed files with 1,406 additions and 53 deletions.
3 changes: 2 additions & 1 deletion nvflare/apis/dxo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class DataKind(object):
COLLECTION = "COLLECTION" # Dict or List of DXO objects
STATISTICS = "STATISTICS"
PSI = "PSI"
APP_DEFINED = "APP_DEFINED" # data format is app defined


class MetaKey(FLMetaKey):
Expand Down Expand Up @@ -128,7 +129,7 @@ def validate(self) -> str:
if self.data is None:
return "missing data"

if not isinstance(self.data, dict):
if self.data_kind != DataKind.APP_DEFINED and not isinstance(self.data, dict):
return "invalid data: expect dict but got {}".format(type(self.data))

if self.meta is not None and not isinstance(self.meta, dict):
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class ReturnCode(object):
EARLY_TERMINATION = "EARLY_TERMINATION"
SERVER_NOT_READY = "SERVER_NOT_READY"
SERVICE_UNAVAILABLE = "SERVICE_UNAVAILABLE"
EARLY_TERMINATION = "EARLY_TERMINATION"


class MachineStatus(Enum):
Expand Down
13 changes: 13 additions & 0 deletions nvflare/app_common/app_defined/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
75 changes: 75 additions & 0 deletions nvflare/app_common/app_defined/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.aggregator import Aggregator
from nvflare.app_common.abstract.model import ModelLearnableKey
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.app_event_type import AppEventType

from .component_base import ComponentBase


class AppDefinedAggregator(Aggregator, ComponentBase, ABC):
def __init__(self):
Aggregator.__init__(self)
ComponentBase.__init__(self)
self.current_round = None
self.base_model_obj = None

def handle_event(self, event_type, fl_ctx: FLContext):
if event_type == AppEventType.ROUND_STARTED:
self.fl_ctx = fl_ctx
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
base_model_learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
if isinstance(base_model_learnable, dict):
self.base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS)
self.reset()

@abstractmethod
def reset(self):
pass

@abstractmethod
def processing_training_result(self, client_name: str, trained_weights: Any, trained_meta: dict) -> bool:
pass

@abstractmethod
def aggregate_training_result(self) -> (Any, dict):
pass

def accept(self, shareable: Shareable, fl_ctx: FLContext) -> bool:
dxo = from_shareable(shareable)
trained_weights = dxo.data
trained_meta = dxo.meta
self.fl_ctx = fl_ctx
peer_ctx = fl_ctx.get_peer_context()
client_name = peer_ctx.get_identity_name()
return self.processing_training_result(client_name, trained_weights, trained_meta)

def aggregate(self, fl_ctx: FLContext) -> Shareable:
self.fl_ctx = fl_ctx
aggregated_result, aggregated_meta = self.aggregate_training_result()
dxo = DXO(
data_kind=DataKind.APP_DEFINED,
data=aggregated_result,
meta=aggregated_meta,
)
self.debug(f"learnable_to_shareable: {dxo.data}")
return dxo.to_shareable()
87 changes: 87 additions & 0 deletions nvflare/app_common/app_defined/component_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.fl_component import FLComponent


class ComponentBase(FLComponent):
def __init__(self):
FLComponent.__init__(self)
self.fl_ctx = None

def debug(self, msg: str):
"""Convenience method for logging an DEBUG message with contextual info
Args:
msg: the message to be logged
Returns:
"""
self.log_debug(self.fl_ctx, msg)

def info(self, msg: str):
"""Convenience method for logging an INFO message with contextual info
Args:
msg: the message to be logged
Returns:
"""
self.log_info(self.fl_ctx, msg)

def error(self, msg: str):
"""Convenience method for logging an ERROR message with contextual info
Args:
msg: the message to be logged
Returns:
"""
self.log_error(self.fl_ctx, msg)

def warning(self, msg: str):
"""Convenience method for logging a WARNING message with contextual info
Args:
msg: the message to be logged
Returns:
"""
self.log_warning(self.fl_ctx, msg)

def exception(self, msg: str):
"""Convenience method for logging an EXCEPTION message with contextual info
Args:
msg: the message to be logged
Returns:
"""
self.log_exception(self.fl_ctx, msg)

def critical(self, msg: str):
"""Convenience method for logging a CRITICAL message with contextual info
Args:
msg: the message to be logged
Returns:
"""
self.log_critical(self.fl_ctx, msg)
57 changes: 57 additions & 0 deletions nvflare/app_common/app_defined/model_persistor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable
from nvflare.app_common.abstract.model_persistor import ModelPersistor

from .component_base import ComponentBase


class AppDefinedModelPersistor(ModelPersistor, ComponentBase, ABC):
def __init__(self):
ModelPersistor.__init__(self)
ComponentBase.__init__(self)

@abstractmethod
def read_model(self) -> Any:
"""Load model object.
Returns: a model object
"""
pass

@abstractmethod
def write_model(self, model_obj: Any):
"""Save the model object
Args:
model_obj: the model object to be saved
Returns: None
"""
pass

def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
self.fl_ctx = fl_ctx
model = self.read_model()
return make_model_learnable(weights=model, meta_props={})

def save_model(self, learnable: ModelLearnable, fl_ctx: FLContext):
self.fl_ctx = fl_ctx
self.write_model(learnable.get(ModelLearnableKey.WEIGHTS))
94 changes: 94 additions & 0 deletions nvflare/app_common/app_defined/shareable_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.abstract.learnable import Learnable
from nvflare.app_common.abstract.model import ModelLearnable, ModelLearnableKey, make_model_learnable
from nvflare.app_common.abstract.shareable_generator import ShareableGenerator
from nvflare.app_common.app_constant import AppConstants

from .component_base import ComponentBase


class AppDefinedShareableGenerator(ShareableGenerator, ComponentBase, ABC):
def __init__(self):
ShareableGenerator.__init__(self)
ComponentBase.__init__(self)
self.current_round = None

@abstractmethod
def model_to_trainable(self, model_obj: Any) -> (Any, dict):
"""Convert the model weights and meta to a format that can be sent to clients to do training
Args:
model_obj: model object
Returns: a tuple of (weights, meta)
The returned weights and meta will be for training and serializable
"""
pass

@abstractmethod
def update_model(self, model_obj: Any, training_result: Any, meta: dict) -> Any:
"""Update model with training result and meta
Args:
model_obj: base model object to be updated
training_result: training result to be applied to the model object
meta: trained meta
Returns: the updated model object
"""
pass

def learnable_to_shareable(self, learnable: Learnable, fl_ctx: FLContext) -> Shareable:
self.fl_ctx = fl_ctx
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
self.debug(f"{learnable=}")
base_model_obj = learnable.get(ModelLearnableKey.WEIGHTS)
trainable_weights, trainable_meta = self.model_to_trainable(base_model_obj)
self.debug(f"trainable weights: {trainable_weights}")
dxo = DXO(
data_kind=DataKind.APP_DEFINED,
data=trainable_weights,
meta=trainable_meta,
)
self.debug(f"learnable_to_shareable: {dxo.data}")
return dxo.to_shareable()

def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Learnable:
self.fl_ctx = fl_ctx
self.current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND)
base_model_learnable = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)

if not base_model_learnable:
self.system_panic(reason="No global base model!", fl_ctx=fl_ctx)
return base_model_learnable

if not isinstance(base_model_learnable, ModelLearnable):
raise ValueError(f"expect global model to be ModelLearnable but got {type(base_model_learnable)}")
base_model_obj = base_model_learnable.get(ModelLearnableKey.WEIGHTS)

dxo = from_shareable(shareable)
training_result = dxo.data
trained_meta = dxo.meta
model_obj = self.update_model(model_obj=base_model_obj, training_result=training_result, meta=trained_meta)
return make_model_learnable(model_obj, {})
Loading

0 comments on commit b829303

Please sign in to comment.