Skip to content

Commit

Permalink
update fsdp import
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Feb 1, 2025
1 parent 9d9ea1f commit a3547aa
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/zeroband/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.distributed.checkpoint.stateful import Stateful
import warnings
import logging
from torch.distributed._tensor.api import DTensor
from torch.distributed.tensor import DTensor
from zeroband.utils.state_dict_send_recv import (
_get_sendable_state_dict,
recv_state_dict,
Expand Down
2 changes: 1 addition & 1 deletion src/zeroband/diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from zeroband.utils.logging import get_logger
from zeroband.config import DilocoConfig
import torch.distributed as dist
from torch.distributed._tensor.api import DTensor
from torch.distributed.tensor import DTensor
from functools import lru_cache


Expand Down
23 changes: 12 additions & 11 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
import torch.distributed as dist
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy # type: ignore
from torch.autograd.profiler import record_function

from zeroband.checkpoint import CkptManager, TrainingProgress
Expand Down Expand Up @@ -69,10 +69,9 @@ def log_hash_training_state(
logger.debug(f"outer diloco optimizer hash {id} : {outer_optimizer_hash}")
logger.debug(f"outer diloco model hash {id} : {outer_model_hash}")

metrics.update({
f"outer_optimizer_hash_{id}": outer_optimizer_hash,
f"outer_model_hash_{id}": outer_model_hash
})
metrics.update(
{f"outer_optimizer_hash_{id}": outer_optimizer_hash, f"outer_model_hash_{id}": outer_model_hash}
)
if world_info.rank == 0:
assert metric_logger is not None
metric_logger.log(metrics)
Expand Down Expand Up @@ -139,13 +138,11 @@ def train(config: Config):
apply_ac_ckpt(model, num)

elastic_device_mesh = ElasticDeviceMesh(
enable=config.diloco is not None,
live_recovery_rank_src=config.ckpt.live_recovery_rank_src
enable=config.diloco is not None, live_recovery_rank_src=config.ckpt.live_recovery_rank_src
)

mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
param_dtype=torch.bfloat16, reduce_dtype=torch.float32 if config.train.reduce_fp32 else None
)

offload_policy = CPUOffloadPolicy(pin_memory=True) if config.train.fsdp_cpu_offload else None
Expand Down Expand Up @@ -366,9 +363,13 @@ def train(config: Config):
with record_function("Inner allreduce"):
logger.debug("loss allreduce()")
# Launch both allreduces at the same time to hide latency
loss_allreduce = dist.all_reduce(tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True)
loss_allreduce = dist.all_reduce(
tensor=loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True
)
if config.optim.z_loss:
z_loss_allreduce = dist.all_reduce(tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True)
z_loss_allreduce = dist.all_reduce(
tensor=z_loss_batch, op=dist.ReduceOp.AVG, group=elastic_device_mesh.local_pg, async_op=True
)

assert isinstance(loss_allreduce, torch.distributed.Work)
loss_allreduce.wait()
Expand Down
4 changes: 2 additions & 2 deletions src/zeroband/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import torch
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed._tensor.api import DTensor
from torch.distributed.tensor import DTensor
from distributed_shampoo import DistributedShampoo


Expand Down Expand Up @@ -193,4 +193,4 @@ def __init__(self):
self.pad_token_id = 2

def __len__(self):
return self.vocab_size
return self.vocab_size
2 changes: 1 addition & 1 deletion src/zeroband/utils/state_dict_send_recv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pickle
import torch
from torch.distributed import ProcessGroup
from torch.distributed._tensor.api import DTensor
from torch.distributed.tensor import DTensor


def _object_to_tensor(obj):
Expand Down

0 comments on commit a3547aa

Please sign in to comment.