Skip to content

Commit

Permalink
fix didn't update fp8 parameters in optim.step() due to grad_accum
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 9, 2025
1 parent 4723335 commit dd3259b
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 191 deletions.
286 changes: 117 additions & 169 deletions src/nanotron/fp8/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from nanotron import logging

# from nanotron._utils.memory import delete_tensor_from_memory
from nanotron.fp8.constants import FP8_DTYPES, FP8LM_RECIPE
from nanotron.fp8.constants import FP8LM_RECIPE
from nanotron.fp8.dtypes import DTypes
from nanotron.fp8.recipe import FP8OptimRecipe
from nanotron.fp8.tensor import (
Expand All @@ -16,7 +16,7 @@
convert_tensor_from_fp8,
convert_tensor_from_fp16,
)
from nanotron.fp8.utils import compute_stas, is_overflow_underflow_nan
from nanotron.fp8.utils import is_overflow_underflow_nan
from nanotron.logging import log_rank

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -157,81 +157,34 @@ def step(self, closure=None):
# assert num_param_has_grads > 0

self._is_overflow = False
loggings = {}

fp8_config = cast(FP8Args, constants.CONFIG.fp8)
non_fp8_accum_dtype = fp8_config.resid_dtype

# from nanotron.helpers import get_accum_grad, set_accum_grad

for i, group in enumerate(self.param_groups):
for p in group["params"]:

if not isinstance(p.data, FP8Tensor) and p.requires_grad is False:
continue
# if not isinstance(p.data, FP8Tensor) and p.requires_grad is False:
# continue

assert p.grad is not None

# p_name = self.params_id_to_param_names[id(p)]
# loggings[p] = {}
state = self.state[p]
if len(state) == 0:
self._init_optim_states(state, p)

# data = get_data_from_param(p)
# IS_FP8 = data.__class__ == FP8Tensor

# NOTE: if use gradient accumulation, after the backward pass
# we set the param.grad to None, so we need to retrieve it from accumulator

# if constants.CONFIG.optimizer.accumulate_grad_in_fp32 is True:
# # fp32_grad = self.grad_accumulator.get_grad_buffer(name=p_name)

# # if "model.decoder.8.pp_block.attn_layer_scale" in p_name:
# # assert 1 == 1

# # if constants.CONFIG.fp8.is_save_grad_for_accum_debugging is True:
# # from nanotron.helpers import create_folder_and_save_tensor

# # create_folder_and_save_tensor(
# # fp32_grad,
# # f"/fsx/phuc/temp/temp3_env_for_fp8/nanotron/debug_accum/{constants.CONFIG.general.run}/aggr_grads/{p_name}.pt",
# # )
# raise NotImplementedError("accumulate_grad_in_fp32 is not implemented")
# else:
# if isinstance(p.data, FP8Tensor):
# if constants.CONFIG.fp8.is_directly_keep_accum_grad_of_fp8 is True:
# # fp32_grad = constants.ACCUM_GRADS[p_name]
# # grad = get_accum_grad(p_name)
# # fp32_grad = (
# # grad.to(self.optim_accum_dtype) if grad.dtype != self.optim_accum_dtype else grad
# # )
# # assert fp32_grad.dtype == torch.float32

# # # constants.ACCUM_GRADS[p_name] = None
# # set_accum_grad(p_name, None)
# raise NotImplementedError("is_directly_keep_accum_grad_of_fp8 is not implemented")
# else:
# assert p.grad.dtype in FP8_DTYPES
# fp32_grad = convert_tensor_from_fp8(p.grad, p.grad.fp8_meta, self.optim_accum_dtype)
# else:
# # grad = get_grad_from_parameter(p)

# # assert grad is not None
# assert p.grad.dtype == non_fp8_accum_dtype

# fp32_grad = p.grad.to(self.optim_accum_dtype)

# NOTE: Case 1: With gradient accumulator => the grad is already in the correct dtype
# Case 2: Without gradient accumulator =>
# 2.1 Non-FP8 parameter => cast the grad to the correct dtype
# 2.2 FP8 parameter => dequantize the grad to the correct dtype
grad = p.grad
if isinstance(p.data, FP8Tensor):
fp32_grad = convert_tensor_from_fp8(grad, grad.fp8_meta, self.optim_accum_dtype)
else:
fp32_grad = grad.to(self.optim_accum_dtype)

