diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 4d5ef3e..310ee32 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -91,6 +91,7 @@ class HvConfig(BaseConfig): world_rank: int galaxy_size: int fail_rank_drop: bool = False # fail if we lose a diloco worker + @model_validator(mode="before") def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: @@ -179,7 +180,7 @@ def train(config: Config): local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) - + world_rank_list = list(range(config.hv.galaxy_size)) world_messenger_hv = config.hv is not None and local_rank == 0 # batch_size is the total batch size for all GPUs @@ -357,7 +358,7 @@ def scheduler_fn(opt): max_num_peers = 0 log_activations = {} - + log_drop = True for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps): real_step = (step + 1) // gradient_accumulation_steps is_accumulating = bool((step + 1) % gradient_accumulation_steps) @@ -449,12 +450,12 @@ def scheduler_fn(opt): metrics.update(log_activations) log_activations = {} - if world_messenger_hv and num_peers < max_num_peers: - log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}") - if config.hv.fail_rank_drop: - raise ValueError( - f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}" - ) + # if world_messenger_hv and num_peers < max_num_peers: + #log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}") + #if config.hv.fail_rank_drop: + #raise ValueError( + # f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}" + #) current_time = time.time() @@ -510,7 +511,13 @@ def scheduler_fn(opt): if config.max_steps is not None and real_step >= config.max_steps: break - + + if real_step >= int(config.total_steps)//2: + if log_drop: + log(f"Dropping worker world ranks {world_rank_list[config.hv.galaxy_size//2:]}") + log_drop = False + if config.hv is not None and config.hv.world_rank in world_rank_list[config.hv.galaxy_size//2:]: + break log("Training completed.") if rank == 0: metric_logger.finish()