Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Feb 21, 2025
1 parent aaf1e48 commit 0c04e94
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 199 deletions.
31 changes: 11 additions & 20 deletions examples/advanced/bionemo/downstream/bionemo_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,15 @@
# limitations under the License.

from typing import Union
from torch import Tensor

from nvflare.apis.dxo import DXO, DataKind, MetaKey
from nvflare.apis.dxo import DXO, DataKind
from nvflare.apis.dxo_filter import DXOFilter
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable


class BioNeMoParamsFilter(DXOFilter):
def __init__(
self,
precision="bf16-mixed"
):
def __init__(self, precision="bf16-mixed"):
"""Filter to add a prefix to global state dict to avoid key mismatches between global and local state dictionaries.
This is needed because of NeMo training framework adding module wrappers depending on the used training precision.
Expand Down Expand Up @@ -60,18 +56,15 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
params = dxo.data
new_params = {}
for k, v in params.items():
new_key = self._prefix + k
new_params[new_key] = v
new_key = self._prefix + k
new_params[new_key] = v

dxo.data = new_params
return dxo


class BioNeMoExcludeParamsFilter(DXOFilter):
def __init__(
self,
exclude_vars="head"
):
def __init__(self, exclude_vars="head"):
"""Filter to remove parameters from state dictionary that shouldn't be shared with other party.
Args:
Expand All @@ -84,7 +77,6 @@ def __init__(

self.exclude_vars = exclude_vars


def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Union[None, DXO]:
"""Filter process apply to the Shareable object.
Expand All @@ -100,14 +92,13 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
params = dxo.data
new_params = {}
for k, v in params.items():
if self.exclude_vars not in k:
new_params[k] = v
if self.exclude_vars not in k:
new_params[k] = v

if len(new_params) < len(params):
self.log_info(fl_ctx, f"Excluded {len(params)-len(new_params)} parameters matching '{self.exclude_vars}'")
else:
raise ValueError(f"State dictionary did not match any exclude keys that matched '{self.exclude_vars}'")
raise ValueError(f"State dictionary did not match any exclude keys that matched '{self.exclude_vars}'")

dxo.data = new_params
return dxo

73 changes: 32 additions & 41 deletions examples/advanced/bionemo/downstream/finetune_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,29 @@
# Copied and adapted for NVFlare from https://github.com/NVIDIA/bionemo-framework/blob/main/sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py

import shutil
import argparse
import random
from pathlib import Path
from lightning import seed_everything
from typing import Dict, List, Optional, Sequence, Tuple, Type, get_args

from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning import resume
from nemo.lightning.pytorch import callbacks as nl_callbacks
from nemo.lightning.pytorch.optim import MegatronOptimizerModule
from typing import List, Optional, Tuple, Type

from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
from bionemo.esm2.data.tokenizer import get_tokenizer
from bionemo.esm2.model.finetune.datamodule import ESM2FineTuneDataModule
from bionemo.esm2.model.finetune.dataset import (
InMemoryPerTokenValueDataset,
InMemoryProteinDataset,
InMemorySingleValueDataset,
)
from bionemo.esm2.model.finetune.dataset import InMemoryProteinDataset, InMemorySingleValueDataset
from bionemo.esm2.model.finetune.sequence_model import ESM2FineTuneSeqConfig
from bionemo.esm2.model.finetune.token_model import ESM2FineTuneTokenConfig

# Resue parser and config constants from bionemo
from bionemo.esm2.scripts.finetune_esm2 import get_parser
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BioBertConfig
from bionemo.llm.model.config import TorchmetricsConfig
from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size
from bionemo.llm.utils.datamodule_utils import infer_global_batch_size
from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger

# Resue parser and config constants from bionemo
from bionemo.esm2.scripts.finetune_esm2 import get_parser, SUPPORTED_CONFIGS, SUPPORTED_DATASETS
from lightning.pytorch.callbacks import Callback, LearningRateMonitor, RichModelSummary
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning.pytorch import callbacks as nl_callbacks
from nemo.lightning.pytorch.optim import MegatronOptimizerModule

# (1) import nvflare lightning client API
import nvflare.client.lightning as flare
Expand Down Expand Up @@ -111,7 +101,7 @@ def train_model(
average_in_collective: bool = True,
grad_reduce_in_fp32: bool = False,
label_column: str = "labels",
classes: List[str] = None
classes: List[str] = None,
) -> Tuple[Path, Callback | None, nl.Trainer]:
"""Train an ESM2 model on UR data.
Expand Down Expand Up @@ -265,7 +255,9 @@ def train_model(
# because after flare.patch the trainer.fit/validate will get the
# global model internally
input_model = flare.receive()
print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}, Global model = {input_model} ({len(input_model.params)} params)]\n")
print(
f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}, Global model = {input_model} ({len(input_model.params)} params)]\n"
)
# use a unique result directory for each round

