diff --git a/python/text_utils/api/trainer.py b/python/text_utils/api/trainer.py index 684eb8c..d9368f5 100644 --- a/python/text_utils/api/trainer.py +++ b/python/text_utils/api/trainer.py @@ -88,10 +88,9 @@ def parser(cls, name: str, description: str) -> argparse.ArgumentParser: ) parser.add_argument( "--profile", - type=str, - default=None, - help="Run cProfile profile on main process and output stats to this file " - "(only respected if platform=local)" + action="store_true", + help="Run cProfile profile on main process and output stats to 'profile.pstat' " + "in experiment directory (only works for platform=local)" ) return parser @@ -166,10 +165,14 @@ def __init__( assert dist_type in {"DDP", "FSDP"}, \ f"distributed training type must be either DDP or FSDP, but got {dist_type}" - if dist_type == "DDP": + if self.info.is_single_gpu: + self.model = model.to(self.info.device) + dist_type = "single GPU" + elif dist_type == "DDP": self.model = DDP( model.to(self.info.device), - static_graph=compile + static_graph=compile, + gradient_as_bucket_view=True ) else: offload_params = dist_cfg.get("offload", False) @@ -234,7 +237,7 @@ def __init__( ) ) - self.model: DDP | FSDP = torch.compile( + self.model: nn.Module | DDP | FSDP = torch.compile( self.model, fullgraph=True, disable=not compile @@ -849,7 +852,7 @@ def _train_local_distributed( port: int, cfg: dict[str, Any], directories: dict[str, str], - profile: str | None = None + profile: bool ): logging.setup_logging() os.environ["MASTER_ADDR"] = "localhost" @@ -874,16 +877,25 @@ def _train_local_distributed( assert dist.is_initialized(), "failed to initialize process group" - if info.is_main_process and profile is not None: + if info.is_main_process and profile: import cProfile + torch.cuda.memory._record_memory_history() cProfile.runctx( "cls(cfg, directories, info).run()", globals(), locals(), - filename=profile + filename=os.path.join( + directories["experiment"], + "profile.pstat" + ) ) + torch.cuda.memory._dump_snapshot(os.path.join( + directories["experiment"], + "memory_profile.pickle" + )) else: cls(cfg, directories, info).run() + dist.destroy_process_group() @classmethod @@ -968,7 +980,13 @@ def train_slurm(cls, work_dir: str, experiment_dir: str, config_path: str): dist.destroy_process_group() @classmethod - def train_local(cls, work_dir: str, experiment_dir: str, config_path: str, profile: str | None = None): + def train_local( + cls, + work_dir: str, + experiment_dir: str, + config_path: str, + profile: bool + ): logging.setup_logging() logger = logging.get_logger("LOCAL_INITIALIZATION") num_gpus = torch.cuda.device_count() @@ -1133,7 +1151,7 @@ def step( ), torch.autograd.set_detect_anomaly( os.environ.get("TORCH_SET_DETECT_ANOMALY", "") != "" ): - if i < len(batches) - 1: + if i < len(batches) - 1 and self.info.is_distributed: with self.model.no_sync(): outputs, loss = step( batch, @@ -1144,7 +1162,11 @@ def step( else: # synchronize gradients for the last batch outputs, loss = step( - batch, rank_batch_size, inputs, labels) + batch, + rank_batch_size, + inputs, + labels + ) losses.append(loss) if first_outputs is None: diff --git a/python/text_utils/distributed.py b/python/text_utils/distributed.py index dbcc6a6..2b62ebf 100644 --- a/python/text_utils/distributed.py +++ b/python/text_utils/distributed.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Union +from typing import Any import torch from torch import nn, optim @@ -29,6 +29,14 @@ def __init__( ) self.device = torch.device(device_index) + @property + def is_distributed(self) -> bool: + return not self.is_single_gpu + + @property + def is_single_gpu(self) -> bool: + return self.world_size == 1 + @property def is_local_main_process(self) -> bool: return self.local_rank == 0