Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 10, 2025
1 parent a3a13ce commit e8b114b
Show file tree
Hide file tree
Showing 12 changed files with 10 additions and 130 deletions.
4 changes: 0 additions & 4 deletions src/nanotron/fp8/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ def linear(
name: Optional[str] = None,
):
assert isinstance(weight, NanotronParameter)
from typing import cast

from nanotron import constants
from nanotron.config.fp8_config import FP8Args

assert metadatas is not None, "metadatas must be specified"
assert recipe is not None, "recipe must be specified"
Expand Down
11 changes: 0 additions & 11 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union, cast

import pydevd
import torch
import transformer_engine as te # noqa
from torch import nn
Expand Down Expand Up @@ -143,16 +142,6 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
from nanotron.config.fp8_config import FP8Args
from nanotron.fp8.utils import is_overflow_underflow_nan

# pydevd.settrace(suspend=False, trace_only_current_thread=True)
if (
constants.CONFIG is not None
and constants.CONFIG.fp8 is not None
and constants.CONFIG.fp8.is_debugging is True
):
pydevd.settrace(suspend=False, trace_only_current_thread=True)

# dist.monitored_barrier(wait_all_ranks=True)

if constants.CONFIG is None:
fp8_config = FP8Args()
else:
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/fp8/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
fp8_meta: Optional[FP8Meta] = None,
sync: bool = False,
) -> None:
raise NotImplementedError()
pass

