Skip to content

Commit

Permalink
add gpu memory profiling, add single gpu option
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Aug 25, 2024
1 parent 4459076 commit 3b47e6e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
48 changes: 35 additions & 13 deletions python/text_utils/api/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ def parser(cls, name: str, description: str) -> argparse.ArgumentParser:
)
parser.add_argument(
"--profile",
type=str,
default=None,
help="Run cProfile profile on main process and output stats to this file "
"(only respected if platform=local)"
action="store_true",
help="Run cProfile profile on main process and output stats to 'profile.pstat' "
"in experiment directory (only works for platform=local)"
)
return parser

Expand Down Expand Up @@ -166,10 +165,14 @@ def __init__(
assert dist_type in {"DDP", "FSDP"}, \
f"distributed training type must be either DDP or FSDP, but got {dist_type}"

if dist_type == "DDP":
if self.info.is_single_gpu:
self.model = model.to(self.info.device)
dist_type = "single GPU"
elif dist_type == "DDP":
self.model = DDP(
model.to(self.info.device),
static_graph=compile
static_graph=compile,
gradient_as_bucket_view=True
)
else:
offload_params = dist_cfg.get("offload", False)
Expand Down Expand Up @@ -234,7 +237,7 @@ def __init__(
)
)

self.model: DDP | FSDP = torch.compile(
self.model: nn.Module | DDP | FSDP = torch.compile(
self.model,
fullgraph=True,
disable=not compile
Expand Down Expand Up @@ -849,7 +852,7 @@ def _train_local_distributed(
port: int,
cfg: dict[str, Any],
directories: dict[str, str],
profile: str | None = None
profile: bool
):
logging.setup_logging()
os.environ["MASTER_ADDR"] = "localhost"
Expand All @@ -874,16 +877,25 @@ def _train_local_distributed(

assert dist.is_initialized(), "failed to initialize process group"

if info.is_main_process and profile is not None:
if info.is_main_process and profile:
import cProfile
torch.cuda.memory._record_memory_history()
cProfile.runctx(
"cls(cfg, directories, info).run()",
globals(),
locals(),
filename=profile
filename=os.path.join(
directories["experiment"],
"profile.pstat"
)
)
torch.cuda.memory._dump_snapshot(os.path.join(
directories["experiment"],
"memory_profile.pickle"
))
else:
cls(cfg, directories, info).run()

dist.destroy_process_group()

@classmethod
Expand Down Expand Up @@ -968,7 +980,13 @@ def train_slurm(cls, work_dir: str, experiment_dir: str, config_path: str):
dist.destroy_process_group()

@classmethod
def train_local(cls, work_dir: str, experiment_dir: str, config_path: str, profile: str | None = None):
def train_local(
cls,
work_dir: str,
experiment_dir: str,
config_path: str,
profile: bool
):
logging.setup_logging()
logger = logging.get_logger("LOCAL_INITIALIZATION")
num_gpus = torch.cuda.device_count()
Expand Down Expand Up @@ -1133,7 +1151,7 @@ def step(
), torch.autograd.set_detect_anomaly(
os.environ.get("TORCH_SET_DETECT_ANOMALY", "") != ""
):
if i < len(batches) - 1:
if i < len(batches) - 1 and self.info.is_distributed:
with self.model.no_sync():
outputs, loss = step(
batch,
Expand All @@ -1144,7 +1162,11 @@ def step(
else:
# synchronize gradients for the last batch
outputs, loss = step(
batch, rank_batch_size, inputs, labels)
batch,
rank_batch_size,
inputs,
labels
)

losses.append(loss)
if first_outputs is None:
Expand Down
10 changes: 9 additions & 1 deletion python/text_utils/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Union
from typing import Any

import torch
from torch import nn, optim
Expand Down Expand Up @@ -29,6 +29,14 @@ def __init__(
)
self.device = torch.device(device_index)

@property
def is_distributed(self) -> bool:
return not self.is_single_gpu

@property
def is_single_gpu(self) -> bool:
return self.world_size == 1

@property
def is_local_main_process(self) -> bool:
return self.local_rank == 0
Expand Down

0 comments on commit 3b47e6e

Please sign in to comment.