# grad = p.grad
# if isinstance(p.data, FP8Tensor):
# fp32_grad = convert_tensor_from_fp8(grad, grad.fp8_meta, self.optim_accum_dtype)
# else:
# fp32_grad = grad.to(self.optim_accum_dtype)
fp32_grad = p.grad
assert fp32_grad.dtype == self.optim_accum_dtype

if is_overflow_underflow_nan(fp32_grad):
Expand All @@ -247,22 +200,24 @@ def step(self, closure=None):
else:
raise ValueError("Overflow, underflow, or NaN detected in the gradients")

if isinstance(p.data, FP8Tensor):
assert p.data.dtype in FP8_DTYPES
assert hash(p) in self.mappping_fp8_to_master_weight, "Can't find master weight for FP8 parameter"
# if isinstance(p.data, FP8Tensor):
# assert p.data.dtype in FP8_DTYPES
# assert hash(p) in self.mappping_fp8_to_master_weight, "Can't find master weight for FP8 parameter"

master_data = self.mappping_fp8_to_master_weight[hash(p)]
if self.master_weight_dtype == DTypes.KFLOAT16:
fp32_data = convert_tensor_from_fp16(master_data, self.optim_accum_dtype)
else:
fp32_data = master_data.to(self.optim_accum_dtype)
else:
assert (
p.data.dtype == non_fp8_accum_dtype
), f"data.dtype={p.data.dtype}, non_fp8_accum_dtype={non_fp8_accum_dtype}"
fp32_data = p.data.to(self.optim_accum_dtype)
# master_data = self.mappping_fp8_to_master_weight[hash(p)]
# if self.master_weight_dtype == DTypes.KFLOAT16:
# fp32_data = convert_tensor_from_fp16(master_data, self.optim_accum_dtype)
# else:
# fp32_data = master_data.to(self.optim_accum_dtype)
# else:
# assert (
# p.data.dtype == non_fp8_accum_dtype
# ), f"data.dtype={p.data.dtype}, non_fp8_accum_dtype={non_fp8_accum_dtype}"
# fp32_data = p.data.to(self.optim_accum_dtype)

assert fp32_data.dtype == self.optim_accum_dtype
# assert fp32_data.dtype == self.optim_accum_dtype
assert p.data.dtype == torch.float32
fp32_data = p.data

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

Expand Down Expand Up @@ -349,56 +304,48 @@ def step(self, closure=None):
fp32_new_changes_in_p = fp32_new_changes_from_grad + fp32_new_changes_from_weight_decay
new_fp32_data = fp32_data - fp32_new_changes_in_p

if isinstance(p.data, FP8Tensor):
sync_amax_in_weight = fp8_config.sync_amax_in_weight

self.mappping_fp8_to_master_weight[p] = self._create_master_weight(new_fp32_data)
p.data.set_data(new_fp32_data, sync=sync_amax_in_weight)

# NOTE: SANITY CHECK
if constants.CONFIG.fp8.run_fp8_sanity_check is True:
if self.master_weight_dtype == DTypes.KFLOAT16:
_dequant_master_data = convert_tensor_from_fp16(
self.mappping_fp8_to_master_weight[p], DTypes.KFLOAT16, torch.float32
)
torch.testing.assert_allclose(_dequant_master_data, new_fp32_data)

# _quant_new_fp32_data = get_data_from_param(p)
_quant_new_fp32_data = p.data
_dequant_new_fp32_data = convert_tensor_from_fp8(
_quant_new_fp32_data, _quant_new_fp32_data.fp8_meta, torch.float32
)
from nanotron.fp8.constants import FP8_WEIGHT_ATOL_THRESHOLD, FP8_WEIGHT_RTOL_THRESHOLD

torch.testing.assert_allclose(
_dequant_new_fp32_data,
new_fp32_data,
rtol=FP8_WEIGHT_RTOL_THRESHOLD,
atol=FP8_WEIGHT_ATOL_THRESHOLD,
)
p.data = new_fp32_data

# NOTE: move param.data = new_data to FP32GradientAccumulator.step
# if isinstance(p.data, FP8Tensor):
# sync_amax_in_weight = fp8_config.sync_amax_in_weight

# self.mappping_fp8_to_master_weight[p] = self._create_master_weight(new_fp32_data)
# p.data.set_data(new_fp32_data, sync=sync_amax_in_weight)