# Remove previous checkpoints to preserve disk space
Expand All @@ -274,21 +266,21 @@ def train_model(
previous_ckpt_dir = result_dir / f"round{input_model.current_round-1}" / experiment_name / "dev" / "checkpoints"
if previous_ckpt_dir.is_dir():
print(f"Removing previous checkpoint directory {previous_ckpt_dir}")
shutil.rmtree(previous_ckpt_dir)
shutil.rmtree(previous_ckpt_dir)

# create output folder for this round
result_dir = result_dir / f"round{input_model.current_round}"

# add a learning rate decay for each round
if input_model.current_round > 0:
lr_step_reduce = 1.05 # TODO: make lr_step_reduce configurable
new_lr = lr/(input_model.current_round*lr_step_reduce)
new_lr_multiplier = lr_multiplier/(input_model.current_round*lr_step_reduce)
new_lr = lr / (input_model.current_round * lr_step_reduce)
new_lr_multiplier = lr_multiplier / (input_model.current_round * lr_step_reduce)
print(f"Reduce lr {lr} by {input_model.current_round*lr_step_reduce}: {new_lr}")
else:
new_lr = lr
new_lr_multiplier = lr_multiplier
new_lr_multiplier = lr_multiplier

# remaining bionemo training code
tokenizer = get_tokenizer()

Expand All @@ -302,7 +294,7 @@ def train_model(
train_dataset.label_tokenizer.build_vocab([classes])
print(f"Build custom label tokenizer based on label classes: {classes}")
valid_dataset.label_tokenizer = train_dataset.label_tokenizer

data_module = ESM2FineTuneDataModule(
train_dataset=train_dataset,
valid_dataset=valid_dataset,
Expand Down Expand Up @@ -379,7 +371,7 @@ def train_model(

module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer)

#If client should save best local checkpoints, set to `save_local_ckpt=True`,
# If client should save best local checkpoints, set to `save_local_ckpt=True`,
save_local_ckpt = False
if save_local_ckpt:
# Configure our custom Checkpointer
Expand All @@ -393,7 +385,7 @@ def train_model(
)
else:
checkpoint_callback = None

# Setup the logger and train the model
nemo_logger = setup_nemo_lightning_logger(
root_dir=result_dir,
Expand All @@ -402,15 +394,15 @@ def train_model(
wandb_config=wandb_config,
ckpt_callback=checkpoint_callback,
)

# perform local training starting with the received global model
llm.train(
model=module,
data=data_module,
trainer=trainer,
log=nemo_logger,
resume=None,
)
)

if checkpoint_callback:
ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
Expand All @@ -431,7 +423,7 @@ def finetune_esm2_entrypoint():
required=False,
default=None,
help="Unique strings describing the classes for classification. Used to build the same label vocabulary on each client. Should be comma separate list of strings, e.g. 'pos,neg'",
)
)
args = parser.parse_args()

if args.classes:
Expand All @@ -440,7 +432,6 @@ def finetune_esm2_entrypoint():
classes = args.classes.split(",")
else:
classes = None


# to avoid padding for single value labels:
if args.min_seq_length is not None and args.datset_class is InMemorySingleValueDataset:
Expand Down Expand Up @@ -503,10 +494,10 @@ def finetune_esm2_entrypoint():
average_in_collective=not args.no_average_in_collective,
grad_reduce_in_fp32=args.grad_reduce_in_fp32,
label_column=args.label_column,
classes=classes
classes=classes,
)



if __name__ == "__main__":
finetune_esm2_entrypoint()
flare.shutdown()

Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def main():
if do_clean_chains:
train_df = clean_chains(train_df)
test_df = clean_chains(test_df)

_split_dir = os.path.join(split_dir, "train")
if not os.path.isdir(_split_dir):
os.makedirs(_split_dir)
Expand Down
70 changes: 38 additions & 32 deletions examples/advanced/bionemo/downstream/sabdab/run_sim_sabdab.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,24 @@
# limitations under the License.

import argparse
import logging
import os
import sys

from nvflare import FedJob, FilterType
from bionemo.core.data.load import load

from nvflare import FilterType
from nvflare.app_common.launchers.subprocess_launcher import SubprocessLauncher
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner, BaseScriptRunner
from nvflare.app_common.launchers.subprocess_launcher import SubprocessLauncher
from nvflare.job_config.script_runner import BaseScriptRunner

import os
import pandas as pd
import sys
sys.path.append(os.path.join(os.getcwd(), "..")) # include parent folder in path
sys.path.append(os.path.join(os.getcwd(), "..")) # include parent folder in path
from bionemo_filters import BioNeMoParamsFilter


def main(args):
# Create BaseFedJob with initial model
job = BaseFedJob(
name=f"{args.exp_name}_sabdab_esm2_{args.model}"
)
job = BaseFedJob(name=f"{args.exp_name}_sabdab_esm2_{args.model}")