@staticmethod
# @torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def __init__(
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
name=f"model.decoder.{layer_idx}.attention.qkv_proj",
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
# tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
if config.rope_interleaved:
Expand Down Expand Up @@ -917,7 +917,7 @@ def __init__(
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"name": "model.lm_head",
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
# "tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
Expand Down
4 changes: 0 additions & 4 deletions src/nanotron/optim/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def clip_grad_norm(
named_parameters = list(named_parameters)
world_rank = dist.get_rank()

# assert that all params require grad
for _, p in named_parameters:
assert p.requires_grad or isinstance(
p.data, FP8Tensor
Expand Down Expand Up @@ -89,9 +88,6 @@ def clip_grad_norm(
device_to_clip_coef_clamped = {device: clip_coef_clamped.to(device) for device in devices}

for name, param in named_parameters:
if "model.decoder.13.pp_block.attn.o_proj.weight" in name:
assert 1 == 1

if grad_accumulator is None:
param.grad.detach().mul_(device_to_clip_coef_clamped[param.grad.device])
else:
Expand Down
20 changes: 2 additions & 18 deletions src/nanotron/optim/gradient_accumulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import nanotron.distributed as dist
from nanotron import logging
from nanotron.fp8.tensor import FP8Tensor
from nanotron.fp8.utils import is_overflow_underflow_nan
from nanotron.parallel.parameters import NanotronParameter
from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage

Expand Down Expand Up @@ -89,8 +90,6 @@ def __init__(
# because we want to do the backward ourself, so here we only skip
# if the parameter isn't fp8, and doesn't require grad

# if not isinstance(param.data, FP8Tensor) and not param.requires_grad:
# continue
if self._is_not_required_master_weights(param):
fp32_params.append((name, param))
continue
Expand Down Expand Up @@ -277,13 +276,6 @@ def backward(self, loss: torch.Tensor):

def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
"""Accumulate grad in fp32 and set the fp32 grad to the fp32 grad buffer, so that optimizer can update fp32 weights afterwards"""
if name == "model.decoder.4.pp_block.attn.qkv_proj.weight":
assert 1 == 1

# try:
# assert half_param.grad is not None, f"Expected param {name} to have gradient."
# except AssertionError:
# assert 1 == 1
assert half_param.grad is not None, f"Expected param {name} to have gradient."
from nanotron.fp8.tensor import convert_tensor_from_fp8

Expand All @@ -292,20 +284,12 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None:
else:
grad = half_param.grad

from nanotron.fp8.utils import is_overflow_underflow_nan

assert is_overflow_underflow_nan(grad) is False, f"name: {name}"
assert is_overflow_underflow_nan(grad) is False, f"Detected overflow/underflow/nan in {name} grad"

fp32_grad = self.get_grad_buffer(name=name)

if self._is_accumulation_sync_step is False:
# WARNING: We assume fp32_grad_bucket is already zeroed
# if not isinstance(half_param.data, FP8Tensor):
# fp32_grad.add_(grad)
# else:
# assert grad.dtype in [torch.int8, torch.uint8]
# # TODO(xrsrke): move .convert_tensor_from_fp8 to .to(dtype), so we have an unified API
# fp32_grad.add_(grad)
fp32_grad.add_(grad)
# In case _is_accumulation_sync_step = True: no need to add half gradients, because it's done in the allreduce hook

Expand Down
26 changes: 0 additions & 26 deletions src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import dataclasses
import hashlib
import os
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

Expand All @@ -17,15 +15,6 @@
logger = logging.get_logger(__name__)


def _generate_random_hash():
# Generate 64 bytes of random data
random_data = os.urandom(64)
# Hash the random data using SHA-256
hash_object = hashlib.sha256(random_data)
# Convert the hash object to a hexadecimal string
return hash_object.hexdigest()


@dataclasses.dataclass
class SlicesPair:
local_slices: Tuple[slice, ...]
Expand Down Expand Up @@ -122,7 +111,6 @@ class NanotronParameter(nn.Parameter):

# __torch_function__ = torch._C._disabled_torch_function_impl

NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME = "__nanotron_hash__"
NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME = "__nanotron_metadata__"
NANOTRON_PARAMETER_METADATA_TIED_KEY = "tied"
NANOTRON_PARAMETER_METADATA_SHARDED_KEY = "sharded"
Expand Down Expand Up @@ -173,7 +161,6 @@ def __init__(self, tensor: Union[torch.Tensor, "FP8Tensor"]):
# because we need to know a parameter will be in fp8 or not
# so we create master weights of the fp32 parameters before quantizing
self._is_future_fp8 = False
setattr(self, self.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME, hash(_generate_random_hash()))

def _set_metadata(self, key: str, value: Any):
metadata = getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)
Expand Down Expand Up @@ -239,18 +226,9 @@ def create_param_that_share_metadata(cls, tensor: torch.Tensor, param: Union[nn.
new_param = NanotronParameter(tensor)
setattr(new_param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME, metadata)

# NOTE: if the param is a nn.Parameter, then we don't need to sync the hash
if isinstance(param, NanotronParameter):
setattr(
new_param,
NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME,
getattr(param, cls.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME),
)

# TODO(xrsrke): sync all the attributes in the param
# to the new parameter? in case, user sets some attributes
# then the new parameter is kinda lost it

return new_param

@property
Expand Down Expand Up @@ -319,10 +297,6 @@ def wrap(e):
else:
return tree_map(wrap, outputs)

def __hash__(self):
# Combine the attributes to compute a unique hash value
return getattr(self, self.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)


def sanity_check(root_module: nn.Module):
"""Makes sure that the module is in Nanotronformat
Expand Down
32 changes: 0 additions & 32 deletions src/nanotron/parallel/tensor_parallel/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,6 @@ def column_linear(
name: Optional[str] = None,
recipe: Optional[FP8LinearRecipe] = None,
):
# weight = get_data_from_param(weight)

# if bias is not None:
# bias = get_data_from_param(bias)

if async_communication:
return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather)

Expand All @@ -456,29 +451,8 @@ def column_linear(

input = differentiable_identity(input, group=group)

# if isinstance(weight, FP8Tensor): # i used weight before removing get_data_from_param
if isinstance(weight.data, FP8Tensor):
assert recipe is not None, "recipe must be provided for column_linear"
from nanotron import constants

# if name not in constants.TRACKING_FP8_PARAM:
# constants.TRACKING_FP8_PARAM[name] = weight

if (
constants.CONFIG is not None
and constants.CONFIG.fp8 is not None
and constants.CONFIG.fp8.is_sanity_logging is True
):
from nanotron import logging
from nanotron.logging import log_rank

logger = logging.get_logger(__name__)
log_rank(
f"[iteration_step: {constants.ITERATION_STEP}]name = {name}, doing fp8 kernel",
logger=logger,
level=logging.INFO,
)

return fp8_functional.linear(input, weight, bias, metadatas=metadatas, recipe=recipe, name=name)
else:
return F.linear(input, weight, bias)
Expand Down Expand Up @@ -632,18 +606,12 @@ def row_linear(
recipe: Optional[FP8LinearRecipe] = None,
name: Optional[str] = None,
):
# weight = get_data_from_param(weight)
# if bias is not None:
# bias = get_data_from_param(bias)

if async_communication:
return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode)

# out = F.linear(input, weight, bias)
import nanotron.fp8.functional as fp8_functional
from nanotron.fp8.tensor import FP8Tensor

# if isinstance(weight, FP8Tensor): # i used weight before removing get_data_from_param
if isinstance(weight.data, FP8Tensor):
assert recipe is not None, "recipe must be provided for row_linear"
out = fp8_functional.linear(input, weight, bias, metadatas=metadatas, recipe=recipe, name=name)
Expand Down
3 changes: 0 additions & 3 deletions src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@
)
from nanotron.parallel.tied_parameters import create_tied_parameter

# from nanotron.utils import post_init

# @post_init
class _BaseTensorParallelColumnLinear:
def __init__(
self,
Expand Down Expand Up @@ -110,7 +108,6 @@ def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}"


# @post_init
class _BaseTensorParallelRowLinear:
def __init__(
self,
Expand Down
11 changes: 0 additions & 11 deletions src/nanotron/scaling/parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def __init__(self, config: ModelArgs):
self.num_layers = config.model_config.num_hidden_layers

def _parametrize_column_linear(self, param_name: str, module: nn.Module):
# assert param_name in ["weight", "bias"]
assert any(x in param_name for x in ["weight", "bias"])

if "weight" in param_name:
Expand All @@ -54,7 +53,6 @@ def _parametrize_column_linear(self, param_name: str, module: nn.Module):
module.bias.zero_()

def _parametrize_row_linear(self, param_name: str, module: nn.Module):
# assert param_name in ["weight", "bias"]
assert any(x in param_name for x in ["weight", "bias"])

if "weight" in param_name:
Expand All @@ -64,13 +62,6 @@ def _parametrize_row_linear(self, param_name: str, module: nn.Module):
module.bias.zero_()

def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
# assert param_name in ["weight", "bias"]

# if "weight" == param_name:
# # TODO @thomasw21: Sometimes we actually want 0
# module.weight.fill_(1)
# elif "bias" == param_name:
# module.bias.zero_()
assert any(x in param_name for x in ["weight", "bias"])
if "weight" in param_name:
# TODO @thomasw21: Sometimes we actually want 0
Expand All @@ -79,10 +70,8 @@ def _parametrize_layer_norm(self, param_name: str, module: nn.Module):
module.bias.zero_()

def _parametrize_embedding(self, param_name: str, module: nn.Module):
# assert param_name in ["weight"]
assert "weight" in param_name

# if "weight" == param_name:
if "weight" in param_name:
init.normal_(module.weight, mean=0.0, std=self.std)

Expand Down
10 changes: 5 additions & 5 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ def __init__(
constants.CPU_WEIGHTS[n.replace("module.", "")] = p.data.cpu().clone()

# NOTE: sanity check all hash are different
param_hash = []
for p in self.model.parameters():
assert hash(p) not in param_hash
param_hash.append(hash(p))
# param_hash = []
# for p in self.model.parameters():
# assert hash(p) not in param_hash
# param_hash.append(hash(p))

# NOTE: if we cast model to FP8 before wrapping it with NanotronParameter,
# then we can create a NanotronParameter that has dtype=[torch.int8, torch.uint8]
Expand Down Expand Up @@ -585,7 +585,7 @@ def training_step(
)

before_optim_step_sanity_checks(
self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator
self.config, self.parallel_context, self.unwrapped_model, self.optimizer, self.grad_accumulator
)

# Compute DP average loss and overlap with optimizer step
Expand Down
13 changes: 0 additions & 13 deletions src/nanotron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,16 +160,3 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int:
return port
except OSError:
continue


def post_init(cls):
"""Decorator to call __post_init__ method after __init__ method of a class."""
original_init = cls.__init__

def new_init(self, *args, **kwargs):
original_init(self, *args, **kwargs)
if hasattr(self, "post_init"):
self.__post_init__()

cls.__init__ = new_init
return cls

0 comments on commit e8b114b

Please sign in to comment.