Skip to content

Commit

Permalink
add fp8 optim init
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 22, 2024
1 parent 2864391 commit 1800efe
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 300 deletions.
3 changes: 2 additions & 1 deletion src/nanotron/config/fp8_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from nanotron.fp8.constants import FP8LM_OPTIM_RECIPE
from nanotron.fp8.recipe import FP8LinearRecipe, FP8OptimRecipe


Expand All @@ -24,7 +25,7 @@ class FP8Args:
accum_dtype: torch.dtype = torch.bfloat16

model: Optional[List[FP8LayerArgs]] = None
optim: Optional[FP8OptimRecipe] = None
optim: Optional[FP8OptimRecipe] = FP8LM_OPTIM_RECIPE

run_fp8_sanity_check: bool = False

Expand Down
1 change: 1 addition & 0 deletions src/nanotron/fp8/constant_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding

MODULE_NAMES_THAT_NOT_FP8 = [
"rotary_embedding",
"token_embedding",
"input_layernorm",
"post_attention_layernorm",
Expand Down
359 changes: 80 additions & 279 deletions src/nanotron/fp8/optim.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/nanotron/fp8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def convert_logs_to_flat_logs(logs, prefix):
def find_fp8_config_by_module_name(config: Config, target_module_name: str) -> Optional[FP8LayerArgs]:
if hasattr(config, "fp8") and hasattr(config.fp8, "model"):
# NOTE: either model or is_quant_all_except_first_and_last must be specified, not both
assert config.fp8.model is not None or config.fp8.is_quant_all_except_first_and_last is not None
# assert config.fp8.model is not None or config.fp8.is_quant_all_except_first_and_last is not None

if config.fp8.model is not None:
for layer_args in config.fp8.model:
Expand Down
10 changes: 9 additions & 1 deletion src/nanotron/parallel/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def __new__(cls, tensor: torch.Tensor, requires_grad: bool = True):

def __init__(self, tensor: Union[torch.Tensor, "FP8Tensor"]):
self._data = tensor
setattr(self, self.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME, _generate_random_hash())
# NOTE: whether we will quantize this parameter
# because we need to know a parameter will be in fp8 or not
# so we create master weights of the fp32 parameters before quantizing
self._is_future_fp8 = False
setattr(self, self.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME, hash(_generate_random_hash()))

def _set_metadata(self, key: str, value: Any):
metadata = getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME)
Expand Down Expand Up @@ -327,6 +331,10 @@ def wrap(e):
else:
return tree_map(wrap, outputs)

def __hash__(self):
# Combine the attributes to compute a unique hash value
return getattr(self, self.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)


