From 1d2bf4dc55c5dad503873974f3653a55e55a03d5 Mon Sep 17 00:00:00 2001 From: nvkevlu <55759229+nvkevlu@users.noreply.github.com> Date: Mon, 8 Jan 2024 17:35:30 -0500 Subject: [PATCH] Add global round info for logging metrics global step (#2258) * Add global round info for logging metrics global step * Fix ci * update how to get current round * improvements * fixes --- .../experiment-tracking/pt/learner_with_mlflow.py | 9 ++++++--- .../advanced/experiment-tracking/pt/learner_with_tb.py | 9 ++++++--- .../experiment-tracking/pt/learner_with_wandb.py | 9 ++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/examples/advanced/experiment-tracking/pt/learner_with_mlflow.py b/examples/advanced/experiment-tracking/pt/learner_with_mlflow.py index dd8949a46f..bfd591f961 100644 --- a/examples/advanced/experiment-tracking/pt/learner_with_mlflow.py +++ b/examples/advanced/experiment-tracking/pt/learner_with_mlflow.py @@ -24,7 +24,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable -from nvflare.apis.fl_constant import ReservedKey, ReturnCode +from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal @@ -44,6 +44,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana """Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset. Args: + data_path (str): Path that the data will be stored at. Defaults to "~/data". lr (float, optional): Learning rate. Defaults to 0.01 epochs (int, optional): Epochs. Defaults to 5 exclude_vars (list): List of variables to exclude during model loading. @@ -63,6 +64,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana self.loss = None self.device = None self.model = None + self.data_path = data_path self.lr = lr self.epochs = epochs @@ -147,6 +149,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha def local_train(self, fl_ctx, abort_signal): # Basic training + current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round") for epoch in range(self.epochs): self.model.train() running_loss = 0.0 @@ -174,12 +177,12 @@ def local_train(self, fl_ctx, abort_signal): ) # Stream training loss at each step - current_step = len(self.train_loader) * epoch + i + current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i self.writer.log_metrics({"train_loss": cost.item(), "running_loss": running_loss}, current_step) # Stream validation accuracy at the end of each epoch metric = self.local_validate(abort_signal) - self.writer.log_metric("validation_accuracy", metric, epoch) + self.writer.log_metric("validation_accuracy", metric, current_step) def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable: run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) diff --git a/examples/advanced/experiment-tracking/pt/learner_with_tb.py b/examples/advanced/experiment-tracking/pt/learner_with_tb.py index 49860c2913..50bd45eb69 100644 --- a/examples/advanced/experiment-tracking/pt/learner_with_tb.py +++ b/examples/advanced/experiment-tracking/pt/learner_with_tb.py @@ -43,7 +43,7 @@ class PTLearner(Learner): def __init__( self, - data_path="/tmp/nvflare/tensorboard-streaming", + data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, @@ -52,6 +52,7 @@ def __init__( """Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset. Args: + data_path (str): Path that the data will be stored at. Defaults to "~/data". lr (float, optional): Learning rate. Defaults to 0.01 epochs (int, optional): Epochs. Defaults to 5 exclude_vars (list): List of variables to exclude during model loading. @@ -71,6 +72,7 @@ def __init__( self.loss = None self.device = None self.model = None + self.data_path = data_path self.lr = lr self.epochs = epochs @@ -150,6 +152,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha def local_train(self, fl_ctx, abort_signal): # Basic training + current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round") for epoch in range(self.epochs): self.model.train() running_loss = 0.0 @@ -173,12 +176,12 @@ def local_train(self, fl_ctx, abort_signal): running_loss = 0.0 # Stream training loss at each step - current_step = len(self.train_loader) * epoch + i + current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i self.writer.add_scalar("train_loss", cost.item(), current_step) # Stream validation accuracy at the end of each epoch metric = self.local_validate(abort_signal) - self.writer.add_scalar("validation_accuracy", metric, epoch) + self.writer.add_scalar("validation_accuracy", metric, current_step) def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable: run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id()) diff --git a/examples/advanced/experiment-tracking/pt/learner_with_wandb.py b/examples/advanced/experiment-tracking/pt/learner_with_wandb.py index ad5ee47cc1..0245b9aeb2 100644 --- a/examples/advanced/experiment-tracking/pt/learner_with_wandb.py +++ b/examples/advanced/experiment-tracking/pt/learner_with_wandb.py @@ -24,7 +24,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable -from nvflare.apis.fl_constant import ReservedKey, ReturnCode +from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal @@ -44,6 +44,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana """Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset. Args: + data_path (str): Path that the data will be stored at. Defaults to "~/data". lr (float, optional): Learning rate. Defaults to 0.01 epochs (int, optional): Epochs. Defaults to 5 exclude_vars (list): List of variables to exclude during model loading. @@ -63,6 +64,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana self.loss = None self.device = None self.model = None + self.data_path = data_path self.lr = lr self.epochs = epochs @@ -141,6 +143,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha def local_train(self, fl_ctx, abort_signal): # Basic training + current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round") for epoch in range(self.epochs): self.model.train() running_loss = 0.0 @@ -164,12 +167,12 @@ def local_train(self, fl_ctx, abort_signal): running_loss = 0.0 # Stream training loss at each step - current_step = len(self.train_loader) * epoch + i + current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i self.writer.log({"train_loss": cost.item()}, current_step) # Stream validation accuracy at the end of each epoch metric = self.local_validate(abort_signal) - self.writer.log({"validation_accuracy": metric}, epoch) + self.writer.log({"validation_accuracy": metric}, current_step) def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable: run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())