Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 11, 2025
1 parent e8b114b commit ebea115
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 58 deletions.
1 change: 1 addition & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ class OptimizerArgs:
clip_grad: Optional[float]
accumulate_grad_in_fp32: bool
learning_rate_scheduler: LRSchedulerArgs
master_weight_dtype: torch.dtype = torch.float32


@dataclass
Expand Down
16 changes: 16 additions & 0 deletions src/nanotron/fp8/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
from nanotron.parallel.parameters import NanotronParameter


def post_scaling_all_reduce_mean(
tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None, async_op=False
) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op)
tensor.div_(dist.get_world_size())
return tensor


def post_scaling_all_reduce_coalesced_mean(
tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None, async_op=False
) -> torch.Tensor:
dist.all_reduce_coalesced(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op)
tensor.div_(dist.get_world_size())
return tensor


def all_reduce(
tensor: Union[torch.Tensor, NanotronParameter],
op: dist.ReduceOp = dist.ReduceOp.SUM,
Expand Down
63 changes: 37 additions & 26 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from nanotron import logging
from nanotron.config import Config, DatasetStageArgs, LRSchedulerArgs, OptimizerArgs, ParallelismArgs
from nanotron.distributed import ProcessGroup
from nanotron.fp8.optim import FP8AdamW
from nanotron.logging import LogItem, log_rank
from nanotron.models.base import NanotronModel
from nanotron.optim.base import BaseOptimizer, Optimizer
Expand Down Expand Up @@ -295,6 +294,7 @@ def merge_named_param_groups(
def init_optimizer_and_grad_accumulator(
parametrization_method: ParametrizationMethod,
model: nn.Module,
master_weight_dtype: torch.dtype,
optimizer_args: OptimizerArgs,
parallel_context: ParallelContext,
) -> Tuple[BaseOptimizer, GradientAccumulator]:
Expand Down Expand Up @@ -327,34 +327,44 @@ def basic_optimizer_builder(named_param_groups):
optimizer = None

if optimizer_args.optimizer_factory.name == "adamW":
from nanotron import constants

def optimizer(param_groups):
# if has_fp8_params(param_groups) is False:
if constants.CONFIG.model.dtype != torch.int8:
return torch.optim.AdamW(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
eps=optimizer_args.optimizer_factory.adam_eps,
betas=(
optimizer_args.optimizer_factory.adam_beta1,
optimizer_args.optimizer_factory.adam_beta2,
),
fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
)
else:
return FP8AdamW(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
eps=optimizer_args.optimizer_factory.adam_eps,
betas=(
optimizer_args.optimizer_factory.adam_beta1,
optimizer_args.optimizer_factory.adam_beta2,
),
recipe=constants.CONFIG.fp8.optim,
)
# if constants.CONFIG.model.dtype != torch.int8:
# return torch.optim.AdamW(
# param_groups,
# lr=optimizer_args.learning_rate_scheduler.learning_rate,
# weight_decay=optimizer_args.weight_decay,
# eps=optimizer_args.optimizer_factory.adam_eps,
# betas=(
# optimizer_args.optimizer_factory.adam_beta1,
# optimizer_args.optimizer_factory.adam_beta2,
# ),
# fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
# )
# else:
# return FP8AdamW(
# param_groups,
# lr=optimizer_args.learning_rate_scheduler.learning_rate,
# weight_decay=optimizer_args.weight_decay,
# eps=optimizer_args.optimizer_factory.adam_eps,
# betas=(
# optimizer_args.optimizer_factory.adam_beta1,
# optimizer_args.optimizer_factory.adam_beta2,
# ),
# recipe=constants.CONFIG.fp8.optim,
# )
return torch.optim.AdamW(
param_groups,
lr=optimizer_args.learning_rate_scheduler.learning_rate,
weight_decay=optimizer_args.weight_decay,
eps=optimizer_args.optimizer_factory.adam_eps,
betas=(
optimizer_args.optimizer_factory.adam_beta1,
optimizer_args.optimizer_factory.adam_beta2,
),
fused=optimizer_args.optimizer_factory.torch_adam_is_fused,
)

elif optimizer_args.optimizer_factory.name == "sgd":

Expand Down Expand Up @@ -384,6 +394,7 @@ def grad_optimizer_builder(named_param_groups):
gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(
named_parameters=named_params,
grad_buckets_named_params=named_parameters,
master_dtype=master_weight_dtype,
),
named_params_or_groups=named_param_groups,
optimizer_builder=basic_optimizer_builder,
Expand Down
46 changes: 25 additions & 21 deletions src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
self,
named_parameters: Iterator[Tuple[str, NanotronParameter]],
grad_buckets_named_params: Optional[Iterator[Tuple[str, NanotronParameter]]] = None,
master_dtype: torch.dtype = torch.float32,
):
"""Create a gradient accumulator that will accumulate gradients in fp32.
Expand All @@ -84,14 +85,14 @@ def __init__(
# Assign big buffer for weights + grad in fp32
segment_index = {}
length = 0
fp32_params = []
master_params = []
for name, param in named_parameters:
# NOTE: FP8 Parameter by default has requires_grad=False,
# because we want to do the backward ourself, so here we only skip
# if the parameter isn't fp8, and doesn't require grad

if self._is_not_required_master_weights(param):
fp32_params.append((name, param))
master_params.append((name, param))
continue

start = length
Expand All @@ -100,21 +101,21 @@ def __init__(
segment_index[name] = (start, end_weight, param)
length = end_weight

big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda")
big_flat_buffer = torch.empty(length, dtype=master_dtype, device="cuda")
self.parameters = {
name: {
"fp32": big_flat_buffer[start_weight:end_weight].view_as(param),
"master": big_flat_buffer[start_weight:end_weight].view_as(param),
"half": param,
}
for name, (start_weight, end_weight, param) in segment_index.items()
}
self.parameters.update(
{
name: {
"fp32": param,
"master": param,
"half": param,
}
for name, param in fp32_params
for name, param in master_params
}
)

Expand All @@ -124,15 +125,15 @@ def __init__(

with torch.inference_mode():
for _, elt in self.parameters.items():
fp32_param = elt["fp32"]
master_param = elt["master"]
half_param = elt["half"]

# Check that fp32 weights have the same memory representation as half precision weights
assert fp32_param.stride() == half_param.stride()
assert master_param.stride() == half_param.stride()

# Copy weights from half precision to full precision
if not isinstance(half_param.data, FP8Tensor):
fp32_param.copy_(half_param)
master_param.copy_(half_param)
else:
from nanotron import constants

Expand All @@ -147,13 +148,13 @@ def find_param_name(param, named_parameters):
p_data = constants.CPU_WEIGHTS[p_name]
assert p_data.dtype == torch.float32, f"Expected {p_name} to be float32, but got {p_data.dtype}"

fp32_param.copy_(constants.CPU_WEIGHTS[p_name])
master_param.copy_(constants.CPU_WEIGHTS[p_name])

del constants.CPU_WEIGHTS[p_name]
del p_name

# Set requires_grad=True
fp32_param.requires_grad = True
master_param.requires_grad = True

self._is_accumulation_sync_step = False
# We need the last allreduce handle to make sure it finishes before the optimizer step
Expand Down Expand Up @@ -298,7 +299,7 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:

# In the case an optimizer decides to set it to None, we need to re-assign previous buffer
if name in self.parameters:
fp32_param = self.parameters[name]["fp32"]
master_param = self.parameters[name]["master"]
if hasattr(self, "param_name_to_offsets"):
if name not in self.param_name_to_offsets:
# When `name` isn't in `param_name_to_offsets` it means the slice is empty.
Expand All @@ -307,8 +308,8 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
grad = fp32_grad.view(-1)[start_offset:end_offset]
else:
grad = fp32_grad
fp32_param.grad = grad
assert is_overflow_underflow_nan(fp32_param.grad) is False
master_param.grad = grad
assert is_overflow_underflow_nan(master_param.grad) is False

@contextmanager
def no_sync(self):
Expand All @@ -331,15 +332,15 @@ def step(self):
We need to update only the parameters that were updated by the optimizer.
"""
for name in self.parameters.keys():
fp32_param = self.parameters[name]["fp32"]
master_param = self.parameters[name]["master"]
half_param = self.parameters[name]["half"]

# TODO @nouamane: should we use a fused kernel to copy?
# Copy weights from full precision to half precision
if half_param.data.__class__ == FP8Tensor:
half_param.data.set_data(fp32_param, sync=False)
half_param.data.set_data(master_param, sync=False)
else:
half_param.copy_(fp32_param)
half_param.copy_(master_param)

def zero_grad(self):
# Full precision gradients are reset to zero/none after the underlying `optimiser.step`, so no need to reset.
Expand All @@ -355,22 +356,22 @@ def zero_grad(self):
self._contiguous_fp32_grad_buffer.zero_()

def get_parameter_for_optimizer(self, name: str) -> NanotronParameter:
return self.parameters[name]["fp32"]
return self.parameters[name]["master"]

def get_grad_buffer(self, name: str) -> torch.Tensor:
"""Returns the gradient of the parameter from the appropriate grad bucket."""
return self.fp32_grad_buffers[name]["fp32_grad"]

def state_dict(self) -> Dict[str, torch.Tensor]:
# We consider `fp32` parameters as a state of the gradient accumulator
return {name: elt["fp32"] for name, elt in self.parameters.items()}
# We consider master parameters as a state of the gradient accumulator
return {name: elt["master"] for name, elt in self.parameters.items()}

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
assert set(state_dict.keys()) == set(self.parameters.keys())

with torch.inference_mode():
for name, elt in self.parameters.items():
elt["fp32"].copy_(state_dict[name])
elt["master"].copy_(state_dict[name])


@dataclasses.dataclass
Expand Down Expand Up @@ -405,6 +406,9 @@ def get_fp32_accum_hook(
# s = torch.cuda.Stream()

def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
import pydevd

pydevd.settrace(suspend=True, trace_only_current_thread=True)
# nonlocal s
# DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation.
# See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details.
Expand Down
35 changes: 24 additions & 11 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
)
from nanotron.constants import MODEL_CONFIG_FILE_NAME
from nanotron.dataloader import sanity_check_dataloader
from nanotron.fp8.parallel import DistributedDataParallel as FP8DistributedDataParallel
from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.utils import convert_model_to_fp8
from nanotron.helpers import (
Expand Down Expand Up @@ -240,9 +239,17 @@ def __init__(
self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator(
parametrization_method=parametrization_method,
model=self.model,
master_weight_dtype=self.config.optimizer.master_weight_dtype,
optimizer_args=self.config.optimizer,
parallel_context=self.parallel_context,
)
# NOTE: quantize optimizer states
# add hook to dequantize optimizer states before .step()
# add hook step to recompute lr
# add post_step hook to quantize optimizer states

assert 1 == 1

if self.init_checkpoint_path is not None:
load_optimizer(
optimizer=self.optimizer,
Expand Down Expand Up @@ -883,16 +890,22 @@ def _init_model(
# Check that the model has at least one grad. Necessary for DDP
check_model_has_grad(model=model, parallel_context=parallel_context)
# TODO @thomasw21: DDP doesn't support broadcasting complex buffers (and we don't really need that broadcasting anyway)
if self.config.model.dtype == torch.int8:
raise NotImplementedError
model = FP8DistributedDataParallel(model, self.parallel_context)
else:
model = DistributedDataParallel(
model,
process_group=parallel_context.dp_pg,
broadcast_buffers=False,
bucket_cap_mb=config.model.ddp_bucket_cap_mb,
)
# if self.config.model.dtype == torch.int8:
# raise NotImplementedError
# model = FP8DistributedDataParallel(model, self.parallel_context)
# else:
# model = DistributedDataParallel(
# model,
# process_group=parallel_context.dp_pg,
# broadcast_buffers=False,
# bucket_cap_mb=config.model.ddp_bucket_cap_mb,
# )
model = DistributedDataParallel(
model,
process_group=parallel_context.dp_pg,
broadcast_buffers=False,
bucket_cap_mb=config.model.ddp_bucket_cap_mb,
)

# Sanity check the model, all parameters must be NanotronParameter (either tied or sharded)
sanity_check(root_module=model)
Expand Down

0 comments on commit ebea115

Please sign in to comment.