diff --git a/src/nanotron/fp8/functional.py b/src/nanotron/fp8/functional.py index 25719fd6..cdf9e58a 100644 --- a/src/nanotron/fp8/functional.py +++ b/src/nanotron/fp8/functional.py @@ -16,10 +16,6 @@ def linear( name: Optional[str] = None, ): assert isinstance(weight, NanotronParameter) - from typing import cast - - from nanotron import constants - from nanotron.config.fp8_config import FP8Args assert metadatas is not None, "metadatas must be specified" assert recipe is not None, "recipe must be specified" diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 92f552bc..96777e97 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, cast -import pydevd import torch import transformer_engine as te # noqa from torch import nn @@ -143,16 +142,6 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ from nanotron.config.fp8_config import FP8Args from nanotron.fp8.utils import is_overflow_underflow_nan - # pydevd.settrace(suspend=False, trace_only_current_thread=True) - if ( - constants.CONFIG is not None - and constants.CONFIG.fp8 is not None - and constants.CONFIG.fp8.is_debugging is True - ): - pydevd.settrace(suspend=False, trace_only_current_thread=True) - - # dist.monitored_barrier(wait_all_ranks=True) - if constants.CONFIG is None: fp8_config = FP8Args() else: diff --git a/src/nanotron/fp8/tensor.py b/src/nanotron/fp8/tensor.py index e5d375be..db1b670e 100644 --- a/src/nanotron/fp8/tensor.py +++ b/src/nanotron/fp8/tensor.py @@ -58,7 +58,7 @@ def __init__( fp8_meta: Optional[FP8Meta] = None, sync: bool = False, ) -> None: - raise NotImplementedError() + pass @staticmethod # @torch.no_grad() diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index e873314d..0d69fe62 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -394,7 +394,7 @@ def __init__( async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, name=f"model.decoder.{layer_idx}.attention.qkv_proj", - tp_recompute_allgather=parallel_config.tp_recompute_allgather, + # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -917,7 +917,7 @@ def __init__( "mode": self.tp_mode, "async_communication": tp_linear_async_communication, "name": "model.lm_head", - "tp_recompute_allgather": parallel_config.tp_recompute_allgather, + # "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, diff --git a/src/nanotron/optim/clip_grads.py b/src/nanotron/optim/clip_grads.py index 06ac55d4..dd5698a2 100644 --- a/src/nanotron/optim/clip_grads.py +++ b/src/nanotron/optim/clip_grads.py @@ -34,7 +34,6 @@ def clip_grad_norm( named_parameters = list(named_parameters) world_rank = dist.get_rank() - # assert that all params require grad for _, p in named_parameters: assert p.requires_grad or isinstance( p.data, FP8Tensor @@ -89,9 +88,6 @@ def clip_grad_norm( device_to_clip_coef_clamped = {device: clip_coef_clamped.to(device) for device in devices} for name, param in named_parameters: - if "model.decoder.13.pp_block.attn.o_proj.weight" in name: - assert 1 == 1 - if grad_accumulator is None: param.grad.detach().mul_(device_to_clip_coef_clamped[param.grad.device]) else: diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 12ce3800..c5dc9990 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -10,6 +10,7 @@ import nanotron.distributed as dist from nanotron import logging from nanotron.fp8.tensor import FP8Tensor +from nanotron.fp8.utils import is_overflow_underflow_nan from nanotron.parallel.parameters import NanotronParameter from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage @@ -89,8 +90,6 @@ def __init__( # because we want to do the backward ourself, so here we only skip # if the parameter isn't fp8, and doesn't require grad - # if not isinstance(param.data, FP8Tensor) and not param.requires_grad: - # continue if self._is_not_required_master_weights(param): fp32_params.append((name, param)) continue @@ -277,13 +276,6 @@ def backward(self, loss: torch.Tensor): def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: """Accumulate grad in fp32 and set the fp32 grad to the fp32 grad buffer, so that optimizer can update fp32 weights afterwards""" - if name == "model.decoder.4.pp_block.attn.qkv_proj.weight": - assert 1 == 1 - - # try: - # assert half_param.grad is not None, f"Expected param {name} to have gradient." - # except AssertionError: - # assert 1 == 1 assert half_param.grad is not None, f"Expected param {name} to have gradient." from nanotron.fp8.tensor import convert_tensor_from_fp8 @@ -292,20 +284,12 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: else: grad = half_param.grad - from nanotron.fp8.utils import is_overflow_underflow_nan - - assert is_overflow_underflow_nan(grad) is False, f"name: {name}" + assert is_overflow_underflow_nan(grad) is False, f"Detected overflow/underflow/nan in {name} grad" fp32_grad = self.get_grad_buffer(name=name) if self._is_accumulation_sync_step is False: # WARNING: We assume fp32_grad_bucket is already zeroed - # if not isinstance(half_param.data, FP8Tensor): - # fp32_grad.add_(grad) - # else: - # assert grad.dtype in [torch.int8, torch.uint8] - # # TODO(xrsrke): move .convert_tensor_from_fp8 to .to(dtype), so we have an unified API - # fp32_grad.add_(grad) fp32_grad.add_(grad) # In case _is_accumulation_sync_step = True: no need to add half gradients, because it's done in the allreduce hook diff --git a/src/nanotron/parallel/parameters.py b/src/nanotron/parallel/parameters.py index 79b1ea50..702a1e80 100644 --- a/src/nanotron/parallel/parameters.py +++ b/src/nanotron/parallel/parameters.py @@ -1,6 +1,4 @@ import dataclasses -import hashlib -import os from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union @@ -17,15 +15,6 @@ logger = logging.get_logger(__name__) -def _generate_random_hash(): - # Generate 64 bytes of random data - random_data = os.urandom(64) - # Hash the random data using SHA-256 - hash_object = hashlib.sha256(random_data) - # Convert the hash object to a hexadecimal string - return hash_object.hexdigest() - - @dataclasses.dataclass class SlicesPair: local_slices: Tuple[slice, ...] @@ -122,7 +111,6 @@ class NanotronParameter(nn.Parameter): # __torch_function__ = torch._C._disabled_torch_function_impl - NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME = "__nanotron_hash__" NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME = "__nanotron_metadata__" NANOTRON_PARAMETER_METADATA_TIED_KEY = "tied" NANOTRON_PARAMETER_METADATA_SHARDED_KEY = "sharded" @@ -173,7 +161,6 @@ def __init__(self, tensor: Union[torch.Tensor, "FP8Tensor"]): # because we need to know a parameter will be in fp8 or not # so we create master weights of the fp32 parameters before quantizing self._is_future_fp8 = False - setattr(self, self.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME, hash(_generate_random_hash())) def _set_metadata(self, key: str, value: Any): metadata = getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME) @@ -239,18 +226,9 @@ def create_param_that_share_metadata(cls, tensor: torch.Tensor, param: Union[nn. new_param = NanotronParameter(tensor) setattr(new_param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME, metadata) - # NOTE: if the param is a nn.Parameter, then we don't need to sync the hash - if isinstance(param, NanotronParameter): - setattr( - new_param, - NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME, - getattr(param, cls.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME), - ) - # TODO(xrsrke): sync all the attributes in the param # to the new parameter? in case, user sets some attributes # then the new parameter is kinda lost it - return new_param @property @@ -319,10 +297,6 @@ def wrap(e): else: return tree_map(wrap, outputs) - def __hash__(self): - # Combine the attributes to compute a unique hash value - return getattr(self, self.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME) - def sanity_check(root_module: nn.Module): """Makes sure that the module is in Nanotronformat diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 61dbdd1d..7a46fe98 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -442,11 +442,6 @@ def column_linear( name: Optional[str] = None, recipe: Optional[FP8LinearRecipe] = None, ): - # weight = get_data_from_param(weight) - - # if bias is not None: - # bias = get_data_from_param(bias) - if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) @@ -456,29 +451,8 @@ def column_linear( input = differentiable_identity(input, group=group) - # if isinstance(weight, FP8Tensor): # i used weight before removing get_data_from_param if isinstance(weight.data, FP8Tensor): assert recipe is not None, "recipe must be provided for column_linear" - from nanotron import constants - - # if name not in constants.TRACKING_FP8_PARAM: - # constants.TRACKING_FP8_PARAM[name] = weight - - if ( - constants.CONFIG is not None - and constants.CONFIG.fp8 is not None - and constants.CONFIG.fp8.is_sanity_logging is True - ): - from nanotron import logging - from nanotron.logging import log_rank - - logger = logging.get_logger(__name__) - log_rank( - f"[iteration_step: {constants.ITERATION_STEP}]name = {name}, doing fp8 kernel", - logger=logger, - level=logging.INFO, - ) - return fp8_functional.linear(input, weight, bias, metadatas=metadatas, recipe=recipe, name=name) else: return F.linear(input, weight, bias) @@ -632,18 +606,12 @@ def row_linear( recipe: Optional[FP8LinearRecipe] = None, name: Optional[str] = None, ): - # weight = get_data_from_param(weight) - # if bias is not None: - # bias = get_data_from_param(bias) - if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) - # out = F.linear(input, weight, bias) import nanotron.fp8.functional as fp8_functional from nanotron.fp8.tensor import FP8Tensor - # if isinstance(weight, FP8Tensor): # i used weight before removing get_data_from_param if isinstance(weight.data, FP8Tensor): assert recipe is not None, "recipe must be provided for row_linear" out = fp8_functional.linear(input, weight, bias, metadatas=metadatas, recipe=recipe, name=name) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 920dc403..debc8f06 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -42,9 +42,7 @@ ) from nanotron.parallel.tied_parameters import create_tied_parameter -# from nanotron.utils import post_init -# @post_init class _BaseTensorParallelColumnLinear: def __init__( self, @@ -110,7 +108,6 @@ def extra_repr(self) -> str: return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}" -# @post_init class _BaseTensorParallelRowLinear: def __init__( self, diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index 62128db3..819617f6 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -45,7 +45,6 @@ def __init__(self, config: ModelArgs): self.num_layers = config.model_config.num_hidden_layers def _parametrize_column_linear(self, param_name: str, module: nn.Module): - # assert param_name in ["weight", "bias"] assert any(x in param_name for x in ["weight", "bias"]) if "weight" in param_name: @@ -54,7 +53,6 @@ def _parametrize_column_linear(self, param_name: str, module: nn.Module): module.bias.zero_() def _parametrize_row_linear(self, param_name: str, module: nn.Module): - # assert param_name in ["weight", "bias"] assert any(x in param_name for x in ["weight", "bias"]) if "weight" in param_name: @@ -64,13 +62,6 @@ def _parametrize_row_linear(self, param_name: str, module: nn.Module): module.bias.zero_() def _parametrize_layer_norm(self, param_name: str, module: nn.Module): - # assert param_name in ["weight", "bias"] - - # if "weight" == param_name: - # # TODO @thomasw21: Sometimes we actually want 0 - # module.weight.fill_(1) - # elif "bias" == param_name: - # module.bias.zero_() assert any(x in param_name for x in ["weight", "bias"]) if "weight" in param_name: # TODO @thomasw21: Sometimes we actually want 0 @@ -79,10 +70,8 @@ def _parametrize_layer_norm(self, param_name: str, module: nn.Module): module.bias.zero_() def _parametrize_embedding(self, param_name: str, module: nn.Module): - # assert param_name in ["weight"] assert "weight" in param_name - # if "weight" == param_name: if "weight" in param_name: init.normal_(module.weight, mean=0.0, std=self.std) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 5315d6dc..9b20f615 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -218,10 +218,10 @@ def __init__( constants.CPU_WEIGHTS[n.replace("module.", "")] = p.data.cpu().clone() # NOTE: sanity check all hash are different - param_hash = [] - for p in self.model.parameters(): - assert hash(p) not in param_hash - param_hash.append(hash(p)) + # param_hash = [] + # for p in self.model.parameters(): + # assert hash(p) not in param_hash + # param_hash.append(hash(p)) # NOTE: if we cast model to FP8 before wrapping it with NanotronParameter, # then we can create a NanotronParameter that has dtype=[torch.int8, torch.uint8] @@ -585,7 +585,7 @@ def training_step( ) before_optim_step_sanity_checks( - self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator + self.config, self.parallel_context, self.unwrapped_model, self.optimizer, self.grad_accumulator ) # Compute DP average loss and overlap with optimizer step diff --git a/src/nanotron/utils.py b/src/nanotron/utils.py index 5fc0a0c5..b3831801 100644 --- a/src/nanotron/utils.py +++ b/src/nanotron/utils.py @@ -160,16 +160,3 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int: return port except OSError: continue - - -def post_init(cls): - """Decorator to call __post_init__ method after __init__ method of a class.""" - original_init = cls.__init__ - - def new_init(self, *args, **kwargs): - original_init(self, *args, **kwargs) - if hasattr(self, "post_init"): - self.__post_init__() - - cls.__init__ = new_init - return cls