# # NOTE: SANITY CHECK
# if constants.CONFIG.fp8.run_fp8_sanity_check is True:
# if self.master_weight_dtype == DTypes.KFLOAT16:
# _dequant_master_data = convert_tensor_from_fp16(
# self.mappping_fp8_to_master_weight[p], DTypes.KFLOAT16, torch.float32
# )
# torch.testing.assert_allclose(_dequant_master_data, new_fp32_data)

# # _quant_new_fp32_data = get_data_from_param(p)
# _quant_new_fp32_data = p.data
# _dequant_new_fp32_data = convert_tensor_from_fp8(
# _quant_new_fp32_data, _quant_new_fp32_data.fp8_meta, torch.float32
# )
# from nanotron.fp8.constants import FP8_WEIGHT_ATOL_THRESHOLD, FP8_WEIGHT_RTOL_THRESHOLD

# torch.testing.assert_allclose(
# _dequant_new_fp32_data,
# new_fp32_data,
# rtol=FP8_WEIGHT_RTOL_THRESHOLD,
# atol=FP8_WEIGHT_ATOL_THRESHOLD,
# )

else:
if constants.CONFIG.fp8.stochastic_rounding is True:
raise NotImplementedError("stochastic_rounding is not implemented")
assert non_fp8_accum_dtype is torch.bfloat16, "only support stochastic rounding for bfloat16"
new_fp16 = torch.full_like(new_fp32_data, 0.0, dtype=non_fp8_accum_dtype)
copy_stochastic_(target=new_fp16, source=new_fp32_data)
else:
# new_fp16 = (
# new_fp32_data.to(non_fp8_accum_dtype)
# if new_fp32_data.dtype != non_fp8_accum_dtype
# else new_fp32_data
# )
new_fp16 = new_fp32_data.to(non_fp8_accum_dtype)
# else:
# new_fp16 = new_fp32_data.to(non_fp8_accum_dtype)

# new_fp16.requires_grad = True
# p.data = new_fp16
# new_fp16.requires_grad = True
# p.data = new_fp16

# assert p.data is new_fp16
# assert p.data is new_fp16

# if constants.CONFIG.fp8.run_fp8_sanity_check is True:
# # torch.testing.assert_allclose(get_data_from_param(p), new_fp16)
# torch.testing.assert_allclose(p.data, new_fp16)
# if constants.CONFIG.fp8.run_fp8_sanity_check is True:
# # torch.testing.assert_allclose(get_data_from_param(p), new_fp16)
# torch.testing.assert_allclose(p.data, new_fp16)

exp_avg = self._quantize_optim_state(fp32_exp_avg, self.recipe.exp_avg_dtype)
exp_avg_sq = self._quantize_optim_state(fp32_exp_avg_sq, self.recipe.exp_avg_sq_dtype)
Expand All @@ -412,62 +359,63 @@ def step(self, closure=None):
assert state["exp_avg_sq"] is exp_avg_sq

# NOTE: remove this shit
if constants.is_ready_to_log is True:
loggings[p]["step"] = {"value": step}
loggings[p]["group:lr"] = {"value": lr}
loggings[p]["group:eps"] = {"value": group["eps"]}
loggings[p]["group:beta1"] = {"value": beta1}
loggings[p]["group:beta2"] = {"value": beta2}

loggings[p]["bias_correction1"] = {"value": bias_correction1}
loggings[p]["bias_correction2"] = {"value": bias_correction2}
loggings[p]["fp32_exp_avg"] = compute_stas(fp32_exp_avg)
loggings[p]["fp32_exp_avg_sq"] = compute_stas(fp32_exp_avg_sq)

loggings[p]["normalized_grad"] = compute_stas(normalized_grad)

if fp8_config.adam_atan2 is False:
loggings[p]["denom"] = compute_stas(denom)

loggings[p]["update_lr"] = {"value": update_lr}

loggings[p]["fp32_p"] = compute_stas(fp32_data)
loggings[p]["fp32_new_changes_in_p"] = {
# "abs_total": fp32_new_changes_in_p.abs().sum(),
# "abs_mean": fp32_new_changes_in_p.abs().mean(),
"rms": fp32_new_changes_in_p.pow(2)
.mean()
.sqrt(),
}
loggings[p]["fp32_new_changes_from_grad"] = {
"rms": fp32_new_changes_from_grad.pow(2).mean().sqrt(),
}

p_norm = fp32_data.norm()