def sanity_check(root_module: nn.Module):
"""Makes sure that the module is in Nanotronformat
Expand Down
64 changes: 51 additions & 13 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ def __init__(
else ParametrizationMethod.STANDARD
)

from nanotron.fp8.utils import find_fp8_config_by_module_name, get_leaf_modules

for module_name, module in get_leaf_modules(self.model):
if any(p.numel() > 0 for p in module.parameters()) is False:
continue

recipe = find_fp8_config_by_module_name(constants.CONFIG, module_name)
if recipe is not None:
module.weight._is_future_fp8 = True

# Init optimizer
self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator(
parametrization_method=parametrization_method,
Expand All @@ -193,8 +203,38 @@ def __init__(
parallel_context=self.parallel_context,
)

# self._convert_model_to_fp8(self.model)
# NOTE: sanity check all hash are different
param_hash = []
for p in self.model.parameters():
assert hash(p) not in param_hash
param_hash.append(hash(p))

# NOTE: if we cast model to FP8 before wrapping it with NanotronParameter,
# then we can create a NanotronParameter that has dtype=[torch.int8, torch.uint8]
# which then it allows us to assign [torch.int8, torch.uint8] gradients to the parameter
# otherwise, it would raise:
# "attempting to assign a gradient with dtype
# 'unsigned char' to a tensor with dtype 'float'.
# Please ensure that the gradient and the tensor have the same dtype"
# NOTE: the reason that we cast after initializing the optimizer is that
# we want to create some master weights for fp8 parameters, before quantizing them
assert 1 == 1
self._convert_model_to_fp8(self.model)
assert 1 == 1

# TODO(xrsrke): sanity check that _is_future_fp8 is consistent with the dtype of the parameter

if constants.CONFIG.model.dtype is torch.int8:
from nanotron.fp8.optim import FP8AdamW
from nanotron.fp8.tensor import FP8Tensor

assert self.optimizer.optimizer.__class__ == FP8AdamW
num_fp8_params = sum(1 for p in self.model.parameters() if isinstance(p.data, FP8Tensor))
master_weights_mapping = self.optimizer.optimizer.mappping_fp8_to_master_weight
assert num_fp8_params == len(master_weights_mapping)
assert all(
constants.CONFIG.fp8.optim.master_weight_dtype == p.dtype for p in master_weights_mapping.values()
)

if self.init_checkpoint_path is not None:
load_optimizer(
Expand Down Expand Up @@ -288,7 +328,16 @@ def print_sanity_params(model):
}
if self.config.model.dtype is torch.int8:
for name, module in get_leaf_modules(model):
if isinstance(module, (TensorParallelColumnLinear, TensorParallelRowLinear)):
if any(p.numel() > 0 for p in module.parameters()) is False:
continue

from nanotron import constants
from nanotron.fp8.utils import find_fp8_config_by_module_name

recipe = find_fp8_config_by_module_name(constants.CONFIG, name)

# if isinstance(module, (TensorParallelColumnLinear, TensorParallelRowLinear)):
if recipe is not None:
print(f"Converting {name} to FP8")
module.__class__ = TP_LINEAR_CLS_TO_FP8_LINEAR_CLS[module.__class__]
# TODO(xrsrke): retrieve custom recipe
Expand Down Expand Up @@ -740,17 +789,6 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]:
model = self._init_model_instance()
model = self._load_model_checkpoint(model)

# NOTE: if we cast model to FP8 before wrapping it with NanotronParameter,
# then we can create a NanotronParameter that has dtype=[torch.int8, torch.uint8]
# which then it allows us to assign [torch.int8, torch.uint8] gradients to the parameter
# otherwise, it would raise:
# "attempting to assign a gradient with dtype
# 'unsigned char' to a tensor with dtype 'float'.
# Please ensure that the gradient and the tensor have the same dtype"
assert 1 == 1
self._convert_model_to_fp8(model)
assert 1 == 1

return model

def _init_model_instance(self) -> NanotronModel:
Expand Down
6 changes: 6 additions & 0 deletions tests/fp8/test_fp8_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,9 @@ def _test_create_sharded_fp8_parameter(parallel_context: ParallelContext, dtype:
# on a FP8Parameter

# TODO(xrsrke): test CPU parameter


# TODO(xrsrke): test convert model to FP8
# include the FP8's NanotronParameter's dtype and requires_grad

# TODO(xrsrke): test set FP8 gradients to FP8 NanotronParameter
13 changes: 8 additions & 5 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ def _test_random_hash_nanotron_parameter(parallel_context: ParallelContext):
contiguous_chunks=(8, 8),
)
param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config)
hash = getattr(param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)
# hash = getattr(param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)

assert type(hash) == str
# assert type(hash) == str
assert hash(param) is not None
assert type(hash(param)) == int


def test_nanotron_parameter_does_not_override_some_parameter_variable():
Expand Down Expand Up @@ -105,9 +107,10 @@ def _test_create_param_that_share_metadata(parallel_context: ParallelContext):
assert p1_k == p2_k
assert p1_v == p2_v

orig_hash = getattr(orig_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)
new_hash = getattr(new_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)
# orig_hash = getattr(orig_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)
# new_hash = getattr(new_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME)

assert new_hash == orig_hash
# assert new_hash == orig_hash
assert hash(new_param) == hash(orig_param)

parallel_context.destroy()

0 comments on commit 1800efe

Please sign in to comment.