Skip to content

Commit

Permalink
refactor fp8 linear, tp, parameter tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 22, 2024
1 parent 1800efe commit c5bcbe7
Show file tree
Hide file tree
Showing 12 changed files with 627 additions and 139 deletions.
18 changes: 11 additions & 7 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,23 @@ def __init__(
# TODO(xrsrke): take initialization dtype from recipe
# NOTE: initialize in float32
super().__init__(in_features, out_features, bias, device, dtype=torch.float32)
self._quantize_weights()
self._set_and_quantize_weights(self.weight.data)

assert self.bias is None
# assert self.bias is None
# if self.bias is not None:
# self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype))
# assert self.bias.dtype == recipe.accum_dtype

# self.metadatas = FP8LinearMeta()
# self.recipe = recipe

def _quantize_weights(self, recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE):
def _set_and_quantize_weights(self, data: Optional[torch.Tensor], recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE):
"""
data: if set to None, then we quantize the module's current weights, otherwise, we quantize
the provided tensor
"""
# quant_w = FP8Parameter(self.weight.data, dtype=recipe.weight.dtype, interval=recipe.weight.interval)
quant_w = FP8Tensor(self.weight.data, dtype=recipe.weight.dtype, interval=recipe.weight.interval)
quant_w = FP8Tensor(data, dtype=recipe.weight.dtype, interval=recipe.weight.interval)

# assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
# self.weight = quant_w
Expand All @@ -73,8 +77,8 @@ def _quantize_weights(self, recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE):
# setattr(self, "weight", NanotronParameter(tensor=quant_w))
setattr(self, "weight", NanotronParameter.create_param_that_share_metadata(quant_w, self.weight))

if self.name == "model.decoder.0.attention.qkv_proj":
assert 1 == 1
# if self.name == "model.decoder.0.attention.qkv_proj":
# assert 1 == 1

# NOTE: assume each time we requantize the weights, we reset the metadata
self.metadatas = FP8LinearMeta()
Expand All @@ -88,7 +92,7 @@ def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:
# weight=get_data_from_param(self.weight),
# bias=None if self.bias is None else get_data_from_param(self.bias),
weight=self.weight,
bias=None,
bias=self.bias,
metadatas=self.metadatas,
recipe=self.recipe,
)
Expand Down
61 changes: 32 additions & 29 deletions src/nanotron/fp8/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
# NOTE: torch.Tensor is bias
self.fp8_weights: List[Union[FP8Parameter, torch.Tensor]] = []
# NOTE: create master weights for FP8 Parameter
self.mappping_fp8_to_master_weight: Dict[FP8Tensor, Union[FP16Tensor, torch.Tensor]] = {}
self.mappping_fp8_to_master_weight: Dict[str, Union[FP16Tensor, torch.Tensor]] = {}

for group in self.param_groups:
for p in group["params"]:
Expand Down Expand Up @@ -259,23 +259,19 @@ def step(self, closure=None):
assert fp32_grad.dtype == self.optim_accum_dtype

if isinstance(p.data, FP8Tensor):
assert data.dtype in FP8_DTYPES
assert p in self.mappping_fp8_to_master_weight, "FP8Tensor should have a master weight"
assert p.data.dtype in FP8_DTYPES
assert hash(p) in self.mappping_fp8_to_master_weight, "Can't find master weight for FP8 parameter"

master_data = self.mappping_fp8_to_master_weight[p]
master_data = self.mappping_fp8_to_master_weight[hash(p)]
if self.master_weight_dtype == DTypes.KFLOAT16:
fp32_data = convert_tensor_from_fp16(master_data, self.optim_accum_dtype)
else:
fp32_data = (
master_data.to(self.optim_accum_dtype)
if master_data.dtype != self.optim_accum_dtype
else master_data
)
fp32_data = master_data.to(self.optim_accum_dtype)
else:
assert (
data.dtype == non_fp8_accum_dtype
), f"data.dtype={data.dtype}, non_fp8_accum_dtype={non_fp8_accum_dtype}"
fp32_data = data.to(self.optim_accum_dtype) if data.dtype != self.optim_accum_dtype else data
p.data.dtype == non_fp8_accum_dtype
), f"data.dtype={p.data.dtype}, non_fp8_accum_dtype={non_fp8_accum_dtype}"
fp32_data = p.data.to(self.optim_accum_dtype)

assert fp32_data.dtype == self.optim_accum_dtype

Expand All @@ -284,7 +280,7 @@ def step(self, closure=None):

