Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Jul 21, 2023
1 parent 50ea350 commit c63a3eb
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
net = Net()

# initializes NVFlare interface
flare.init(config="config/config_exchange.json")
flare.init()
input_model, input_meta = flare.receive_model()

# get model from NVFlare
Expand Down Expand Up @@ -69,7 +69,6 @@
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0
break

print("Finished Training")

Expand Down
2 changes: 1 addition & 1 deletion examples/advanced/ml-to-fl/jobs/client_api/meta.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"name": "subprocess with file pipe",
"name": "subprocess with file pipe with pytorch",
"resource_spec": {},
"min_clients" : 2,
"deploy_map": {
Expand Down
4 changes: 4 additions & 0 deletions nvflare/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

from enum import Enum

MODEL_ATTRS = ("optimizer_params", "current_round")
SYS_ATTRS = ("job_id", "site_name", "total_rounds")
CONST_ATTRS = ("total_rounds",)


class ModelExchangeFormat(str, Enum):
RAW = "raw"
Expand Down
8 changes: 3 additions & 5 deletions nvflare/client/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
from nvflare.app_common.model_exchange.model_exchanger import ModelExchanger

from .config import ClientConfig
from .constants import MODEL_ATTRS, SYS_ATTRS
from .utils import copy_fl_model_attributes, get_meta_from_fl_model, numerical_params_diff, set_fl_model_with_meta

IN_ATTRS = ("optimizer_params", "current_round")
SYS_ATTRS = ("job_id", "site_name", "total_rounds")

DIFF_MAP = {"numerical_params_diff": numerical_params_diff}


Expand Down Expand Up @@ -50,7 +48,7 @@ def __init__(self, model_exchanger: ModelExchanger, config: ClientConfig):

def _get_model(self):
self.input_model = self.model_exchanger.receive_model()
self.meta = get_meta_from_fl_model(self.input_model, IN_ATTRS)
self.meta = get_meta_from_fl_model(self.input_model, MODEL_ATTRS)
self.sys_meta = get_meta_from_fl_model(self.input_model, SYS_ATTRS)

def construct_fl_model(self, params) -> FLModel:
Expand All @@ -74,7 +72,7 @@ def construct_fl_model(self, params) -> FLModel:
fl_model.params = params_diff_func(self.input_model.params, fl_model.params)
fl_model.params_type = ParamsType.DIFF

set_fl_model_with_meta(fl_model, self.meta, IN_ATTRS)
set_fl_model_with_meta(fl_model, self.meta, MODEL_ATTRS)
copy_fl_model_attributes(self.input_model, fl_model)
fl_model.meta = self.meta
return fl_model
Expand Down
6 changes: 3 additions & 3 deletions nvflare/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from nvflare.app_common.abstract.fl_model import FLModel

OUT_ATTRS = "total_rounds"
from .constants import CONST_ATTRS


def get_meta_from_fl_model(fl_model: FLModel, attrs: Iterable[str]) -> Dict:
Expand Down Expand Up @@ -53,8 +53,8 @@ def set_fl_model_with_meta(fl_model: FLModel, meta: Dict, attrs):
meta.pop(attr)


def copy_fl_model_attributes(src: FLModel, dst: FLModel, attrs=OUT_ATTRS):
"""Copies FLModel attributes.
def copy_fl_model_attributes(src: FLModel, dst: FLModel, attrs=CONST_ATTRS):
"""Copies FLModel attributes from source to destination.
Args:
src: source FLModel object.
Expand Down

0 comments on commit c63a3eb

Please sign in to comment.