Skip to content

Commit

Permalink
fixes to optimizer and trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jun 26, 2024
1 parent 0977f76 commit 78401eb
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 77 deletions.
111 changes: 62 additions & 49 deletions python/text_utils/api/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import shutil
import time
import zipfile
from typing import Dict, Optional, Tuple, Any, List, Callable, Union
from typing import Any, Callable

import torch
from torch.backends import cuda, cudnn
from torch import GradScaler, distributed as dist
from torch import distributed as dist
from torch import multiprocessing as mp
from torch import nn
from torch.optim import lr_scheduler
Expand Down Expand Up @@ -97,8 +97,8 @@ def parser(cls, name: str, description: str) -> argparse.ArgumentParser:

def __init__(
self,
cfg: Dict[str, Any],
directories: Dict[str, str],
cfg: dict[str, Any],
directories: dict[str, str],
info: distributed.DistributedInfo
):
self.cfg = cfg
Expand All @@ -112,7 +112,7 @@ def __init__(
self.epoch_step = 0
self.epoch = 0
self.best_val_loss = float("inf")
self.best_benchmark: Optional[float] = None
self.best_benchmark: float | None = None
self.logger = logging.get_logger("TRAIN")

if self.info.is_main_process:
Expand All @@ -136,7 +136,7 @@ def __init__(
cuda.matmul.allow_tf32 = True
cudnn.benchmark = True

model, sharding_policy = self._model_from_config(self.cfg)
model = self._model_from_config(self.cfg)

peft = self.cfg["train"].get("peft", None)
if peft is not None:
Expand All @@ -150,6 +150,8 @@ def __init__(
peft
)

sharding_policy = self._sharding_policy(model)

mixed_precision = self.cfg["train"].get("mixed_precision", None)
if mixed_precision == "fp16":
self.mixed_precision = torch.float16
Expand All @@ -175,7 +177,12 @@ def __init__(
strategy = ShardingStrategy[dist_cfg.get("strategy", "NO_SHARD")]
if strategy != ShardingStrategy.NO_SHARD:
shard_size = dist_cfg.get("shard_size", None)
if shard_size is not None:
if sharding_policy is not None:
if self.info.is_main_process:
self.logger.info(
"sharding based on custom policy"
)
elif shard_size is not None:
if self.info.is_main_process:
self.logger.info(
f"sharding based on number of parameters with "
Expand All @@ -187,39 +194,34 @@ def __init__(
force_leaf_modules=None,
exclude_wrap_modules=None
)
elif sharding_policy is None:
if self.info.is_main_process:
self.logger.info(
f"sharding strategy is {strategy.name}, but got "
f"no sharding policy, disabling sharding"
)
strategy = ShardingStrategy.NO_SHARD
else:
raise ValueError(
"sharding strategy is set, but no custom sharding policy "
"or shard size is specified"
)
else:
sharding_policy = None
offload_params = False

offload_state_dict = self.info.world_size > 1

self.model = FSDP(
model,
auto_wrap_policy=sharding_policy,
mixed_precision=MixedPrecision(
param_dtype=self.mixed_precision,
reduce_dtype=self.mixed_precision,
buffer_dtype=self.mixed_precision
),
cpu_offload=CPUOffload(offload_params=offload_params),
limit_all_gathers=True,
sharding_strategy=strategy,
forward_prefetch=prefetch,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE if prefetch else BackwardPrefetch.BACKWARD_POST,
device_id=self.info.device,
use_orig_params=compile or (peft is not None),
)
# set mixed precision to none here for FSDP to avoid autocasting
# later, because FSDP handles mixed precision itself
self.mixed_precision = None

offload_state_dict = self.info.world_size > 1
FSDP.set_state_dict_type(
self.model,
StateDictType.FULL_STATE_DICT,
Expand Down Expand Up @@ -311,7 +313,7 @@ def __init__(
else:
self.cooldown_items = 0

self.cooldown_scheduler: Optional[lr_scheduler.LambdaLR] = None
self.cooldown_scheduler: lr_scheduler.LambdaLR | None = None

if "lr_scheduler" in self.cfg["train"]:
self.step_interval = clamp(
Expand Down Expand Up @@ -441,6 +443,7 @@ def _save_checkpoint(
self.model,
self.optimizer
)
save["grad_scaler_state_dict"] = self.grad_scaler.state_dict()
save["loss_fn_state_dict"] = self.loss_fn.state_dict()
if self.lr_scheduler is not None:
save["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict()
Expand All @@ -465,6 +468,7 @@ def _load_checkpoint(self, path: str):
optim_state_dict,
)
self.optimizer.load_state_dict(optim_state_dict)
self.grad_scaler.load_state_dict(checkpoint["grad_scaler_state_dict"])
if self.lr_scheduler is not None and checkpoint.get("lr_scheduler_state_dict") is not None:
self.lr_scheduler.load_state_dict(
checkpoint["lr_scheduler_state_dict"]
Expand Down Expand Up @@ -510,24 +514,31 @@ def _prepare_peft(
@classmethod
def _model_from_config(
cls,
cfg: Dict[str, Any]
) -> Tuple[nn.Module, Optional[ShardingPolicy]]:
cfg: dict[str, Any]
) -> nn.Module:
raise NotImplementedError

@classmethod
def _additional_loss_fn(cls) -> Optional[Callable]:
def _sharding_policy(
cls,
model: nn.Module,
) -> ShardingPolicy | None:
return None

@classmethod
def _additional_loss_fn(cls) -> Callable | None:
return None

@classmethod
def _additional_optimizer_fn(cls) -> Optional[Callable]:
def _additional_optimizer_fn(cls) -> Callable | None:
return None

@classmethod
def _additional_lr_scheduler_fn(cls) -> Optional[Callable]:
def _additional_lr_scheduler_fn(cls) -> Callable | None:
return None

@classmethod
def _additional_max_length_scheduler_fn(cls) -> Optional[Callable]:
def _additional_max_length_scheduler_fn(cls) -> Callable | None:
return None

@classmethod
Expand Down Expand Up @@ -559,13 +570,13 @@ def _copy_file_to_tmp_dir(cls, path: str, dir: str, info: distributed.Distribute
@classmethod
def _prepare_data_sources(
cls,
sources: List[Dict[str, Any]],
sources: list[dict[str, Any]],
info: distributed.DistributedInfo
) -> Tuple[
List[Tuple[str, Optional[str]]],
List[Optional[Any]],
List[Optional[Any]],
List[str]
) -> tuple[
list[tuple[str, str | None]],
list[Any | None],
list[Any | None],
list[str]
]:
src_paths = []
src_preprocessings = []
Expand Down Expand Up @@ -632,25 +643,25 @@ def _prepare_data_sources(
@classmethod
def _data_from_config(
cls,
train_cfg: Dict[str, Any],
val_cfg: Union[List[Any], int],
train_cfg: dict[str, Any],
val_cfg: list[Any] | int,
num_epochs: int,
seed: Optional[int],
seed: int | None,
info: distributed.DistributedInfo
) -> Tuple[
) -> tuple[
data.TrainLoader,
data.TrainLoader,
int,
int,
int,
Optional[Callable[[int], int]],
List[str]
Callable[[int], int] | None,
list[str]
]:
def prepare_data_loader(
pipeline_cfg: Dict[str, Any],
sources: List[tuple[str, str | None]],
preprocessings: List[Optional[Any]],
postprocessings: List[Optional[Any]],
pipeline_cfg: dict[str, Any],
sources: list[tuple[str, str | None]],
preprocessings: list[Any | None],
postprocessings: list[Any | None],
**kwargs: Any,
) -> data.TrainLoader:
pipeline_cfg = copy.deepcopy(pipeline_cfg)
Expand All @@ -676,7 +687,9 @@ def prepare_data_loader(
# adapt config to multi gpu usage
assert "batch_limit" in train_cfg, "batch_limit must be in data config"
train_cfg["batch_limit"] = max(
1, train_cfg["batch_limit"] // info.world_size)
1,
train_cfg["batch_limit"] // info.world_size
)

# pop some configs not used by the dataloader
max_length = train_cfg.pop("max_length")
Expand Down Expand Up @@ -777,7 +790,7 @@ def prepare_data_loader(
)

@classmethod
def _setup_experiment(cls, work_dir: str, exp_dir: str, config_path: str, cfg: Dict[str, Any]):
def _setup_experiment(cls, work_dir: str, exp_dir: str, config_path: str, cfg: dict[str, Any]):
config_name = os.path.split(config_path)[-1]
os.makedirs(exp_dir, exist_ok=True)
# save the resolved config to the experiment directory
Expand Down Expand Up @@ -812,9 +825,9 @@ def _train_local_distributed(
rank: int,
world_size: int,
port: int,
cfg: Dict[str, Any],
directories: Dict[str, str],
profile: Optional[str] = None
cfg: dict[str, Any],
directories: dict[str, str],
profile: str | None = None
):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
Expand Down Expand Up @@ -931,7 +944,7 @@ def train_slurm(cls, work_dir: str, experiment_dir: str, config_path: str):
dist.destroy_process_group()

@classmethod
def train_local(cls, work_dir: str, experiment_dir: str, config_path: str, profile: Optional[str] = None):
def train_local(cls, work_dir: str, experiment_dir: str, config_path: str, profile: str | None = None):
logger = logging.get_logger("LOCAL_INITIALIZATION")
num_gpus = torch.cuda.device_count()
assert num_gpus > 0, "need at least one GPU for local training"
Expand Down Expand Up @@ -964,8 +977,8 @@ def train_local(cls, work_dir: str, experiment_dir: str, config_path: str, profi
join=True
)

def _prepare_batch(self, batch: data.TrainBatch) -> Tuple[
Dict[str, Any],
def _prepare_batch(self, batch: data.TrainBatch) -> tuple[
dict[str, Any],
torch.Tensor
]:
raise NotImplementedError("prepare batch not implemented")
Expand Down
52 changes: 24 additions & 28 deletions python/text_utils/modules/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
from typing import Dict, Any, Iterator, Tuple, Optional, Callable

from torch import nn, optim
try:
from bitsandbytes import optim as optim_8bit
_8BIT_OPTIMIZERS = True
except ImportError:
_8BIT_OPTIMIZERS = False


def _select_params_and_modules(
Expand All @@ -29,12 +24,24 @@ def optimizer_from_config(
) -> optim.Optimizer:
cfg = copy.deepcopy(cfg)
opt_type = cfg.pop("type")

if opt_type == "adamw":
optim_cls = optim.AdamW
elif opt_type == "adam":
optim_cls = optim.Adam
elif opt_type == "sgd":
optim_cls = optim.SGD
else:
if additional_optimizer_fn is not None:
return additional_optimizer_fn(model, cfg)
raise ValueError(f"unknown optimizer type {opt_type}")

param_groups: list[dict[str, Any]] = cfg.pop("param_groups", [{"prefix": None}])
assert len(param_groups) > 0, "param_groups must be non-empty"

weight_decay_modules: dict[str, list[str]] = cfg.pop(
weight_decay_modules: dict[str, list[str]] | str = cfg.pop(
"weight_decay_modules",
{}
"all"
)
all = set(name for name, p in model.named_parameters() if p.requires_grad)
params = []
Expand All @@ -60,8 +67,16 @@ def optimizer_from_config(
names.add(name)
param_dict[name] = param
mod_name = mod.__class__.__name__
if mod_name in weight_decay_modules and any(
name.endswith(suffix) for suffix in weight_decay_modules[mod_name]
if (
weight_decay_modules == "all"
or (
isinstance(weight_decay_modules, dict)
and mod_name in weight_decay_modules
and any(
name.endswith(suffix)
for suffix in weight_decay_modules[mod_name]
)
)
):
decay.add(name)
else:
Expand All @@ -85,23 +100,4 @@ def optimizer_from_config(
assert len(unused) == 0, \
f"parameter groups dont match trainable model parameters: {unused}"

optim_bits = int(cfg.get("optim_bits", 32))
assert optim_bits in [32, 8], f"optim_bits must be 32 or 8, got {optim_bits}"
use_8bit = optim_bits == 8
if use_8bit:
assert _8BIT_OPTIMIZERS, "8-bit optimizers not available"
else:
cfg.pop("optim_bits", None)

if opt_type == "adamw":
optim_cls = optim_8bit.AdamW if use_8bit else optim.AdamW
elif opt_type == "adam":
optim_cls = optim_8bit.Adam if use_8bit else optim.Adam
elif opt_type == "sgd":
optim_cls = optim_8bit.SGD if use_8bit else optim.SGD
else:
if additional_optimizer_fn is not None:
return additional_optimizer_fn(model, cfg)
raise ValueError(f"unknown optimizer type {opt_type}")

return optim_cls(params, **cfg)

0 comments on commit 78401eb

Please sign in to comment.