if constants.CONFIG.fp8.skip_param_update_if_nan is True:
log_rank(
f"[Optim] param_name={p_name}, skipping update due to overflow/underflow/nan", # noqa
f"[Optim] param_name, skipping update due to overflow/underflow/nan", # noqa
logger=logger,
level=logging.INFO,
)
Expand Down Expand Up @@ -343,7 +339,7 @@ def step(self, closure=None):
# NOTE: only scale down the lr, not scale it up
update_lr = lr / torch.max(torch.tensor(1.0, dtype=self.optim_accum_dtype, device="cuda"), rms)
log_rank(
f"[Gradient clipping] param_name={p_name}, grad_norm: {fp32_grad.norm(p=2)}, RMS is {rms}, original lr is {lr}, new lr is {update_lr}", # noqa
f"[Gradient clipping] param_name=, grad_norm: {fp32_grad.norm(p=2)}, RMS is {rms}, original lr is {lr}, new lr is {update_lr}", # noqa
logger=logger,
level=logging.INFO,
rank=0,
Expand All @@ -353,7 +349,10 @@ def step(self, closure=None):
else:
update_lr = lr

weight_decay_factor = group["weight_decay"] if data.ndim >= 2 else 0.0
# NOTE: keep weight decay for biases
# TODO(xrsrke): we should explicitly set weight_decay_factor to 0 for biases
# in optimizer's param_groups
weight_decay_factor = group["weight_decay"] if p.data.ndim >= 2 else 0.0

if weight_decay_factor != 0:
fp32_new_changes_from_grad = update_lr * normalized_grad
Expand All @@ -374,7 +373,7 @@ def step(self, closure=None):
fp32_new_changes_in_p = fp32_new_changes_from_grad + fp32_new_changes_from_weight_decay
new_fp32_data = fp32_data - fp32_new_changes_in_p

if IS_FP8:
if isinstance(p.data, FP8Tensor):
sync_amax_in_weight = fp8_config.sync_amax_in_weight

self.mappping_fp8_to_master_weight[p] = self._create_master_weight(new_fp32_data)
Expand All @@ -388,7 +387,8 @@ def step(self, closure=None):
)
torch.testing.assert_allclose(_dequant_master_data, new_fp32_data)

_quant_new_fp32_data = get_data_from_param(p)
# _quant_new_fp32_data = get_data_from_param(p)
_quant_new_fp32_data = p.data
_dequant_new_fp32_data = convert_tensor_from_fp8(
_quant_new_fp32_data, _quant_new_fp32_data.fp8_meta, torch.float32
)
Expand All @@ -403,26 +403,29 @@ def step(self, closure=None):

else:
if constants.CONFIG.fp8.stochastic_rounding is True:
raise NotImplementedError("stochastic_rounding is not implemented")
assert non_fp8_accum_dtype is torch.bfloat16, "only support stochastic rounding for bfloat16"
new_fp16 = torch.full_like(new_fp32_data, 0.0, dtype=non_fp8_accum_dtype)
copy_stochastic_(target=new_fp16, source=new_fp32_data)
else:
new_fp16 = (
new_fp32_data.to(non_fp8_accum_dtype)
if new_fp32_data.dtype != non_fp8_accum_dtype
else new_fp32_data
)
# new_fp16 = (
# new_fp32_data.to(non_fp8_accum_dtype)
# if new_fp32_data.dtype != non_fp8_accum_dtype
# else new_fp32_data
# )
new_fp16 = new_fp32_data.to(non_fp8_accum_dtype)

new_fp16.requires_grad = True
p.data = new_fp16
# new_fp16.requires_grad = True
# p.data = new_fp16

assert get_data_from_param(p) is new_fp16
# assert p.data is new_fp16

if constants.CONFIG.fp8.run_fp8_sanity_check is True:
torch.testing.assert_allclose(get_data_from_param(p), new_fp16)
# if constants.CONFIG.fp8.run_fp8_sanity_check is True:
# # torch.testing.assert_allclose(get_data_from_param(p), new_fp16)
# torch.testing.assert_allclose(p.data, new_fp16)

exp_avg = self._create_optim_state(fp32_exp_avg, self.recipe.exp_avg_dtype)
exp_avg_sq = self._create_optim_state(fp32_exp_avg_sq, self.recipe.exp_avg_sq_dtype)
exp_avg = self._quantize_optim_state(fp32_exp_avg, self.recipe.exp_avg_dtype)
exp_avg_sq = self._quantize_optim_state(fp32_exp_avg_sq, self.recipe.exp_avg_sq_dtype)

state["step"] = step
state["exp_avg"] = exp_avg
Expand Down
10 changes: 0 additions & 10 deletions src/nanotron/fp8/parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import Optional, Union

import torch
Expand All @@ -11,15 +10,6 @@
from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor


@dataclass
class FP8GradMeta:
"""FP8 metadata for FP8Linear."""

input_grad: FP8Meta
weight_grad: FP8Meta
output_grad: FP8Meta


