Skip to content

Commit

Permalink
Merge branch 'main' into edge_executor
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Feb 21, 2025
2 parents cf8f633 + 0a8d273 commit a445b48
Show file tree
Hide file tree
Showing 18 changed files with 158 additions and 228 deletions.
3 changes: 1 addition & 2 deletions integration/monai/monai_nvflare/nvflare_stats_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
from monai.config import IgniteInfo
from monai.utils import is_scalar, min_version, optional_import

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.apis.analytix import AnalyticsDataType, LogWriterName
from nvflare.app_common.tracking.log_writer import LogWriter
from nvflare.app_common.tracking.tracker_types import LogWriterName

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")

Expand Down
39 changes: 36 additions & 3 deletions nvflare/apis/analytix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,44 @@

from nvflare.apis.dxo import DXO, DataKind

# TODO: api should not depend on app_common
from nvflare.app_common.tracking.tracker_types import LogWriterName, TrackConst

_DATA_TYPE_KEY = "analytics_data_type"
_KWARGS_KEY = "analytics_kwargs"
ANALYTIC_EVENT_TYPE = "analytix_log_stats"


class LogWriterName(Enum):
TORCH_TB = "TORCH_TENSORBOARD"
MLFLOW = "MLFLOW"
WANDB = "WEIGHTS_AND_BIASES"


class TrackConst(object):
TRACKER_KEY = "tracker_key"

TRACK_KEY = "track_key"
TRACK_VALUE = "track_value"

TAG_KEY = "tag_key"
TAGS_KEY = "tags_key"

EXP_TAGS_KEY = "tags_key"

GLOBAL_STEP_KEY = "global_step"
PATH_KEY = "path"
DATA_TYPE_KEY = "analytics_data_type"
KWARGS_KEY = "analytics_kwargs"

PROJECT_NAME = "project_name"
PROJECT_TAGS = "project_name"

EXPERIMENT_NAME = "experiment_name"
RUN_NAME = "run_name"
EXPERIMENT_TAGS = "experiment_tags"
INIT_CONFIG = "init_config"
RUN_TAGS = "run_tags"

SITE_KEY = "site"
JOB_ID_KEY = "job_id"


class AnalyticsDataType(Enum):
Expand Down
5 changes: 1 addition & 4 deletions nvflare/apis/utils/analytix_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType
from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsData, AnalyticsDataType, LogWriterName
from nvflare.apis.dxo import DXO
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey
from nvflare.apis.fl_context import FLContext

# TODO: api should not depend on app_common
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE, LogWriterName