loggings[p]["fp32_grad"] = compute_stas(fp32_grad)
loggings[p]["update_lr"] = {"value": update_lr}
loggings[p]["weight_norm_and_normalized_grad_norm_ratio"] = {
"value": p_norm / fp32_new_changes_from_grad.norm()
}
loggings[p]["weight_norm_and_weight_update_norm_ratio"] = {
"value": p_norm / fp32_new_changes_in_p.norm()
}

if weight_decay_factor != 0:
loggings[p]["fp32_new_changes_from_weight_decay"] = {
"rms": fp32_new_changes_from_weight_decay.pow(2).mean().sqrt(),
}
loggings[p]["weight_norm_and_weight_decay_grad_norm_ratio"] = {
"value": p_norm / fp32_weight_decay_grad.norm()
}

if constants.CONFIG.fp8.update_clipping is True:
loggings[p]["grad_rms"] = {"value": rms}
# if constants.is_ready_to_log is True:
# loggings[p]["step"] = {"value": step}
# loggings[p]["group:lr"] = {"value": lr}
# loggings[p]["group:eps"] = {"value": group["eps"]}
# loggings[p]["group:beta1"] = {"value": beta1}
# loggings[p]["group:beta2"] = {"value": beta2}

# loggings[p]["bias_correction1"] = {"value": bias_correction1}
# loggings[p]["bias_correction2"] = {"value": bias_correction2}
# loggings[p]["fp32_exp_avg"] = compute_stas(fp32_exp_avg)
# loggings[p]["fp32_exp_avg_sq"] = compute_stas(fp32_exp_avg_sq)

# loggings[p]["normalized_grad"] = compute_stas(normalized_grad)

# if fp8_config.adam_atan2 is False:
# loggings[p]["denom"] = compute_stas(denom)

# loggings[p]["update_lr"] = {"value": update_lr}

# loggings[p]["fp32_p"] = compute_stas(fp32_data)
# loggings[p]["fp32_new_changes_in_p"] = {
# # "abs_total": fp32_new_changes_in_p.abs().sum(),
# # "abs_mean": fp32_new_changes_in_p.abs().mean(),
# "rms": fp32_new_changes_in_p.pow(2)
# .mean()
# .sqrt(),
# }
# loggings[p]["fp32_new_changes_from_grad"] = {
# "rms": fp32_new_changes_from_grad.pow(2).mean().sqrt(),
# }

# p_norm = fp32_data.norm()

# loggings[p]["fp32_grad"] = compute_stas(fp32_grad)
# loggings[p]["update_lr"] = {"value": update_lr}
# loggings[p]["weight_norm_and_normalized_grad_norm_ratio"] = {
# "value": p_norm / fp32_new_changes_from_grad.norm()
# }
# loggings[p]["weight_norm_and_weight_update_norm_ratio"] = {
# "value": p_norm / fp32_new_changes_in_p.norm()
# }

# if weight_decay_factor != 0:
# loggings[p]["fp32_new_changes_from_weight_decay"] = {
# "rms": fp32_new_changes_from_weight_decay.pow(2).mean().sqrt(),
# }
# loggings[p]["weight_norm_and_weight_decay_grad_norm_ratio"] = {
# "value": p_norm / fp32_weight_decay_grad.norm()
# }

# if constants.CONFIG.fp8.update_clipping is True:
# loggings[p]["grad_rms"] = {"value": rms}

# if constants.is_ready_to_log is True:
# self.loggings = loggings
# self.loggings = self._get_optim_logs()
assert 1 == 1

def zero_grad(self):
for group in self.param_groups:
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/fp8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def find_fp8_config_by_module_name(target_module_name: str, config: FP8Args) ->
# TODO(xrsrke): remove config.is_quant_all_except_first_and_last
from nanotron.fp8.constants import FP8LM_LINEAR_RECIPE

if config.model is not None:
if hasattr(config, "model") and config.model is not None:
for layer_args in config.model:
if layer_args.module_name == target_module_name.replace("pp_block.", "").replace("module.", ""):
return layer_args
Expand Down
3 changes: 3 additions & 0 deletions src/nanotron/optim/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def clip_grad_norm(
device_to_clip_coef_clamped = {device: clip_coef_clamped.to(device) for device in devices}

for name, param in named_parameters:
if "model.decoder.13.pp_block.attn.o_proj.weight" in name:
assert 1 == 1

if grad_accumulator is None:
param.grad.detach().mul_(device_to_clip_coef_clamped[param.grad.device])
else:
Expand Down
Loading

0 comments on commit dd3259b

Please sign in to comment.