# Define the controller and send to server
controller = FedAvg(
Expand All @@ -48,11 +44,11 @@ def main(args):

# Define unique strings describing the classes for classification so we can use the same label vocabulary on each client.
classes = "pos,neg"

# Add clients
for i in range(args.num_clients):
client_name = f"site-{i+1}"

# define data paths
# We use the same validation set for each client to make their metrics comparable
val_data_path = "/tmp/data/sabdab_chen/val/sabdab_chen_valid.csv"
Expand All @@ -61,29 +57,32 @@ def main(args):
assert args.num_clients == 1, "Use num_clients=1 for simulating 'central' training setting."
assert args.num_rounds == 1, "Use num_rounds=1 for simulating 'central' training setting."
train_data_path = "/tmp/data/sabdab_chen/train/sabdab_chen_full_train.csv"
val_check_interval = int(args.local_steps/20) # 20 times per training
else: # local or fedavg setting
train_data_path = f"/tmp/data/sabdab_chen/train/sabdab_chen_{client_name}_train.csv"
val_check_interval = int(args.local_steps / 20) # 20 times per training
else: # local or fedavg setting
train_data_path = f"/tmp/data/sabdab_chen/train/sabdab_chen_{client_name}_train.csv"
if args.num_rounds > 1:
val_check_interval = args.local_steps
else:
val_check_interval = int(args.local_steps/20) # 20 times per training
val_check_interval = int(args.local_steps / 20) # 20 times per training

# define training script arguments
#precision = "bf16-mixed"
# precision = "bf16-mixed"
precision = "fp32"
script_args = f"--restore-from-checkpoint-path {checkpoint_path} --train-data-path {train_data_path} --valid-data-path {val_data_path} --config-class ESM2FineTuneSeqConfig --dataset-class InMemorySingleValueDataset --task-type classification --mlp-ft-dropout 0.1 --mlp-hidden-size 256 --mlp-target-size 2 --experiment-name {job.name} --num-steps {args.local_steps} --num-gpus 1 --val-check-interval {val_check_interval} --log-every-n-steps 10 --lr 5e-4 --lr-multiplier 1e3 --scale-lr-layer classification_head --result-dir bionemo --micro-batch-size 64 --precision {precision} --save-top-k 1 --limit-val-batches 1.0 --classes {classes}"
print(f"Running {args.train_script} with args: {script_args}")

# Define training script runner
runner = BaseScriptRunner(script=args.train_script,
launch_external_process=True,
framework="pytorch",
params_exchange_format="pytorch",
launcher=SubprocessLauncher(script=f"python3 custom/{args.train_script} {script_args}",
launch_once=False))
runner = BaseScriptRunner(
script=args.train_script,
launch_external_process=True,
framework="pytorch",
params_exchange_format="pytorch",
launcher=SubprocessLauncher(script=f"python3 custom/{args.train_script} {script_args}", launch_once=False),
)
job.to(runner, client_name)
job.to(BioNeMoParamsFilter(precision), client_name, tasks=["train", "validate"], filter_type=FilterType.TASK_DATA)
job.to(
BioNeMoParamsFilter(precision), client_name, tasks=["train", "validate"], filter_type=FilterType.TASK_DATA
)

job.export_job("./exported_jobs")
job.simulator_run(f"/tmp/nvflare/bionemo/sabdab/{job.name}", gpu=args.sim_gpus)
Expand All @@ -94,12 +93,19 @@ def main(args):
parser.add_argument("--num_clients", type=int, help="Number of clients", required=False, default=1)
parser.add_argument("--num_rounds", type=int, help="Number of rounds", required=False, default=30)
parser.add_argument("--local_steps", type=int, help="Number of rounds", required=False, default=10)
parser.add_argument("--train_script", type=str, help="Training script", required=False, default="../finetune_esm2.py")
parser.add_argument(
"--train_script", type=str, help="Training script", required=False, default="../finetune_esm2.py"
)
parser.add_argument("--exp_name", type=str, help="Job name prefix", required=False, default="fedavg")
parser.add_argument("--model", choices=["8m", "650m", "3b"], help="ESM2 model", required=False, default="8m")
parser.add_argument("--sim_gpus", type=str, help="GPU indexes to simulate clients, e.g., '0,1,2,3' if you want to run 4 clients, each on a separate GPU. By default run all clients on the same GPU 0.", required=False, default="0")
parser.add_argument(
"--sim_gpus",
type=str,
help="GPU indexes to simulate clients, e.g., '0,1,2,3' if you want to run 4 clients, each on a separate GPU. By default run all clients on the same GPU 0.",
required=False,
default="0",
)

args = parser.parse_args()

args = parser.parse_args()

main(args)

Loading

0 comments on commit 0c04e94

Please sign in to comment.