From 0a8d2735f9cc8a5dcc8993efb3bcc31ec49a6e1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Fri, 21 Feb 2025 11:30:52 -0800 Subject: [PATCH] Improve metrics streaming (#2722) --- .../monai_nvflare/nvflare_stats_handler.py | 3 +- nvflare/apis/analytix.py | 39 +++++- nvflare/apis/utils/analytix_utils.py | 5 +- .../in_process_client_api_executor.py | 6 +- nvflare/app_common/tracking/log_writer.py | 5 +- .../app_common/tracking/track_exception.py | 26 ---- nvflare/app_common/tracking/tracker_types.py | 52 -------- nvflare/app_common/widgets/metric_relay.py | 2 +- nvflare/app_common/widgets/streaming.py | 116 ++++++++---------- .../tracking/mlflow/mlflow_receiver.py | 80 ++++++------ .../app_opt/tracking/mlflow/mlflow_writer.py | 4 +- nvflare/app_opt/tracking/tb/tb_receiver.py | 2 +- nvflare/app_opt/tracking/tb/tb_writer.py | 4 +- .../app_opt/tracking/wandb/wandb_receiver.py | 29 ++--- .../app_opt/tracking/wandb/wandb_writer.py | 4 +- nvflare/client/tracking.py | 3 +- tests/unit_test/apis/analytix_test.py | 3 +- .../app_common/widgets/streaming_test.py | 3 +- 18 files changed, 158 insertions(+), 228 deletions(-) delete mode 100644 nvflare/app_common/tracking/track_exception.py delete mode 100644 nvflare/app_common/tracking/tracker_types.py diff --git a/integration/monai/monai_nvflare/nvflare_stats_handler.py b/integration/monai/monai_nvflare/nvflare_stats_handler.py index 2a7ab7d840..775d390cc4 100644 --- a/integration/monai/monai_nvflare/nvflare_stats_handler.py +++ b/integration/monai/monai_nvflare/nvflare_stats_handler.py @@ -22,9 +22,8 @@ from monai.config import IgniteInfo from monai.utils import is_scalar, min_version, optional_import -from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.analytix import AnalyticsDataType, LogWriterName from nvflare.app_common.tracking.log_writer import LogWriter -from nvflare.app_common.tracking.tracker_types import LogWriterName Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") diff --git a/nvflare/apis/analytix.py b/nvflare/apis/analytix.py index 0cb0ab1b45..801019e108 100644 --- a/nvflare/apis/analytix.py +++ b/nvflare/apis/analytix.py @@ -16,11 +16,44 @@ from nvflare.apis.dxo import DXO, DataKind -# TODO: api should not depend on app_common -from nvflare.app_common.tracking.tracker_types import LogWriterName, TrackConst - _DATA_TYPE_KEY = "analytics_data_type" _KWARGS_KEY = "analytics_kwargs" +ANALYTIC_EVENT_TYPE = "analytix_log_stats" + + +class LogWriterName(Enum): + TORCH_TB = "TORCH_TENSORBOARD" + MLFLOW = "MLFLOW" + WANDB = "WEIGHTS_AND_BIASES" + + +class TrackConst(object): + TRACKER_KEY = "tracker_key" + + TRACK_KEY = "track_key" + TRACK_VALUE = "track_value" + + TAG_KEY = "tag_key" + TAGS_KEY = "tags_key" + + EXP_TAGS_KEY = "tags_key" + + GLOBAL_STEP_KEY = "global_step" + PATH_KEY = "path" + DATA_TYPE_KEY = "analytics_data_type" + KWARGS_KEY = "analytics_kwargs" + + PROJECT_NAME = "project_name" + PROJECT_TAGS = "project_name" + + EXPERIMENT_NAME = "experiment_name" + RUN_NAME = "run_name" + EXPERIMENT_TAGS = "experiment_tags" + INIT_CONFIG = "init_config" + RUN_TAGS = "run_tags" + + SITE_KEY = "site" + JOB_ID_KEY = "job_id" class AnalyticsDataType(Enum): diff --git a/nvflare/apis/utils/analytix_utils.py b/nvflare/apis/utils/analytix_utils.py index 1c55890ed7..04ae9791c2 100644 --- a/nvflare/apis/utils/analytix_utils.py +++ b/nvflare/apis/utils/analytix_utils.py @@ -12,15 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsData, AnalyticsDataType, LogWriterName from nvflare.apis.dxo import DXO from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import FLContextKey from nvflare.apis.fl_context import FLContext -# TODO: api should not depend on app_common -from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE, LogWriterName - def send_analytic_dxo( comp: FLComponent, dxo: DXO, fl_ctx: FLContext, event_type: str = ANALYTIC_EVENT_TYPE, fire_fed_event: bool = False diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py index c920f1b4f9..d2549f8e84 100644 --- a/nvflare/app_common/executors/in_process_client_api_executor.py +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -11,23 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import threading import time from typing import Optional +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE from nvflare.apis.event_type import EventType from nvflare.apis.executor import Executor from nvflare.apis.fl_constant import FLContextKey, FLMetaKey, ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal -from nvflare.apis.utils.analytix_utils import create_analytic_dxo +from nvflare.apis.utils.analytix_utils import create_analytic_dxo, send_analytic_dxo from nvflare.apis.workspace import Workspace from nvflare.app_common.abstract.params_converter import ParamsConverter from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.executors.task_script_runner import TaskScriptRunner -from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE -from nvflare.app_common.widgets.streaming import send_analytic_dxo from nvflare.client.api_spec import CLIENT_API_KEY from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType from nvflare.client.in_process.api import ( diff --git a/nvflare/app_common/tracking/log_writer.py b/nvflare/app_common/tracking/log_writer.py index fa7162d3e4..78f8109f8b 100644 --- a/nvflare/app_common/tracking/log_writer.py +++ b/nvflare/app_common/tracking/log_writer.py @@ -15,12 +15,11 @@ from abc import ABC, abstractmethod from typing import Optional -from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsDataType, LogWriterName from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext -from nvflare.app_common.tracking.tracker_types import LogWriterName -from nvflare.app_common.widgets.streaming import ANALYTIC_EVENT_TYPE, AnalyticsSender +from nvflare.app_common.widgets.streaming import AnalyticsSender class LogWriter(FLComponent, ABC): diff --git a/nvflare/app_common/tracking/track_exception.py b/nvflare/app_common/tracking/track_exception.py deleted file mode 100644 index 6cad709b48..0000000000 --- a/nvflare/app_common/tracking/track_exception.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class ExpTrackingException(Exception): - def __init__(self, message, **kwargs): - """ - Args: - message: The message or exception describing the error that occurred. - **kwargs: Additional key-value pairs - """ - message = str(message) - self.message = message - self.kwargs = kwargs - super().__init__(message) diff --git a/nvflare/app_common/tracking/tracker_types.py b/nvflare/app_common/tracking/tracker_types.py deleted file mode 100644 index d257b1c0e3..0000000000 --- a/nvflare/app_common/tracking/tracker_types.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from enum import Enum - -ANALYTIC_EVENT_TYPE = "analytix_log_stats" - - -class LogWriterName(Enum): - TORCH_TB = "TORCH_TENSORBOARD" - MLFLOW = "MLFLOW" - WANDB = "WEIGHTS_AND_BIASES" - - -class TrackConst(object): - TRACKER_KEY = "tracker_key" - - TRACK_KEY = "track_key" - TRACK_VALUE = "track_value" - - TAG_KEY = "tag_key" - TAGS_KEY = "tags_key" - - EXP_TAGS_KEY = "tags_key" - - GLOBAL_STEP_KEY = "global_step" - PATH_KEY = "path" - DATA_TYPE_KEY = "analytics_data_type" - KWARGS_KEY = "analytics_kwargs" - - PROJECT_NAME = "project_name" - PROJECT_TAGS = "project_name" - - EXPERIMENT_NAME = "experiment_name" - RUN_NAME = "run_name" - EXPERIMENT_TAGS = "experiment_tags" - INIT_CONFIG = "init_config" - RUN_TAGS = "run_tags" - - SITE_KEY = "site" - JOB_ID_KEY = "job_id" diff --git a/nvflare/app_common/widgets/metric_relay.py b/nvflare/app_common/widgets/metric_relay.py index 7b1372911e..e5137f186a 100644 --- a/nvflare/app_common/widgets/metric_relay.py +++ b/nvflare/app_common/widgets/metric_relay.py @@ -14,11 +14,11 @@ from typing import Tuple +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE from nvflare.apis.dxo import DXO from nvflare.apis.event_type import EventType from nvflare.apis.fl_context import FLContext from nvflare.apis.utils.analytix_utils import send_analytic_dxo -from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE from nvflare.client.config import ConfigKey from nvflare.fuel.utils.attributes_exportable import AttributesExportable from nvflare.fuel.utils.constants import PipeChannelName diff --git a/nvflare/app_common/widgets/streaming.py b/nvflare/app_common/widgets/streaming.py index 481467757a..4ebec89c83 100644 --- a/nvflare/app_common/widgets/streaming.py +++ b/nvflare/app_common/widgets/streaming.py @@ -16,14 +16,12 @@ from threading import Lock from typing import List, Optional -from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsDataType, LogWriterName, TrackConst from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import EventScope, FLContextKey, ReservedKey from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable from nvflare.apis.utils.analytix_utils import create_analytic_dxo, send_analytic_dxo -from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE, LogWriterName, TrackConst -from nvflare.fuel.utils.deprecated import deprecated from nvflare.widgets.widget import Widget @@ -72,53 +70,6 @@ def add(self, tag: str, value, data_type: AnalyticsDataType, global_step: Option with self.engine.new_context() as fl_ctx: send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=self.event_type) - @deprecated( - "This method is deprecated, please use :py:class:`TBWriter ` instead." - ) - def add_scalar(self, tag: str, scalar: float, global_step: Optional[int] = None, **kwargs): - """Legacy method to send a scalar. - - This follows the signature from PyTorch SummaryWriter and is here in case it is used in previous code. If - you are writing new code, use :py:class:`TBWriter ` instead. - - Args: - tag (str): Data identifier. - scalar (float): Value to send. - global_step (optional, int): Global step value. - **kwargs: Additional arguments to pass to the receiver side. - """ - self.add(tag=tag, value=scalar, data_type=AnalyticsDataType.SCALAR, global_step=global_step, **kwargs) - - @deprecated( - "This method is deprecated, please use :py:class:`TBWriter ` instead." - ) - def add_scalars(self, tag: str, scalars: dict, global_step: Optional[int] = None, **kwargs): - """Legacy method to send scalars. - - This follows the signature from PyTorch SummaryWriter and is here in case it is used in previous code. If - you are writing new code, use :py:class:`TBWriter ` instead. - - Args: - tag (str): The parent name for the tags. - scalars (dict): Key-value pair storing the tag and corresponding values. - global_step (optional, int): Global step value. - **kwargs: Additional arguments to pass to the receiver side. - """ - self.add(tag=tag, value=scalars, data_type=AnalyticsDataType.SCALARS, global_step=global_step, **kwargs) - - @deprecated( - "This method is deprecated, please use :py:class:`TBWriter ` instead." - ) - def flush(self): - """Legacy method to flush out the message. - - This follows the signature from PyTorch SummaryWriter and is here in case it is used in previous code. If - you are writing new code, use :py:class:`TBWriter ` instead. - - This does nothing, it is defined to mimic the PyTorch SummaryWriter. - """ - pass - def close(self): """Close resources.""" if self.engine: @@ -126,16 +77,19 @@ def close(self): class AnalyticsReceiver(Widget, ABC): - def __init__(self, events: Optional[List[str]] = None): + def __init__(self, events: Optional[List[str]] = None, client_side_supported: bool = False): """Receives analytic data. Args: events (optional, List[str]): A list of event that this receiver will handle. + client_side_supported (bool): Whether the client side is supported. """ super().__init__() if events is None: events = [ANALYTIC_EVENT_TYPE, f"fed.{ANALYTIC_EVENT_TYPE}"] self.events = events + self.client_side_supported = client_side_supported + self._initialized = False self._save_lock = Lock() self._end = False @@ -176,8 +130,28 @@ def finalize(self, fl_ctx: FLContext): def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: - self.initialize(fl_ctx) + self._handle_start_run_event(fl_ctx) elif event_type in self.events: + self._handle_data_event(event_type, fl_ctx) + elif event_type == EventType.END_RUN: + self._handle_end_run_event(fl_ctx) + + def _handle_start_run_event(self, fl_ctx: FLContext): + if not self._is_supported(fl_ctx): + self.log_error( + fl_ctx, f"This receiver is not supported on the site {fl_ctx.get_identity_name()}.", fire_event=False + ) + return + try: + self.initialize(fl_ctx) + except Exception as e: + # catch the exception so the job can continue + self.log_error(fl_ctx, f"Receiver initialize failed with {e}.", fire_event=False) + return + self._initialized = True + + def _handle_data_event(self, event_type: str, fl_ctx: FLContext): + if self._initialized: if self._end: self.log_debug(fl_ctx, f"Already received end run event, drop event {event_type}.", fire_event=False) return @@ -191,17 +165,35 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): ) return - # if fed event use peer name to save - if fl_ctx.get_prop(FLContextKey.EVENT_SCOPE) == EventScope.FEDERATION: - record_origin = data.get_peer_prop(ReservedKey.IDENTITY_NAME, None) - else: - record_origin = fl_ctx.get_identity_name() - + record_origin = self._get_record_origin(fl_ctx, data) if record_origin is None: self.log_error(fl_ctx, "record_origin can't be None.", fire_event=False) return - with self._save_lock: - self.save(shareable=data, fl_ctx=fl_ctx, record_origin=record_origin) - elif event_type == EventType.END_RUN: + + try: + with self._save_lock: + self.save(shareable=data, fl_ctx=fl_ctx, record_origin=record_origin) + except Exception as e: + self.log_error(fl_ctx, f"Receiver save method failed with {e}.", fire_event=False) + + def _handle_end_run_event(self, fl_ctx: FLContext): + if self._initialized: self._end = True - self.finalize(fl_ctx) + try: + with self._save_lock: + self.finalize(fl_ctx) + except Exception as e: + # catch the exception so the job can continue + self.log_error(fl_ctx, f"Receiver finalize failed with {e}.", fire_event=False) + + def _is_supported(self, fl_ctx: FLContext) -> bool: + if not self.client_side_supported: + identity = fl_ctx.get_identity_name() + return "server" in identity + return True + + def _get_record_origin(self, fl_ctx: FLContext, data: Shareable) -> Optional[str]: + if fl_ctx.get_prop(FLContextKey.EVENT_SCOPE) == EventScope.FEDERATION: + return data.get_peer_prop(ReservedKey.IDENTITY_NAME, None) + else: + return fl_ctx.get_identity_name() diff --git a/nvflare/app_opt/tracking/mlflow/mlflow_receiver.py b/nvflare/app_opt/tracking/mlflow/mlflow_receiver.py index b76821f023..3659188a77 100644 --- a/nvflare/app_opt/tracking/mlflow/mlflow_receiver.py +++ b/nvflare/app_opt/tracking/mlflow/mlflow_receiver.py @@ -15,18 +15,16 @@ import os import time import timeit -from typing import Dict, Optional +from typing import Dict, List, Optional import mlflow from mlflow.entities import Metric, Param, RunTag from mlflow.tracking.client import MlflowClient -from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsData, AnalyticsDataType, LogWriterName, TrackConst from nvflare.apis.dxo import from_shareable from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable -from nvflare.app_common.tracking.track_exception import ExpTrackingException -from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE, LogWriterName, TrackConst from nvflare.app_common.widgets.streaming import AnalyticsReceiver @@ -46,7 +44,7 @@ def __init__( tracking_uri: Optional[str] = None, kw_args: Optional[dict] = None, artifact_location: Optional[str] = None, - events=None, + events: Optional[List[str]] = None, buffer_flush_time=1, ): """MLflowReceiver receives log events from clients and deliver them to the MLflow tracking server. @@ -64,7 +62,7 @@ def __init__( When provided, it displays as run description field on the MLflow UI. You can use Markdown syntax for the description. artifact_location (Optional[str], optional): Relative location of artifacts. Currently only text is supported at the moment. - events (_type_, optional): The event the receiver is listening to. By default, it listens to "fed.analytix_log_stats". + events (optional, List[str]): A list of event that this receiver will handle. buffer_flush_time (int, optional): The time in seconds between deliveries of event data to the MLflow tracking server. The data is buffered and then delivered to the MLflow tracking server in batches, and the buffer_flush_time controls the frequency of the sending. By default, the buffer @@ -108,12 +106,16 @@ def initialize(self, fl_ctx: FLContext): art_full_path = self.get_artifact_location(self.artifact_location) experiment_name = self.kw_args.get(TrackConst.EXPERIMENT_NAME, "FLARE FL Experiment") + if not experiment_name: + raise ValueError("Experiment name can't be empty.") experiment_tags = self._get_tags(TrackConst.EXPERIMENT_TAGS, kwargs=self.kw_args) sites = fl_ctx.get_engine().get_clients() - self._init_buffer(sites) + self.mlflow_setup(art_full_path, experiment_name, experiment_tags, sites) + self._init_buffer(sites) + def mlflow_setup(self, art_full_path, experiment_name, experiment_tags, sites): """Set up an MlflowClient for each client site and create an experiment and run. @@ -128,7 +130,7 @@ def mlflow_setup(self, art_full_path, experiment_name, experiment_tags, sites): if not mlflow_client: mlflow_client = MlflowClient() self.mlflow_clients[site.name] = mlflow_client - self.experiment_id = self._create_experiment( + self.experiment_id = self._get_or_create_experiment( mlflow_client, experiment_name, art_full_path, experiment_tags ) run_group_id = str(int(time.time())) @@ -182,7 +184,7 @@ def get_artifact_location(self, relative_path: str): root_log_dir = os.path.join(run_dir, relative_path) return root_log_dir - def _create_experiment( + def _get_or_create_experiment( self, mlflow_client: MlflowClient, experiment_name: str, @@ -190,26 +192,20 @@ def _create_experiment( experiment_tags: Optional[dict] = None, ) -> Optional[str]: experiment_id = None - if experiment_name: + experiment = mlflow_client.get_experiment_by_name(name=experiment_name) + if not experiment: + self.logger.info(f"Experiment with name '{experiment_name}' does not exist. Creating a new experiment.") + import pathlib + + artifact_location_uri = pathlib.Path(artifact_location).as_uri() + experiment_id = mlflow_client.create_experiment( + name=experiment_name, artifact_location=artifact_location_uri, tags=experiment_tags + ) experiment = mlflow_client.get_experiment_by_name(name=experiment_name) - if not experiment: - self.logger.info(f"Experiment with name '{experiment_name}' does not exist. Creating a new experiment.") - try: - import pathlib - - artifact_location_uri = pathlib.Path(artifact_location).as_uri() - experiment_id = mlflow_client.create_experiment( - name=experiment_name, artifact_location=artifact_location_uri, tags=experiment_tags - ) - except Exception as e: - raise ExpTrackingException( - f"Could not create an MLflow Experiment with name {experiment_name}. {e}" - ) - experiment = mlflow_client.get_experiment_by_name(name=experiment_name) - else: - experiment_id = experiment.experiment_id - - self.logger.info(f"Experiment={experiment}") + else: + experiment_id = experiment.experiment_id + + self.logger.info(f"Experiment={experiment}") return experiment_id def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin: str): @@ -226,6 +222,9 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin: str): if not mlflow_client: raise RuntimeError(f"mlflow client is None for site {record_origin}.") run_id = self.get_run_id(record_origin) + if not run_id: + raise RuntimeError(f"run_id is missing for site {record_origin}.") + if data.kwargs.get("path", None): mlflow_client.log_text(run_id=run_id, text=data.value, artifact_file=data.kwargs.get("path")) elif data.data_type == AnalyticsDataType.MODEL: @@ -238,7 +237,7 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin: str): self.buffer_data(data, record_origin) self.time_since_flush += timeit.default_timer() - self.time_start if self.time_since_flush >= self.buff_flush_time: - self.flush_buffer(record_origin) + self.flush_buffers(record_origin) def buffer_data(self, data: AnalyticsData, record_origin: str) -> None: """Buffer the data to send later. @@ -279,8 +278,8 @@ def get_target_type(self, data_type: AnalyticsDataType): else: return data_type - def flush_buffer(self, record_origin): - """Flush the buffer and send all the data to the MLflow tracking server. + def flush_buffers(self, record_origin): + """Flush buffers and send all the data to the MLflow tracking server. Args: record_origin (str): Origin of the data, or site name. @@ -290,27 +289,28 @@ def flush_buffer(self, record_origin): raise RuntimeError(f"mlflow client is None for site {record_origin}.") run_id = self.get_run_id(record_origin) + if not run_id: + raise RuntimeError(f"run_id is missing for site {record_origin}.") site_buff = self.buffer[record_origin] - metrics_arr = self.pop_from_buffer(site_buff[AnalyticsDataType.METRICS]) - params_arr = self.pop_from_buffer(site_buff[AnalyticsDataType.PARAMETERS]) - tags_arr = self.pop_from_buffer(site_buff[AnalyticsDataType.TAGS]) + metrics_arr = self.flush_buffer(site_buff[AnalyticsDataType.METRICS]) + params_arr = self.flush_buffer(site_buff[AnalyticsDataType.PARAMETERS]) + tags_arr = self.flush_buffer(site_buff[AnalyticsDataType.TAGS]) mlflow_client.log_batch(run_id=run_id, metrics=metrics_arr, params=params_arr, tags=tags_arr) self.time_start = 0 self.time_since_flush = 0 - def pop_from_buffer(self, log_buffer): - item_arr = [] - for _ in range(len(log_buffer)): - item_arr.append(log_buffer.pop()) + def flush_buffer(self, log_buffer: List): + item_arr = list(log_buffer) + log_buffer.clear() return item_arr def finalize(self, fl_ctx: FLContext): for site_name in self.buffer: - self.flush_buffer(site_name) + self.flush_buffers(site_name) for site_name in self.run_ids: run_id = self.run_ids[site_name] @@ -318,7 +318,7 @@ def finalize(self, fl_ctx: FLContext): if run_id: mlflow_client.set_terminated(run_id) - def get_run_id(self, site_id: str) -> str: + def get_run_id(self, site_id: str) -> Optional[str]: return self.run_ids.get(site_id, None) def get_mlflow_client(self, site_id: str) -> MlflowClient: diff --git a/nvflare/app_opt/tracking/mlflow/mlflow_writer.py b/nvflare/app_opt/tracking/mlflow/mlflow_writer.py index 751254eb90..118156bcae 100644 --- a/nvflare/app_opt/tracking/mlflow/mlflow_writer.py +++ b/nvflare/app_opt/tracking/mlflow/mlflow_writer.py @@ -14,10 +14,8 @@ from typing import Dict, Optional -from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsDataType, LogWriterName from nvflare.app_common.tracking.log_writer import LogWriter -from nvflare.app_common.tracking.tracker_types import LogWriterName -from nvflare.app_common.widgets.streaming import ANALYTIC_EVENT_TYPE class MLflowWriter(LogWriter): diff --git a/nvflare/app_opt/tracking/tb/tb_receiver.py b/nvflare/app_opt/tracking/tb/tb_receiver.py index fb00c0c396..b213a2a0d9 100644 --- a/nvflare/app_opt/tracking/tb/tb_receiver.py +++ b/nvflare/app_opt/tracking/tb/tb_receiver.py @@ -68,7 +68,7 @@ def __init__(self, tb_folder="tb_events", events: Optional[List[str]] = None): - peer_name_2: """ - super().__init__(events=events) + super().__init__(events=events, client_side_supported=True) self.writers_table = {} self.tb_folder = tb_folder self.root_log_dir = None diff --git a/nvflare/app_opt/tracking/tb/tb_writer.py b/nvflare/app_opt/tracking/tb/tb_writer.py index a60ad45d0a..dd886672f2 100644 --- a/nvflare/app_opt/tracking/tb/tb_writer.py +++ b/nvflare/app_opt/tracking/tb/tb_writer.py @@ -14,10 +14,8 @@ from typing import Optional -from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsDataType, LogWriterName from nvflare.app_common.tracking.log_writer import LogWriter -from nvflare.app_common.tracking.tracker_types import LogWriterName -from nvflare.app_common.widgets.streaming import ANALYTIC_EVENT_TYPE class TBWriter(LogWriter): diff --git a/nvflare/app_opt/tracking/wandb/wandb_receiver.py b/nvflare/app_opt/tracking/wandb/wandb_receiver.py index 674ee1fbf3..84913007b9 100644 --- a/nvflare/app_opt/tracking/wandb/wandb_receiver.py +++ b/nvflare/app_opt/tracking/wandb/wandb_receiver.py @@ -15,15 +15,14 @@ import os import time from multiprocessing import Process, Queue -from typing import NamedTuple, Optional +from typing import List, NamedTuple, Optional import wandb -from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType +from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType, LogWriterName from nvflare.apis.dxo import from_shareable from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable -from nvflare.app_common.tracking.tracker_types import LogWriterName from nvflare.app_common.widgets.streaming import AnalyticsReceiver @@ -35,9 +34,9 @@ class WandBTask(NamedTuple): class WandBReceiver(AnalyticsReceiver): - def __init__(self, kwargs: dict, mode: str = "offline", events=None, process_timeout=10): - if events is None: - events = ["fed.analytix_log_stats"] + def __init__( + self, kwargs: dict, mode: str = "offline", events: Optional[List[str]] = None, process_timeout: float = 10.0 + ): super().__init__(events=events) self.fl_ctx = None self.mode = mode @@ -49,7 +48,7 @@ def __init__(self, kwargs: dict, mode: str = "offline", events=None, process_tim # os.environ["WANDB_API_KEY"] = YOUR_KEY_HERE os.environ["WANDB_MODE"] = self.mode - def job(self, queue): + def process_queue_tasks(self, queue): cnt = 0 run = None try: @@ -75,7 +74,6 @@ def job(self, queue): run.finish() def initialize(self, fl_ctx: FLContext): - self.fl_ctx = fl_ctx sites = fl_ctx.get_engine().get_clients() run_group_id = str(int(time.time())) @@ -103,11 +101,10 @@ def initialize(self, fl_ctx: FLContext): q = Queue() wandb_task = WandBTask(task_owner=site.name, task_type="init", task_data=self.kwargs, step=0) - # q.put_nowait(wandb_task) q.put(wandb_task) self.queues[site.name] = q - p = Process(target=self.job, args=(q,)) + p = Process(target=self.process_queue_tasks, args=(q,)) self.processes[site.name] = p p.start() time.sleep(0.2) @@ -125,7 +122,7 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin: str): if not data: return - q: Optional[Queue] = self.get_job_queue(record_origin) + q: Optional[Queue] = self.get_task_queue(record_origin) if q: if data.data_type == AnalyticsDataType.PARAMETER or data.data_type == AnalyticsDataType.METRIC: log_data = {data.tag: data.value} @@ -141,7 +138,7 @@ def finalize(self, fl_ctx: FLContext): """ for site in self.processes: self.log_info(self.fl_ctx, f"inform {site} to stop") - q: Optional[Queue] = self.get_job_queue(site) + q: Optional[Queue] = self.get_task_queue(site) q.put(WandBTask(task_owner=site, task_type="stop", task_data={}, step=0)) for site in self.processes: @@ -149,15 +146,15 @@ def finalize(self, fl_ctx: FLContext): p.join(self.process_timeout) p.terminate() - def get_job_queue(self, record_origin): + def get_task_queue(self, record_origin): return self.queues.get(record_origin, None) def check_kwargs(self, kwargs): if "project" not in kwargs: - raise ValueError("must provide `project' value") + raise ValueError("must provide 'project' value") if "group" not in kwargs: - raise ValueError("must provide `group' value") + raise ValueError("must provide 'group' value") if "job_type" not in kwargs: - raise ValueError("must provide `job_type' value") + raise ValueError("must provide 'job_type' value") diff --git a/nvflare/app_opt/tracking/wandb/wandb_writer.py b/nvflare/app_opt/tracking/wandb/wandb_writer.py index e1f496bfc1..6d38baff5c 100644 --- a/nvflare/app_opt/tracking/wandb/wandb_writer.py +++ b/nvflare/app_opt/tracking/wandb/wandb_writer.py @@ -14,10 +14,8 @@ from typing import Dict, Optional -from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.analytix import ANALYTIC_EVENT_TYPE, AnalyticsDataType, LogWriterName from nvflare.app_common.tracking.log_writer import LogWriter -from nvflare.app_common.tracking.tracker_types import LogWriterName -from nvflare.app_common.widgets.streaming import ANALYTIC_EVENT_TYPE class WandBWriter(LogWriter): diff --git a/nvflare/client/tracking.py b/nvflare/client/tracking.py index 6bbb370a6a..9659a8be18 100644 --- a/nvflare/client/tracking.py +++ b/nvflare/client/tracking.py @@ -14,8 +14,7 @@ from typing import Dict, Optional -from nvflare.apis.analytix import AnalyticsDataType -from nvflare.app_common.tracking.tracker_types import LogWriterName +from nvflare.apis.analytix import AnalyticsDataType, LogWriterName # flake8: noqa from .api import default_context as default_context diff --git a/tests/unit_test/apis/analytix_test.py b/tests/unit_test/apis/analytix_test.py index ab56bf08d1..931d6d5213 100644 --- a/tests/unit_test/apis/analytix_test.py +++ b/tests/unit_test/apis/analytix_test.py @@ -14,10 +14,9 @@ import pytest -from nvflare.apis.analytix import _DATA_TYPE_KEY, AnalyticsData, AnalyticsDataType +from nvflare.apis.analytix import _DATA_TYPE_KEY, AnalyticsData, AnalyticsDataType, LogWriterName, TrackConst from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.utils.analytix_utils import create_analytic_dxo -from nvflare.app_common.tracking.tracker_types import LogWriterName, TrackConst FROM_DXO_TEST_CASES = [ ("hello", 3.0, 1, AnalyticsDataType.SCALAR), diff --git a/tests/unit_test/app_common/widgets/streaming_test.py b/tests/unit_test/app_common/widgets/streaming_test.py index a5d9351e11..51a7965d3d 100644 --- a/tests/unit_test/app_common/widgets/streaming_test.py +++ b/tests/unit_test/app_common/widgets/streaming_test.py @@ -16,12 +16,11 @@ import pytest -from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.analytix import AnalyticsDataType, LogWriterName, TrackConst from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext from nvflare.apis.utils.analytix_utils import create_analytic_dxo, send_analytic_dxo -from nvflare.app_common.tracking.tracker_types import LogWriterName, TrackConst INVALID_TEST_CASES = [ (list(), dict(), FLContext(), TypeError, f"expect comp to be an instance of FLComponent, but got {type(list())}"),