class FP8Parameter(nn.Parameter):
"""
A custom FP8 parameter class that allows
Expand Down
7 changes: 4 additions & 3 deletions src/nanotron/fp8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.linear import FP8Linear
from nanotron.fp8.meta import FP8Meta
from nanotron.fp8.parameter import FP8Parameter
from nanotron.logging import log_rank

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -62,10 +61,12 @@ def convert_linear_to_fp8(linear: nn.Linear, accum_qtype: DTypes = FP8LM_RECIPE.
fp8_linear = FP8Linear(
in_features, out_features, bias=is_bias, device=linear.weight.device, accum_qtype=accum_qtype
)
fp8_linear.weight = FP8Parameter(linear.weight.data.clone(), FP8LM_RECIPE.linear.weight.dtype)
# TODO(xrsrke): do we need clone?
fp8_linear._set_and_quantize_weights(linear.weight.data.clone())
# fp8_linear.weight = FP8Parameter(linear.weight.data.clone(), FP8LM_RECIPE.linear.weight.dtype)

if is_bias:
fp8_linear.bias.orig_data = linear.bias.data.clone()
# fp8_linear.bias.orig_data = linear.bias.data.clone()
fp8_linear.bias.data = linear.bias.data.to(QTYPE_TO_DTYPE[accum_qtype])

return fp8_linear
Expand Down
54 changes: 31 additions & 23 deletions src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def get_sharded_info(self) -> ShardedInfo:
]

@classmethod
def create_param_that_share_metadata(cls, tensor: torch.Tensor, param: "NanotronParameter"):
def create_param_that_share_metadata(cls, tensor: torch.Tensor, param: Union[nn.Parameter, "NanotronParameter"]):
"""Create a new parameter that shares the metadata and hash of the given parameter"""
# TODO(xrsrke): support deepcopy for tied parameter's metadata, because it includes an all-reduce
# which if we do deepcopy, it raises an error
Expand All @@ -240,11 +240,19 @@ def create_param_that_share_metadata(cls, tensor: torch.Tensor, param: "Nanotron
# Copy metadata to the new parameter
new_param = NanotronParameter(tensor)
setattr(new_param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME, metadata)
setattr(
new_param,
NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME,
getattr(param, cls.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME),
)

# NOTE: if the param is a nn.Parameter, then we don't need to sync the hash
if isinstance(param, NanotronParameter):
setattr(
new_param,
NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME,
getattr(param, cls.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME),
)

# TODO(xrsrke): sync all the attributes in the param
# to the new parameter? in case, user sets some attributes
# then the new parameter is kinda lost it

return new_param

@property
Expand Down Expand Up @@ -349,26 +357,26 @@ def sanity_check(root_module: nn.Module):
)


def get_data_from_param(p: NanotronParameter):
from nanotron.fp8.parameter import FP8Parameter
# def get_data_from_param(p: NanotronParameter):
# from nanotron.fp8.parameter import FP8Parameter

assert p.__class__ in [NanotronParameter, FP8Parameter]
# NOTE: this return the data that gradients can flow into
return p.data
# assert p.__class__ in [NanotronParameter, FP8Parameter]
# # NOTE: this return the data that gradients can flow into
# return p.data


def get_grad_from_parameter(p: NanotronParameter):
assert p.__class__ == NanotronParameter
assert (p.grad is not None and p.data.grad is not None) is False
# def get_grad_from_parameter(p: NanotronParameter):
# assert p.__class__ == NanotronParameter
# assert (p.grad is not None and p.data.grad is not None) is False

from nanotron import constants
# from nanotron import constants

if constants.CONFIG is not None and constants.CONFIG.tokens.batch_accumulation_per_replica > 1:
if hasattr(p, "grad"):
grad = p.grad if p.grad is not None else p.data.grad
else:
grad = p.__accum_grad if p.__accum_grad is not None else p.data.__accum_grad
else:
grad = p.grad if p.grad is not None else p.data.grad
# if constants.CONFIG is not None and constants.CONFIG.tokens.batch_accumulation_per_replica > 1:
# if hasattr(p, "grad"):
# grad = p.grad if p.grad is not None else p.data.grad
# else:
# grad = p.__accum_grad if p.__accum_grad is not None else p.data.__accum_grad
# else:
# grad = p.grad if p.grad is not None else p.data.grad

return grad
# return grad
16 changes: 15 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def print_sanity_params(model):
print(f"Converting {name} to FP8")
module.__class__ = TP_LINEAR_CLS_TO_FP8_LINEAR_CLS[module.__class__]
# TODO(xrsrke): retrieve custom recipe
module._quantize_weights()
module._set_and_quantize_weights()

assert isinstance(module.weight, NanotronParameter)
assert isinstance(module.weight.data, FP8Tensor)
Expand Down Expand Up @@ -646,6 +646,20 @@ def training_step(
loss_avg = None
handle = None

# NOTE: sanity check that non-fp8 parameters's gradients have
# the same datatype of the residual stream's dtype
for p in self.model.parameters():
from nanotron import constants
from nanotron.fp8.tensor import FP8Tensor

if isinstance(p.data, FP8Tensor):
assert p.grad.dtype in [torch.uint8, torch.int8], f"got {p.data.dtype}"
else:
if p.requires_grad is False:
continue

assert p.grad.dtype == constants.CONFIG.fp8.resid_dtype

# Apply gradient
self.optimizer.step()
self.optimizer.zero_grad()
Expand Down
File renamed without changes.
Loading

0 comments on commit c5bcbe7

Please sign in to comment.