def send_analytic_dxo(
comp: FLComponent, dxo: DXO, fl_ctx: FLContext, event_type: str = ANALYTIC_EVENT_TYPE, fire_fed_event: bool = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading
import time
from typing import Optional

from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE
from nvflare.apis.event_type import EventType
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import FLContextKey, FLMetaKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.apis.utils.analytix_utils import create_analytic_dxo
from nvflare.apis.utils.analytix_utils import create_analytic_dxo, send_analytic_dxo
from nvflare.apis.workspace import Workspace
from nvflare.app_common.abstract.params_converter import ParamsConverter
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.task_script_runner import TaskScriptRunner
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.app_common.widgets.streaming import send_analytic_dxo
from nvflare.client.api_spec import CLIENT_API_KEY
from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType
from nvflare.client.in_process.api import (
Expand Down
5 changes: 2 additions & 3 deletions nvflare/app_common/tracking/log_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from abc import ABC, abstractmethod
from typing import Optional

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsDataType, LogWriterName
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.tracking.tracker_types import LogWriterName
from nvflare.app_common.widgets.streaming import ANALYTIC_EVENT_TYPE, AnalyticsSender
from nvflare.app_common.widgets.streaming import AnalyticsSender


class LogWriter(FLComponent, ABC):
Expand Down
26 changes: 0 additions & 26 deletions nvflare/app_common/tracking/track_exception.py

This file was deleted.

52 changes: 0 additions & 52 deletions nvflare/app_common/tracking/tracker_types.py

This file was deleted.

2 changes: 1 addition & 1 deletion nvflare/app_common/widgets/metric_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from typing import Tuple

from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE
from nvflare.apis.dxo import DXO
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.utils.analytix_utils import send_analytic_dxo
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE
from nvflare.client.config import ConfigKey
from nvflare.fuel.utils.attributes_exportable import AttributesExportable
from nvflare.fuel.utils.constants import PipeChannelName
Expand Down
116 changes: 54 additions & 62 deletions nvflare/app_common/widgets/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
from threading import Lock
from typing import List, Optional

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsDataType, LogWriterName, TrackConst
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import EventScope, FLContextKey, ReservedKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.utils.analytix_utils import create_analytic_dxo, send_analytic_dxo
from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE, LogWriterName, TrackConst
from nvflare.fuel.utils.deprecated import deprecated
from nvflare.widgets.widget import Widget


Expand Down Expand Up @@ -72,70 +70,26 @@ def add(self, tag: str, value, data_type: AnalyticsDataType, global_step: Option
with self.engine.new_context() as fl_ctx:
send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=self.event_type)

@deprecated(
"This method is deprecated, please use :py:class:`TBWriter <nvflare.app_opt.tracking.tb.tb_writer.TBWriter>` instead."
)
def add_scalar(self, tag: str, scalar: float, global_step: Optional[int] = None, **kwargs):
"""Legacy method to send a scalar.
This follows the signature from PyTorch SummaryWriter and is here in case it is used in previous code. If
you are writing new code, use :py:class:`TBWriter <nvflare.app_opt.tracking.tb.tb_writer.TBWriter>` instead.
Args:
tag (str): Data identifier.
scalar (float): Value to send.
global_step (optional, int): Global step value.
**kwargs: Additional arguments to pass to the receiver side.
"""
self.add(tag=tag, value=scalar, data_type=AnalyticsDataType.SCALAR, global_step=global_step, **kwargs)

@deprecated(
"This method is deprecated, please use :py:class:`TBWriter <nvflare.app_opt.tracking.tb.tb_writer.TBWriter>` instead."
)
def add_scalars(self, tag: str, scalars: dict, global_step: Optional[int] = None, **kwargs):
"""Legacy method to send scalars.
This follows the signature from PyTorch SummaryWriter and is here in case it is used in previous code. If
you are writing new code, use :py:class:`TBWriter <nvflare.app_opt.tracking.tb.tb_writer.TBWriter>` instead.
Args:
tag (str): The parent name for the tags.
scalars (dict): Key-value pair storing the tag and corresponding values.
global_step (optional, int): Global step value.
**kwargs: Additional arguments to pass to the receiver side.
"""
self.add(tag=tag, value=scalars, data_type=AnalyticsDataType.SCALARS, global_step=global_step, **kwargs)

@deprecated(
"This method is deprecated, please use :py:class:`TBWriter <nvflare.app_opt.tracking.tb.tb_writer.TBWriter>` instead."
)
def flush(self):
"""Legacy method to flush out the message.
This follows the signature from PyTorch SummaryWriter and is here in case it is used in previous code. If
you are writing new code, use :py:class:`TBWriter <nvflare.app_opt.tracking.tb.tb_writer.TBWriter>` instead.
This does nothing, it is defined to mimic the PyTorch SummaryWriter.
"""
pass

def close(self):
"""Close resources."""
if self.engine:
self.engine = None


class AnalyticsReceiver(Widget, ABC):
def __init__(self, events: Optional[List[str]] = None):
def __init__(self, events: Optional[List[str]] = None, client_side_supported: bool = False):
"""Receives analytic data.
Args:
events (optional, List[str]): A list of event that this receiver will handle.
client_side_supported (bool): Whether the client side is supported.
"""
super().__init__()
if events is None:
events = [ANALYTIC_EVENT_TYPE, f"fed.{ANALYTIC_EVENT_TYPE}"]
self.events = events
self.client_side_supported = client_side_supported
self._initialized = False
self._save_lock = Lock()
self._end = False

Expand Down Expand Up @@ -176,8 +130,28 @@ def finalize(self, fl_ctx: FLContext):

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.initialize(fl_ctx)
self._handle_start_run_event(fl_ctx)
elif event_type in self.events:
self._handle_data_event(event_type, fl_ctx)
elif event_type == EventType.END_RUN:
self._handle_end_run_event(fl_ctx)

def _handle_start_run_event(self, fl_ctx: FLContext):
if not self._is_supported(fl_ctx):
self.log_error(
fl_ctx, f"This receiver is not supported on the site {fl_ctx.get_identity_name()}.", fire_event=False
)
return
try:
self.initialize(fl_ctx)
except Exception as e:
# catch the exception so the job can continue
self.log_error(fl_ctx, f"Receiver initialize failed with {e}.", fire_event=False)
return
self._initialized = True

def _handle_data_event(self, event_type: str, fl_ctx: FLContext):
if self._initialized:
if self._end:
self.log_debug(fl_ctx, f"Already received end run event, drop event {event_type}.", fire_event=False)
return
Expand All @@ -191,17 +165,35 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
)
return

# if fed event use peer name to save
if fl_ctx.get_prop(FLContextKey.EVENT_SCOPE) == EventScope.FEDERATION:
record_origin = data.get_peer_prop(ReservedKey.IDENTITY_NAME, None)
else:
record_origin = fl_ctx.get_identity_name()

record_origin = self._get_record_origin(fl_ctx, data)
if record_origin is None:
self.log_error(fl_ctx, "record_origin can't be None.", fire_event=False)
return
with self._save_lock:
self.save(shareable=data, fl_ctx=fl_ctx, record_origin=record_origin)
elif event_type == EventType.END_RUN:

try:
with self._save_lock:
self.save(shareable=data, fl_ctx=fl_ctx, record_origin=record_origin)
except Exception as e:
self.log_error(fl_ctx, f"Receiver save method failed with {e}.", fire_event=False)

def _handle_end_run_event(self, fl_ctx: FLContext):
if self._initialized:
self._end = True
self.finalize(fl_ctx)
try:
with self._save_lock:
self.finalize(fl_ctx)
except Exception as e:
# catch the exception so the job can continue
self.log_error(fl_ctx, f"Receiver finalize failed with {e}.", fire_event=False)

def _is_supported(self, fl_ctx: FLContext) -> bool:
if not self.client_side_supported:
identity = fl_ctx.get_identity_name()
return "server" in identity
return True

def _get_record_origin(self, fl_ctx: FLContext, data: Shareable) -> Optional[str]:
if fl_ctx.get_prop(FLContextKey.EVENT_SCOPE) == EventScope.FEDERATION:
return data.get_peer_prop(ReservedKey.IDENTITY_NAME, None)
else:
return fl_ctx.get_identity_name()
Loading

0 comments on commit a445b48

Please sign in to comment.