diff --git a/python/text_utils/api/trainer.py b/python/text_utils/api/trainer.py index da2d208..2e05ff5 100644 --- a/python/text_utils/api/trainer.py +++ b/python/text_utils/api/trainer.py @@ -9,11 +9,11 @@ import shutil import time import zipfile -from typing import Dict, Optional, Tuple, Any, List, Callable, Union +from typing import Any, Callable import torch from torch.backends import cuda, cudnn -from torch import GradScaler, distributed as dist +from torch import distributed as dist from torch import multiprocessing as mp from torch import nn from torch.optim import lr_scheduler @@ -97,8 +97,8 @@ def parser(cls, name: str, description: str) -> argparse.ArgumentParser: def __init__( self, - cfg: Dict[str, Any], - directories: Dict[str, str], + cfg: dict[str, Any], + directories: dict[str, str], info: distributed.DistributedInfo ): self.cfg = cfg @@ -112,7 +112,7 @@ def __init__( self.epoch_step = 0 self.epoch = 0 self.best_val_loss = float("inf") - self.best_benchmark: Optional[float] = None + self.best_benchmark: float | None = None self.logger = logging.get_logger("TRAIN") if self.info.is_main_process: @@ -136,7 +136,7 @@ def __init__( cuda.matmul.allow_tf32 = True cudnn.benchmark = True - model, sharding_policy = self._model_from_config(self.cfg) + model = self._model_from_config(self.cfg) peft = self.cfg["train"].get("peft", None) if peft is not None: @@ -150,6 +150,8 @@ def __init__( peft ) + sharding_policy = self._sharding_policy(model) + mixed_precision = self.cfg["train"].get("mixed_precision", None) if mixed_precision == "fp16": self.mixed_precision = torch.float16 @@ -175,7 +177,12 @@ def __init__( strategy = ShardingStrategy[dist_cfg.get("strategy", "NO_SHARD")] if strategy != ShardingStrategy.NO_SHARD: shard_size = dist_cfg.get("shard_size", None) - if shard_size is not None: + if sharding_policy is not None: + if self.info.is_main_process: + self.logger.info( + "sharding based on custom policy" + ) + elif shard_size is not None: if self.info.is_main_process: self.logger.info( f"sharding based on number of parameters with " @@ -187,26 +194,21 @@ def __init__( force_leaf_modules=None, exclude_wrap_modules=None ) - elif sharding_policy is None: - if self.info.is_main_process: - self.logger.info( - f"sharding strategy is {strategy.name}, but got " - f"no sharding policy, disabling sharding" - ) - strategy = ShardingStrategy.NO_SHARD + else: + raise ValueError( + "sharding strategy is set, but no custom sharding policy " + "or shard size is specified" + ) else: sharding_policy = None offload_params = False - offload_state_dict = self.info.world_size > 1 - self.model = FSDP( model, auto_wrap_policy=sharding_policy, mixed_precision=MixedPrecision( param_dtype=self.mixed_precision, reduce_dtype=self.mixed_precision, - buffer_dtype=self.mixed_precision ), cpu_offload=CPUOffload(offload_params=offload_params), limit_all_gathers=True, @@ -214,12 +216,12 @@ def __init__( forward_prefetch=prefetch, backward_prefetch=BackwardPrefetch.BACKWARD_PRE if prefetch else BackwardPrefetch.BACKWARD_POST, device_id=self.info.device, - use_orig_params=compile or (peft is not None), ) # set mixed precision to none here for FSDP to avoid autocasting # later, because FSDP handles mixed precision itself self.mixed_precision = None + offload_state_dict = self.info.world_size > 1 FSDP.set_state_dict_type( self.model, StateDictType.FULL_STATE_DICT, @@ -311,7 +313,7 @@ def __init__( else: self.cooldown_items = 0 - self.cooldown_scheduler: Optional[lr_scheduler.LambdaLR] = None + self.cooldown_scheduler: lr_scheduler.LambdaLR | None = None if "lr_scheduler" in self.cfg["train"]: self.step_interval = clamp( @@ -441,6 +443,7 @@ def _save_checkpoint( self.model, self.optimizer ) + save["grad_scaler_state_dict"] = self.grad_scaler.state_dict() save["loss_fn_state_dict"] = self.loss_fn.state_dict() if self.lr_scheduler is not None: save["lr_scheduler_state_dict"] = self.lr_scheduler.state_dict() @@ -465,6 +468,7 @@ def _load_checkpoint(self, path: str): optim_state_dict, ) self.optimizer.load_state_dict(optim_state_dict) + self.grad_scaler.load_state_dict(checkpoint["grad_scaler_state_dict"]) if self.lr_scheduler is not None and checkpoint.get("lr_scheduler_state_dict") is not None: self.lr_scheduler.load_state_dict( checkpoint["lr_scheduler_state_dict"] @@ -510,24 +514,31 @@ def _prepare_peft( @classmethod def _model_from_config( cls, - cfg: Dict[str, Any] - ) -> Tuple[nn.Module, Optional[ShardingPolicy]]: + cfg: dict[str, Any] + ) -> nn.Module: raise NotImplementedError @classmethod - def _additional_loss_fn(cls) -> Optional[Callable]: + def _sharding_policy( + cls, + model: nn.Module, + ) -> ShardingPolicy | None: + return None + + @classmethod + def _additional_loss_fn(cls) -> Callable | None: return None @classmethod - def _additional_optimizer_fn(cls) -> Optional[Callable]: + def _additional_optimizer_fn(cls) -> Callable | None: return None @classmethod - def _additional_lr_scheduler_fn(cls) -> Optional[Callable]: + def _additional_lr_scheduler_fn(cls) -> Callable | None: return None @classmethod - def _additional_max_length_scheduler_fn(cls) -> Optional[Callable]: + def _additional_max_length_scheduler_fn(cls) -> Callable | None: return None @classmethod @@ -559,13 +570,13 @@ def _copy_file_to_tmp_dir(cls, path: str, dir: str, info: distributed.Distribute @classmethod def _prepare_data_sources( cls, - sources: List[Dict[str, Any]], + sources: list[dict[str, Any]], info: distributed.DistributedInfo - ) -> Tuple[ - List[Tuple[str, Optional[str]]], - List[Optional[Any]], - List[Optional[Any]], - List[str] + ) -> tuple[ + list[tuple[str, str | None]], + list[Any | None], + list[Any | None], + list[str] ]: src_paths = [] src_preprocessings = [] @@ -632,25 +643,25 @@ def _prepare_data_sources( @classmethod def _data_from_config( cls, - train_cfg: Dict[str, Any], - val_cfg: Union[List[Any], int], + train_cfg: dict[str, Any], + val_cfg: list[Any] | int, num_epochs: int, - seed: Optional[int], + seed: int | None, info: distributed.DistributedInfo - ) -> Tuple[ + ) -> tuple[ data.TrainLoader, data.TrainLoader, int, int, int, - Optional[Callable[[int], int]], - List[str] + Callable[[int], int] | None, + list[str] ]: def prepare_data_loader( - pipeline_cfg: Dict[str, Any], - sources: List[tuple[str, str | None]], - preprocessings: List[Optional[Any]], - postprocessings: List[Optional[Any]], + pipeline_cfg: dict[str, Any], + sources: list[tuple[str, str | None]], + preprocessings: list[Any | None], + postprocessings: list[Any | None], **kwargs: Any, ) -> data.TrainLoader: pipeline_cfg = copy.deepcopy(pipeline_cfg) @@ -676,7 +687,9 @@ def prepare_data_loader( # adapt config to multi gpu usage assert "batch_limit" in train_cfg, "batch_limit must be in data config" train_cfg["batch_limit"] = max( - 1, train_cfg["batch_limit"] // info.world_size) + 1, + train_cfg["batch_limit"] // info.world_size + ) # pop some configs not used by the dataloader max_length = train_cfg.pop("max_length") @@ -777,7 +790,7 @@ def prepare_data_loader( ) @classmethod - def _setup_experiment(cls, work_dir: str, exp_dir: str, config_path: str, cfg: Dict[str, Any]): + def _setup_experiment(cls, work_dir: str, exp_dir: str, config_path: str, cfg: dict[str, Any]): config_name = os.path.split(config_path)[-1] os.makedirs(exp_dir, exist_ok=True) # save the resolved config to the experiment directory @@ -812,9 +825,9 @@ def _train_local_distributed( rank: int, world_size: int, port: int, - cfg: Dict[str, Any], - directories: Dict[str, str], - profile: Optional[str] = None + cfg: dict[str, Any], + directories: dict[str, str], + profile: str | None = None ): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(port) @@ -931,7 +944,7 @@ def train_slurm(cls, work_dir: str, experiment_dir: str, config_path: str): dist.destroy_process_group() @classmethod - def train_local(cls, work_dir: str, experiment_dir: str, config_path: str, profile: Optional[str] = None): + def train_local(cls, work_dir: str, experiment_dir: str, config_path: str, profile: str | None = None): logger = logging.get_logger("LOCAL_INITIALIZATION") num_gpus = torch.cuda.device_count() assert num_gpus > 0, "need at least one GPU for local training" @@ -964,8 +977,8 @@ def train_local(cls, work_dir: str, experiment_dir: str, config_path: str, profi join=True ) - def _prepare_batch(self, batch: data.TrainBatch) -> Tuple[ - Dict[str, Any], + def _prepare_batch(self, batch: data.TrainBatch) -> tuple[ + dict[str, Any], torch.Tensor ]: raise NotImplementedError("prepare batch not implemented") diff --git a/python/text_utils/modules/optimizer.py b/python/text_utils/modules/optimizer.py index f983da5..adeb642 100644 --- a/python/text_utils/modules/optimizer.py +++ b/python/text_utils/modules/optimizer.py @@ -2,11 +2,6 @@ from typing import Dict, Any, Iterator, Tuple, Optional, Callable from torch import nn, optim -try: - from bitsandbytes import optim as optim_8bit - _8BIT_OPTIMIZERS = True -except ImportError: - _8BIT_OPTIMIZERS = False def _select_params_and_modules( @@ -29,12 +24,24 @@ def optimizer_from_config( ) -> optim.Optimizer: cfg = copy.deepcopy(cfg) opt_type = cfg.pop("type") + + if opt_type == "adamw": + optim_cls = optim.AdamW + elif opt_type == "adam": + optim_cls = optim.Adam + elif opt_type == "sgd": + optim_cls = optim.SGD + else: + if additional_optimizer_fn is not None: + return additional_optimizer_fn(model, cfg) + raise ValueError(f"unknown optimizer type {opt_type}") + param_groups: list[dict[str, Any]] = cfg.pop("param_groups", [{"prefix": None}]) assert len(param_groups) > 0, "param_groups must be non-empty" - weight_decay_modules: dict[str, list[str]] = cfg.pop( + weight_decay_modules: dict[str, list[str]] | str = cfg.pop( "weight_decay_modules", - {} + "all" ) all = set(name for name, p in model.named_parameters() if p.requires_grad) params = [] @@ -60,8 +67,16 @@ def optimizer_from_config( names.add(name) param_dict[name] = param mod_name = mod.__class__.__name__ - if mod_name in weight_decay_modules and any( - name.endswith(suffix) for suffix in weight_decay_modules[mod_name] + if ( + weight_decay_modules == "all" + or ( + isinstance(weight_decay_modules, dict) + and mod_name in weight_decay_modules + and any( + name.endswith(suffix) + for suffix in weight_decay_modules[mod_name] + ) + ) ): decay.add(name) else: @@ -85,23 +100,4 @@ def optimizer_from_config( assert len(unused) == 0, \ f"parameter groups dont match trainable model parameters: {unused}" - optim_bits = int(cfg.get("optim_bits", 32)) - assert optim_bits in [32, 8], f"optim_bits must be 32 or 8, got {optim_bits}" - use_8bit = optim_bits == 8 - if use_8bit: - assert _8BIT_OPTIMIZERS, "8-bit optimizers not available" - else: - cfg.pop("optim_bits", None) - - if opt_type == "adamw": - optim_cls = optim_8bit.AdamW if use_8bit else optim.AdamW - elif opt_type == "adam": - optim_cls = optim_8bit.Adam if use_8bit else optim.Adam - elif opt_type == "sgd": - optim_cls = optim_8bit.SGD if use_8bit else optim.SGD - else: - if additional_optimizer_fn is not None: - return additional_optimizer_fn(model, cfg) - raise ValueError(f"unknown optimizer type {opt_type}") - return optim_cls(params, **cfg)