diff --git a/.gitignore b/.gitignore index cbc04eaf..2da7cc5a 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,8 @@ cython_debug/ checkpoints/ wandb/ + +*.csv +*.html +src/nanotron/.test_cache/ +log/ diff --git a/benchmark/fp8_tp_speed.py b/benchmark/fp8_tp_speed.py index bba7c1f3..dbd58f13 100644 --- a/benchmark/fp8_tp_speed.py +++ b/benchmark/fp8_tp_speed.py @@ -1,5 +1,4 @@ import argparse -import itertools import pandas as pd import torch @@ -11,16 +10,13 @@ # H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ h100_peak_flops_float32 = 67e12 +# NOTE: without sparsity h100_peak_flops_fp16_tc = 989e12 h100_peak_tops_float8_tc = 1979e12 - -# def color_scale(val, min_val, max_val): -# """Generate a color scale from white to dark blue based on value.""" -# if pd.isna(val): -# return 'background-color: white' -# normalized = (val - min_val) / (max_val - min_val) if max_val != min_val else 0 -# return f'background-color: rgba(0, 0, 139, {normalized:.2f}); color: {"white" if normalized > 0.5 else "black"}' +# # NOTE: with sparity +# h100_peak_flops_fp16_tc = 1979e12 +# h100_peak_tops_float8_tc = 3958e12 def color_scale(val, min_val, max_val, metric_type="default"): @@ -43,55 +39,6 @@ def color_scale(val, min_val, max_val, metric_type="default"): return f"{color}; color: {text_color}" -# [Previous benchmark_fn_in_sec and run functions remain the same...] - -# def create_html_table(df, exp_number, tp_size): -# # Style the dataframe -# styled_df = df.style.format({ -# 'FP8_time_ms': '{:.2f}', -# 'BF16_time_ms': '{:.2f}', -# 'FP8_TFLOPS': '{:.2f}', -# 'BF16_TFLOPS': '{:.2f}', -# 'FP8_efficiency_%': '{:.2f}', -# 'BF16_efficiency_%': '{:.2f}', -# 'Speedup': '{:.2f}' -# }) - -# # Apply color scaling to specific columns -# styled_df = styled_df.apply(lambda x: pd.Series([ -# color_scale(v, x.min(), x.max(), 'time') if col.endswith('time_ms') -# else color_scale(v, x.min(), x.max(), 'performance') if col.endswith('TFLOPS') -# else color_scale(v, x.min(), x.max(), 'efficiency') if col.endswith('efficiency_%') -# else color_scale(v, x.min(), x.max()) if col == 'Speedup' -# else '' for v in x -# ]), axis=0) - -# # Generate HTML -# html = f''' -# -# -# -# -# -#
-#

Benchmark Results (TP_SIZE={tp_size})

-#

Experiment: {exp_number}

-#
-# {styled_df.to_html()} -# -# -# ''' - -# with open(f'{exp_number}_benchmark_results_tp{tp_size}.html', 'w') as f: -# f.write(html) - - def create_html_table(df, exp_number, tp_size): def style_df(df): # Create an empty DataFrame with the same shape as the input @@ -199,6 +146,23 @@ def run_linear(input, M, N, K, parallel_context, include_backward=False): sharded_output.sum().backward() +# def parse_args(): +# parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions") +# parser.add_argument("--exp_number", type=str, help="Experiment number") +# parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") +# # parser.add_argument( +# # "--dimensions", +# # type=str, +# # default="4096,16384,32768,28672,49152", +# # help="Comma-separated list of dimensions to test. Used when M/K/N are not explicitly provided", +# # ) +# # Add the missing argument definitions +# parser.add_argument("--m_size", type=int, help="Explicitly set M dimension") +# parser.add_argument("--k_size", type=int, help="Explicitly set K dimension") +# parser.add_argument("--n_size", type=int, help="Explicitly set N dimension") +# return parser.parse_args() + + def parse_args(): parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions") parser.add_argument("--exp_number", type=str, help="Experiment number") @@ -207,14 +171,17 @@ def parse_args(): "--dimensions", type=str, default="4096,16384,32768,28672,49152", - help="Comma-separated list of dimensions to test", + help="Comma-separated list of dimensions to test. Used when M/K/N are not explicitly provided", ) + parser.add_argument("--m_size", type=int, help="Explicitly set M dimension") + parser.add_argument("--k_size", type=int, help="Explicitly set K dimension") + parser.add_argument("--n_size", type=int, help="Explicitly set N dimension") return parser.parse_args() -def benchmark_linear_operations(M, N, K, parallel_context, include_backward): - input = torch.randn(M, K, device="cuda", requires_grad=False) - bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=False, dtype=torch.bfloat16) +def benchmark_linear_operations(M, N, K, parallel_context, include_backward, requires_grad): + input = torch.randn(M, K, device="cuda", requires_grad=requires_grad) + bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=requires_grad, dtype=torch.bfloat16) # Benchmark FP8 fp8_time = benchmark_fn_in_sec(run_fp8_linear, input, M, N, K, parallel_context, include_backward) @@ -223,13 +190,18 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward): bfloat16_time = benchmark_fn_in_sec(run_linear, bfloat16_input, M, N, K, parallel_context, include_backward) # Calculate FLOPS - # Each linear operation performs 2*M*N*K FLOPs (multiply-add) total_flops = 2 * M * N * K // parallel_context.tensor_parallel_size + if include_backward: + # Gradient with Respect to Parameters + total_flops += 2 * M * N * K // parallel_context.tensor_parallel_size + if requires_grad: + # Gradient with Respect to Inputs + total_flops += 2 * M * N * K // parallel_context.tensor_parallel_size fp8_tflops = (total_flops / fp8_time) / 1e12 bfloat16_tflops = (total_flops / bfloat16_time) / 1e12 - # Calculate efficiency compared to peak performance + # Calculate efficiency fp8_efficiency = (fp8_tflops / (h100_peak_tops_float8_tc / 1e12)) * 100 bfloat16_efficiency = (bfloat16_tflops / (h100_peak_flops_fp16_tc / 1e12)) * 100 @@ -238,12 +210,13 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward): "N": N, "K": K, "Include_Backward": include_backward, + "Input_Requires_Grad": requires_grad, "FP8_time_ms": fp8_time * 1000, "BF16_time_ms": bfloat16_time * 1000, "FP8_TFLOPS": fp8_tflops, "BF16_TFLOPS": bfloat16_tflops, - "FP8_efficiency_%": fp8_efficiency, - "BF16_efficiency_%": bfloat16_efficiency, + "FP8_MFU%": fp8_efficiency, + "BF16_MFU%": bfloat16_efficiency, "Speedup": bfloat16_time / fp8_time, } @@ -252,79 +225,61 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward): torch.backends.cudnn.benchmark = True args = parse_args() - dimensions = [int(d.strip()) for d in args.dimensions.split(",")] - TP_SIZE = args.tp_size - EXP_NUMBER = args.exp_number - - results = [] - total = len(list(itertools.product(dimensions, dimensions, dimensions))) - parallel_context = ParallelContext(data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=TP_SIZE) + parallel_context = ParallelContext( + data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=args.tp_size + ) - # Run benchmarks and collect results results = [] - i = 0 - for M, N, K in itertools.product(dimensions, dimensions, dimensions): - i += 1 - # Run forward-only case - result = benchmark_linear_operations(M, N, K, parallel_context, include_backward=False) + # combinations = list(itertools.product( + # dimensions, # M + # dimensions, # N + # dimensions, # K + # [False, True], # include_backward + # [False, True] # requires_grad + # )) + # Pair dimensions index-wise + # combinations = [] + # for i in range(len(dimensions)): + # M = dimensions[i] + # N = dimensions[i] + # K = dimensions[i] + # # For each dimension pair, test all combinations of include_backward and requires_grad + # for include_backward in [False, True]: + # for requires_grad in [False, True]: + # combinations.append((M, N, K, include_backward, requires_grad)) + + combinations = [] + # Check if explicit M, K, N dimensions are provided + if all(dim is not None for dim in [args.m_size, args.k_size, args.n_size]): + # Use explicitly provided dimensions + for include_backward in [False, True]: + # for requires_grad in [False, True]: + for requires_grad in [False]: + combinations.append((args.m_size, args.n_size, args.k_size, include_backward, requires_grad)) + else: + # Use dimensions from the --dimensions argument + dimensions = [int(d.strip()) for d in args.dimensions.split(",")] + for i in range(len(dimensions)): + M = dimensions[i] + N = dimensions[i] + K = dimensions[i] + for include_backward in [False, True]: + # for requires_grad in [False, True]: + for requires_grad in [False]: + combinations.append((M, N, K, include_backward, requires_grad)) + + total = len(combinations) + + for i, (M, N, K, include_backward, requires_grad) in enumerate(combinations, 1): + result = benchmark_linear_operations(M, N, K, parallel_context, include_backward, requires_grad) results.append(result) - print(f"Experiment {i}/{total} complete (Forward-only)") - - # Run forward+backward case - result = benchmark_linear_operations(M, N, K, parallel_context, include_backward=True) - results.append(result) - print(f"Experiment {i}/{total} complete (Forward+Backward)") + print(f"Experiment {i}/{total} complete") df = pd.DataFrame(results) - df = df.round(2) # Round to 2 decimal places - df = df.sort_values(by=["M", "N", "K", "Include_Backward"]) + df = df.round(2) + df = df.sort_values(by=["M", "N", "K", "Include_Backward", "Input_Requires_Grad"]) print(df) - - # # Define columns to color and their respective color scales - # color_columns = { - # 'FP8_time_ms': 'Reds', - # 'BF16_time_ms': 'Blues', - # 'FP8_TFLOPS': 'Greens', - # 'BF16_TFLOPS': 'Purples', - # 'FP8_efficiency_%': 'Oranges', - # 'BF16_efficiency_%': 'Viridis', - # 'Speedup': 'RdYlBu' - # } - - # # Create the table - # fig = go.Figure(data=[go.Table( - # header=dict( - # values=list(df.columns), - # fill_color='lightgrey', - # align='left', - # font=dict(size=12, color='black') - # ), - # cells=dict( - # values=[df[col] for col in df.columns], - # align='left', - # font=dict(size=11), - # # Format cells with colors for numeric columns - # fill_color=[ - # 'white' if col not in color_columns else - # [f'rgba({int(255*i)}, {int(255*i)}, {int(255*i)}, 0.5)' - # for i in (df[col]-df[col].min())/(df[col].max()-df[col].min())] - # for col in df.columns - # ] - # ) - # )]) - - # # Update layout - # fig.update_layout( - # title=f'Benchmark Results (TP_SIZE={TP_SIZE})', - # width=1200, - # height=800, - # margin=dict(l=20, r=20, t=40, b=20) - # ) - - # # Save the interactive HTML file - # fig.write_html(f'{EXP_NUMBER}_benchmark_results_tp{TP_SIZE}.html') - create_html_table(df, EXP_NUMBER, TP_SIZE) - - print(f"\nResults have been saved to {EXP_NUMBER}_benchmark_results_tp{TP_SIZE}.html") + create_html_table(df, args.exp_number, args.tp_size) + print(f"\nResults have been saved to {args.exp_number}_benchmark_results_tp{args.tp_size}.html") diff --git a/examples/config_fp8_llama.yaml b/examples/config_fp8_llama.yaml new file mode 100644 index 00000000..fbbad98f --- /dev/null +++ b/examples/config_fp8_llama.yaml @@ -0,0 +1,109 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: int8 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 16 + initializer_range: 0.02 + intermediate_size: 64 + is_llama_config: true + max_position_embeddings: 256 + num_attention_heads: 4 + num_hidden_layers: 2 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 2 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 4 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 256 + train_steps: 15 + val_check_interval: -1 diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index 58645e2d..0fd639d5 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -44,7 +44,7 @@ logging: log_level_replica: info model: ddp_bucket_cap_mb: 25 - dtype: bfloat16 + dtype: float8 init_method: std: 0.025 make_vocab_size_divisible_by: 1 diff --git a/examples/llama/tests/test_conversion.py b/examples/llama/tests/test_conversion.py index b5ce3529..8fa02a81 100644 --- a/examples/llama/tests/test_conversion.py +++ b/examples/llama/tests/test_conversion.py @@ -15,6 +15,7 @@ from nanotron.models.base import init_on_device_and_dtype from nanotron.models.llama import LlamaForTraining from nanotron.parallel import ParallelContext +from nanotron.testing.utils import TestContext from nanotron.trainer import mark_tied_parameters from examples.llama.convert_hf_to_nanotron import convert_checkpoint_and_save as convert_hf_to_nt_and_save @@ -22,7 +23,6 @@ from examples.llama.convert_nanotron_to_hf import convert_checkpoint_and_save as convert_nt_to_hf_and_save from examples.llama.convert_nanotron_to_hf import convert_nt_to_hf, get_hf_config from examples.llama.convert_weights import load_nanotron_model, make_parallel_config -from tests.helpers.context import TestContext from tests.helpers.utils import init_distributed CONFIG = NanotronLlamaConfig( @@ -141,7 +141,7 @@ def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor): logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) logits_hf = model_hf(input_ids).logits assert logits_nt.size() == logits_hf.size() - torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL) + torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL) def test_hf_to_nt(input_ids: torch.Tensor): diff --git a/scripts/01_standalone_fp8_tensor.py b/scripts/01_standalone_fp8_tensor.py new file mode 100644 index 00000000..6de99488 --- /dev/null +++ b/scripts/01_standalone_fp8_tensor.py @@ -0,0 +1,73 @@ +import torch +import torch.utils._pytree as pytree + + +class QuantTensor(torch.Tensor): + @staticmethod + def __new__(cls, data: torch.Tensor): + return torch.Tensor._make_wrapper_subclass( + cls, + data.shape, + device=data.device, + ) + + # @torch._dynamo.disable + def __init__(self, data: torch.Tensor): + self._data = data + + def __tensor_flatten__(self): + return ["_data"], [] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["_data"], *tensor_attributes) + + def __repr__(self): + return f"{self.__class__.__name__}(data={self._data})" + + # @classmethod + # def __torch_function__(cls, func, types, args=(), kwargs=None): + # kwargs = kwargs or dict() + + # if func is F.linear: + # return _BitNetTrainingLinear.apply(*args, **kwargs) + + # with torch._C.DisableTorchFunctionSubclass(): + # return func(*args, **kwargs) + + # adapted from FP8 implementation of WeightWithDynamicFloat8CastTensor + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + out = func( + *pytree.tree_map_only(cls, lambda x: x._data, args), + **pytree.tree_map_only(cls, lambda x: x._data, kwargs), + ) + + if func is torch.ops.aten.copy_.default: + # return original object + return args[0] + elif func in { + torch.ops.aten.t.default, + torch.ops.aten.detach.default, + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, + }: + # return new wrapped object + return pytree.tree_map_only(torch.Tensor, lambda x: cls(x), out) + else: + # return new unwrapped object + return out + + +if __name__ == "__main__": + tensor = torch.randn(2, 2, device="cuda") + quant_tensor = QuantTensor(tensor) + + assert 1 == 1 diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index f4c07146..bf8407fc 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -75,8 +75,9 @@ def serialize(data) -> dict: torch.complex128: "complex128", torch.float16: "float16", torch.bfloat16: "bfloat16", - torch.uint8: "uint8", - torch.int8: "int8", + # torch.uint8: "uint8", + # torch.int8: "int8", + torch.int8: "float8", torch.int16: "int16", torch.int32: "int32", torch.int64: "int64", diff --git a/src/nanotron/fp8/_tensor.py b/src/nanotron/fp8/_tensor.py new file mode 100644 index 00000000..0d288850 --- /dev/null +++ b/src/nanotron/fp8/_tensor.py @@ -0,0 +1,90 @@ +import torch + +from nanotron.fp8.meta import FP8Meta + + +# This is within core, the end user never have to look at this +class _WrapperTensor(torch.Tensor): + @staticmethod + def __new__(cls, *args, **kwargs): + t, kwargs = cls.get_wrapper_properties(*args, **kwargs) + if "size" not in kwargs: + size = t.size() + else: + size = kwargs["size"] + del kwargs["size"] + if "dtype" not in kwargs: + kwargs["dtype"] = t.dtype + if "layout" not in kwargs: + kwargs["layout"] = t.layout + if "device" not in kwargs: + kwargs["device"] = t.device + if "requires_grad" not in kwargs: + kwargs["requires_grad"] = False + # Ignore memory_format and pin memory for now as I don't know how to + # safely access them on a Tensor (if possible??) + + wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs) + wrapper._validate_methods() + return wrapper + + @classmethod + def get_wrapper_properties(cls, *args, **kwargs): + # Should return both an example Tensor and a dictionaly of kwargs + # to override any of that example Tensor's properly. + # This is very similar to the `t.new_*(args)` API + raise NotImplementedError("You need to implement get_wrapper_properties") + + def _validate_methods(self): + # Skip this if not in debug mode? + # Changing these on the python side is wrong as it would not be properly reflected + # on the c++ side + # This doesn't catch attributes set in the __init__ + forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"] + for el in forbidden_overrides: + if getattr(self.__class__, el) is not getattr(torch.Tensor, el): + raise RuntimeError( + f"Subclass {self.__class__.__name__} is overwriting the " + f"property {el} but this is not allowed as such change would " + "not be reflected to c++ callers." + ) + + def __repr__(self): + return f"{self.__class__.__name__}({self.__dict__})" + + +from torch.utils._pytree import tree_map + + +class _FP8Tensor(_WrapperTensor): + @classmethod + def get_wrapper_properties(cls, diag): + # return diag, {"size": diag.size() + diag.size()} + return diag, {} + + def __init__(self, data: torch.Tensor, fp8_meta: FP8Meta): + self._tensor = data + + @property + def data(self): + return self._tensor + + @data.setter + def data(self, data): + self._tensor = data + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return torch.diag(e._diag) if isinstance(e, _FP8Tensor) else e + + def wrap(e): + return _FP8Tensor(torch.diag(e)) if isinstance(e, torch.Tensor) else e + + rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + return rs + + +class FP8E4M3Tensor(_FP8Tensor): + def __init__(self): + pass diff --git a/src/nanotron/fp8/constant_recipe.py b/src/nanotron/fp8/constant_recipe.py new file mode 100644 index 00000000..4173ca7b --- /dev/null +++ b/src/nanotron/fp8/constant_recipe.py @@ -0,0 +1,12 @@ +from torch import nn + +from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding + +MODULE_NAMES_THAT_NOT_FP8 = [ + "token_embedding", + "input_layernorm", + "post_attention_layernorm", + "final_layer_norm", + "lm_head", +] +MODULES_THAT_IN_FLOAT16 = [TensorParallelEmbedding, nn.LayerNorm] diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 0bc31b09..4936e42c 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -12,7 +12,6 @@ from nanotron.fp8.parameter import FP8Parameter from nanotron.fp8.recipe import FP8LinearRecipe from nanotron.fp8.tensor import FP8Tensor -from nanotron.parallel.parameters import get_data_from_param @dataclass @@ -55,7 +54,10 @@ def __init__( assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}" self.weight = quant_w - assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}" + if self.name == "model.decoder.0.attention.qkv_proj": + assert 1 == 1 + + assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}, name: {self.name}" if self.bias is not None: self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype)) @@ -66,6 +68,7 @@ def __init__( def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor: import nanotron.fp8.functional as F + from nanotron.parallel.parameters import get_data_from_param return F.linear( input=input, diff --git a/src/nanotron/fp8/meta.py b/src/nanotron/fp8/meta.py index f9f3afbe..3064cf1f 100644 --- a/src/nanotron/fp8/meta.py +++ b/src/nanotron/fp8/meta.py @@ -30,6 +30,7 @@ class FP8Meta: # TODO(xrsrke): change to Literal[torch.int8, torch.uint8] dtype: DTypes interval: int + sync_amax: bool = False @property def te_dtype(self) -> tex.DType: diff --git a/src/nanotron/fp8/tensor.py b/src/nanotron/fp8/tensor.py index 80fbc533..368ae6fc 100644 --- a/src/nanotron/fp8/tensor.py +++ b/src/nanotron/fp8/tensor.py @@ -78,6 +78,16 @@ def __new__( return obj + def __init__( + self, + tensor: torch.Tensor, + dtype: Optional[DTypes] = None, + interval: Optional[int] = 1, + fp8_meta: Optional[FP8Meta] = None, + sync: bool = False, + ) -> None: + pass + @staticmethod # @torch.no_grad() def _get_metadata(tensor: torch.Tensor, dtype: DTypes, interval: int, sync: bool) -> "FP8Meta": @@ -188,6 +198,20 @@ def clone(self) -> FP8Tensor: tensor.fp8_meta = deepcopy(self.fp8_meta) return tensor + # def __torch_function__(self, func, types, args=(), kwargs=None): + # return super().__torch_function__(func, types, args, kwargs) + + # @classmethod + # def __torch_function__(cls, func, types, args, kwargs=None): + # kwargs = kwargs or {} + # if func is torch.transpose: + # assert type(args[0]) == cls + # assert type(args[1]) == type(args[2]) == int + # # return CustomMaskedSum.apply(*args, **kwargs) + # assert 1 == 1 + # else: + # super().__torch_function__(func, types, args, kwargs) + class FP8Tensor(LowPrecisionTensor): """FP8 Tensor.""" diff --git a/src/nanotron/fp8/utils.py b/src/nanotron/fp8/utils.py index b1934aa0..0dc7c3d3 100644 --- a/src/nanotron/fp8/utils.py +++ b/src/nanotron/fp8/utils.py @@ -313,10 +313,10 @@ def is_convert_to_fp16(module) -> bool: IS_CONVERT_TO_FLOAT16 = False name_of_modules_not_in_fp16 = get_modules_not_in_fp16() - if hasattr(module, "name") and "lm_head" in module.name: - assert 1 == 1 + # if hasattr(module, "name") and "lm_head" in module.name: + # assert 1 == 1 - if constants.CONFIG.fp8.model is None: + if constants.CONFIG is not None and constants.CONFIG.fp8.model is None: if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16): IS_CONVERT_TO_FLOAT16 = True else: diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index 14ac6908..78fd3965 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -71,7 +71,7 @@ def get_embeddings_lm_head_tied_names(self) -> list[str]: Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] """ return [] - + def get_named_params_without_weight_decay(self) -> List[str]: """Return a list of named parameters that should not have weight decay applied to them.""" return [] @@ -237,11 +237,118 @@ def build_model( return model +old_register_parameter = nn.Module.register_parameter +old_register_buffer = nn.Module.register_buffer + + +def _register_empty_parameter_for_fp8(module, name, param): + old_register_parameter(module, name, param) + + # from nanotron.fp8.constant_recipe import MODULES_THAT_IN_FLOAT16 + # from nanotron.fp8.utils import get_modules_not_in_fp16 + + # MODULES_THAT_IN_FLOAT16 = [TensorParallelEmbedding, nn.LayerNorm] + # name_of_modules_not_in_fp16 = get_modules_not_in_fp16() + + if param is not None: + from nanotron.fp8.utils import is_convert_to_fp16 + + # IS_CONVERT_TO_FLOAT16 = False + # if constants.CONFIG.fp8.model is None: + # if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16): + # IS_CONVERT_TO_FLOAT16 = True + # else: + # if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16) or ( + # hasattr(module, "name") and module.name not in name_of_modules_not_in_fp16 + # ): + # IS_CONVERT_TO_FLOAT16 = True + is_convert_to_float16 = is_convert_to_fp16(module) + + if is_convert_to_float16: + import nanotron + from nanotron import constants + + if constants.CONFIG is not None: + from typing import cast + + from nanotron.config.fp8_config import FP8Args + + fp8_config = cast(FP8Args, constants.CONFIG.fp8) + resid_dtype = fp8_config.resid_dtype + else: + resid_dtype = nanotron.fp8.constants.FP8LM_LINEAR_RECIPE.accum_dtype + + param.data = param.data.to(torch.device("cuda"), resid_dtype) + else: + param.data = param.data.to(torch.device("cuda")) + + +def _register_empty_buffer_for_fp8(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + + # from nanotron.fp8.constant_recipe import MODULES_THAT_IN_FLOAT16 + # from nanotron.fp8.utils import get_modules_not_in_fp16 + + # name_of_modules_not_in_fp16 = get_modules_not_in_fp16() + + if buffer is not None: + # IS_CONVERT_TO_FLOAT16 = False + + # # NOTE: convert all modules in FP8 except MODULES_THAT_IN_FLOAT16 + # # if fp8.model is None + # if constants.CONFIG.fp8.model is None: + # if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16): + # IS_CONVERT_TO_FLOAT16 = True + # else: + # if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16) or ( + # hasattr(module, "name") and module.name not in name_of_modules_not_in_fp16 + # ): + # IS_CONVERT_TO_FLOAT16 = True + + from nanotron.fp8.utils import is_convert_to_fp16 + + is_convert_to_float16 = is_convert_to_fp16(module) + + if is_convert_to_float16: + + # from nanotron import constants + from nanotron.fp8 import constants + + # fp8_config = cast(FP8Args, constants.CONFIG.fp8) + # fp8_config = cast(FP8Args, constants.CONFIG.fp8) + # module._buffers[name] = module._buffers[name].to(torch.device("cuda"), torch.float16) + module._buffers[name] = module._buffers[name].to( + torch.device("cuda"), constants.FP8LM_LINEAR_RECIPE.accum_dtype + ) + else: + buffer.data = buffer.data.to(torch.device("cuda")) + + +def _register_empty_parameter(module, name, param, device, dtype): + old_register_parameter(module, name, param) + if param is not None: + if isinstance(param, DTypeInvariantTensor): + # if param is DTypeInvariantTensor we should avoid updating it + param.data = param.data.to(device) + else: + param.data = param.data.to(device, dtype) + + +def _register_empty_buffer(module, name, buffer, device, dtype, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + if isinstance(buffer, DTypeInvariantTensor): + # if buffer is DTypeInvariantTensor we should avoid updating it + buffer.data = buffer.data.to(device) + else: + module._buffers[name] = module._buffers[name].to(device, dtype) + + # TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does. @contextmanager def init_on_device_and_dtype( device: torch.device = torch.device("cpu"), - dtype: torch.dtype = torch.float, + dtype: torch.dtype = torch.float32, ): """ A context manager under which models are initialized with all parameters on the specified device. @@ -258,28 +365,32 @@ def init_on_device_and_dtype( from accelerate import init_on_device with init_on_device_and_dtype(device=torch.device("cuda")): tst = nn.Liner(100, 100) # on `cuda` device + + NOTE: in order to initialize an hybrid fp8 properly, you should use this context manager ``` """ + from functools import wraps + + def method_partial(func, *args, **kwargs): + @wraps(func) + def wrapper(self, *fargs, **fkwargs): + return func(self, *args, *fargs, **kwargs, **fkwargs) + + return wrapper + old_register_parameter = nn.Module.register_parameter old_register_buffer = nn.Module.register_buffer - def register_empty_parameter(module, name, param): - old_register_parameter(module, name, param) - if param is not None: - if isinstance(param, DTypeInvariantTensor): - # if param is DTypeInvariantTensor we should avoid updating it - param.data = param.data.to(device) - else: - param.data = param.data.to(device, dtype) - - def register_empty_buffer(module, name, buffer, persistent=True): - old_register_buffer(module, name, buffer, persistent=persistent) - if buffer is not None: - if isinstance(buffer, DTypeInvariantTensor): - # if buffer is DTypeInvariantTensor we should avoid updating it - buffer.data = buffer.data.to(device) - else: - module._buffers[name] = module._buffers[name].to(device, dtype) + register_empty_parameter = ( + _register_empty_parameter_for_fp8 + if dtype == torch.int8 + else method_partial(_register_empty_parameter, device=device, dtype=dtype) + ) + register_empty_buffer = ( + _register_empty_buffer_for_fp8 + if dtype == torch.int8 + else method_partial(_register_empty_buffer, device=device, dtype=dtype) + ) # Patch tensor creation tensor_constructors_to_patch = { @@ -289,8 +400,10 @@ def register_empty_buffer(module, name, buffer, persistent=True): def patch_tensor_constructor(fn): def wrapper(*args, **kwargs): + # NOTE: nanotron automatically sets the device and dtype of the tensor + # but for FP8 training, we initializes with float16 first kwargs["device"] = device - kwargs["dtype"] = dtype + kwargs["dtype"] = torch.float32 if dtype == torch.int8 else dtype return fn(*args, **kwargs) return wrapper @@ -308,6 +421,77 @@ def wrapper(*args, **kwargs): setattr(torch, torch_function_name, old_torch_function) +# TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does. +# @contextmanager +# def init_on_device_and_dtype( +# device: torch.device = torch.device("cpu"), +# dtype: torch.dtype = torch.float, +# ): +# """ +# A context manager under which models are initialized with all parameters on the specified device. +# Args: +# device (`torch.device` defaults to `cpu`): +# Device to initialize all parameters on. +# dtype (`torch.dtype` defaults to `torch.float`): +# Dtype to initialize all parameters on. +# include_buffers (`bool`, defaults to `False`): +# Whether or not to also default all buffers constructors given previous arguments. +# Example: +# ```python +# import torch.nn as nn +# from accelerate import init_on_device +# with init_on_device_and_dtype(device=torch.device("cuda")): +# tst = nn.Liner(100, 100) # on `cuda` device +# ``` +# """ +# old_register_parameter = nn.Module.register_parameter +# old_register_buffer = nn.Module.register_buffer + +# def register_empty_parameter(module, name, param): +# old_register_parameter(module, name, param) +# if param is not None: +# if isinstance(param, DTypeInvariantTensor): +# # if param is DTypeInvariantTensor we should avoid updating it +# param.data = param.data.to(device) +# else: +# param.data = param.data.to(device, dtype) + +# def register_empty_buffer(module, name, buffer, persistent=True): +# old_register_buffer(module, name, buffer, persistent=persistent) +# if buffer is not None: +# if isinstance(buffer, DTypeInvariantTensor): +# # if buffer is DTypeInvariantTensor we should avoid updating it +# buffer.data = buffer.data.to(device) +# else: +# module._buffers[name] = module._buffers[name].to(device, dtype) + +# # Patch tensor creation +# tensor_constructors_to_patch = { +# torch_function_name: getattr(torch, torch_function_name) +# for torch_function_name in ["empty", "zeros", "ones", "full"] +# } + +# def patch_tensor_constructor(fn): +# def wrapper(*args, **kwargs): +# kwargs["device"] = device +# kwargs["dtype"] = dtype +# return fn(*args, **kwargs) + +# return wrapper + +# try: +# nn.Module.register_parameter = register_empty_parameter +# nn.Module.register_buffer = register_empty_buffer +# for torch_function_name in tensor_constructors_to_patch.keys(): +# setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) +# yield +# finally: +# nn.Module.register_parameter = old_register_parameter +# nn.Module.register_buffer = old_register_buffer +# for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): +# setattr(torch, torch_function_name, old_torch_function) + + def check_model_has_grad(model: NanotronModel, parallel_context: "ParallelContext"): """Check that there's at least a parameter in current PP rank that has a gradient.""" for param in model.parameters(): diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 88fb6bcb..e6e74ecb 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -35,10 +35,10 @@ from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy from nanotron.parallel.tensor_parallel.nn import ( - TensorParallelColumnLinear, + FP8TensorParallelColumnLinear, + FP8TensorParallelRowLinear, TensorParallelEmbedding, TensorParallelLinearMode, - TensorParallelRowLinear, ) from nanotron.random import RandomStates from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator @@ -207,6 +207,7 @@ def __init__( config: LlamaConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + layer_idx: int, ): super().__init__() @@ -220,7 +221,8 @@ def __init__( config.intermediate_size, # shape of gate_linear config.intermediate_size, # shape of up_linear ) - self.gate_up_proj = TensorParallelColumnLinear( + # self.gate_up_proj = TensorParallelColumnLinear( + self.gate_up_proj = FP8TensorParallelColumnLinear( config.hidden_size, 2 * config.intermediate_size, pg=tp_pg, @@ -228,15 +230,18 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, - tp_recompute_allgather=parallel_config.tp_recompute_allgather, + name=f"model.decoder.{layer_idx}.mlp.gate_up_proj", + # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) - self.down_proj = TensorParallelRowLinear( + # self.down_proj = TensorParallelRowLinear( + self.down_proj = FP8TensorParallelRowLinear( config.intermediate_size, config.hidden_size, pg=tp_pg, mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + name=f"model.decoder.{layer_idx}.mlp.down_proj", ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -381,7 +386,8 @@ def __init__( config.num_key_value_heads * self.d_qk, # shape of k config.num_key_value_heads * self.d_qk, # shape of v ) - self.qkv_proj = TensorParallelColumnLinear( + # self.qkv_proj = TensorParallelColumnLinear( + self.qkv_proj = FP8TensorParallelColumnLinear( self.d_model, config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk, pg=tp_pg, @@ -389,7 +395,8 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, - tp_recompute_allgather=parallel_config.tp_recompute_allgather, + name=f"model.decoder.{layer_idx}.attention.qkv_proj", + # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -411,13 +418,15 @@ def __init__( dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved ) - self.o_proj = TensorParallelRowLinear( + # self.o_proj = TensorParallelRowLinear( + self.o_proj = FP8TensorParallelRowLinear( config.num_attention_heads * self.d_qk, self.d_model, pg=tp_pg, mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, + name=f"model.decoder.{layer_idx}.attention.o_proj", ) self.attention = CoreAttention( @@ -710,7 +719,7 @@ def __init__( ) self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx) self.recompute_layer = parallel_config.recompute_layer @@ -856,7 +865,8 @@ def __init__( self.lm_head = PipelineBlock( p2p=self.p2p, # Understand that this means that we return sharded logits that are going to need to be gathered - module_builder=TensorParallelColumnLinear, + # module_builder=TensorParallelColumnLinear, + module_builder=FP8TensorParallelColumnLinear, module_kwargs={ "in_features": config.hidden_size, "out_features": config.vocab_size, @@ -865,7 +875,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, - "tp_recompute_allgather": parallel_config.tp_recompute_allgather, + # "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, @@ -920,7 +930,8 @@ def get_block_compute_costs(self): LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + 3 * d_ff * model_config.hidden_size, # This is the last lm_head - TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + # TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + FP8TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, } return block_compute_costs diff --git a/src/nanotron/testing/parallel.py b/src/nanotron/testing/parallel.py new file mode 100644 index 00000000..1dbe8a3c --- /dev/null +++ b/src/nanotron/testing/parallel.py @@ -0,0 +1,159 @@ +import os +import re +from inspect import signature +from typing import Callable + +import torch.cuda +import torch.multiprocessing as mp +from nanotron.parallel import ParallelContext +from packaging import version + + +def global_wrapper(rank, func, tp, pp, dp, port, kwargs): + def setup_dist_env(rank, world_size, port): + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(rank) + # NOTE: since we do unit tests in a + # single node => this is fine! + os.environ["LOCAL_RANK"] = str(rank) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + world_size = tp * pp * dp + setup_dist_env(rank, world_size, port) + parallel_context = ParallelContext(data_parallel_size=dp, pipeline_parallel_size=pp, tensor_parallel_size=tp) + func(parallel_context, **kwargs) + + +def init_distributed(tp: int, dp: int, pp: int): + def _init_distributed(func): + def wrapper(**kwargs): + from nanotron.utils import find_free_port + + world_size = tp * pp * dp + port = find_free_port() + + # Note that kwargs needs to be passed as part of args in a way that can be unpacked + args = (func, tp, pp, dp, port, kwargs) + mp.spawn(global_wrapper, args=args, nprocs=world_size) + + return wrapper + + return _init_distributed + + +def rerun_if_address_is_in_use(max_try: int = 500): + """ + This function reruns a wrapped function if "address already in use" occurs + in testing spawned with torch.multiprocessing + + Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L157 + + Usage:: + + @rerun_if_address_is_in_use() + def test_something(): + ... + + """ + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major >= 1 + + # only torch >= 1.8 has ProcessRaisedException + if torch_version >= version.parse("1.8.0"): + exception = torch.multiprocessing.ProcessRaisedException + else: + exception = Exception + + func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=max_try) + return func_wrapper + + +def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 10) -> Callable: + """ + A decorator on a function to re-run when an exception occurs. + + Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L71 + + Usage:: + + # rerun for all kinds of exception + @rerun_on_exception() + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for RuntimeError only + @rerun_on_exception(exception_type=RuntimeError) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for maximum 10 times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=10) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=None) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun only the exception message is matched with pattern + # for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + Args: + exception_type (Exception, Optional): The type of exception to detect for rerun + pattern (str, Optional): The pattern to match the exception message. + If the pattern is not None and matches the exception message, + the exception will be detected for rerun + max_try (int, Optional): Maximum reruns for this function. The default value is 5. + If max_try is None, it will rerun forever if exception keeps occurring + """ + + def _match_lines(lines, pattern): + for line in lines: + if re.match(pattern, line): + return True + return False + + def _wrapper(func): + def _run_until_success(*args, **kwargs): + try_count = 0 + assert max_try is None or isinstance( + max_try, int + ), f"Expected max_try to be None or int, but got {type(max_try)}" + + while max_try is None or try_count < max_try: + try: + try_count += 1 + ret = func(*args, **kwargs) + return ret + except exception_type as e: + error_lines = str(e).split("\n") + if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): + + print("Exception is caught, retrying...") + # when pattern is not specified, we always skip the exception + # when pattern is specified, we only skip when pattern is matched + continue + else: + print("Maximum number of attempts is reached or pattern is not matched, no more retrying...") + raise e + + # Override signature + # otherwise pytest.mark.parameterize will raise the following error: + # function does not use argument xxx + sig = signature(func) + _run_until_success.__signature__ = sig + + return _run_until_success + + return _wrapper diff --git a/src/nanotron/testing/utils.py b/src/nanotron/testing/utils.py new file mode 100644 index 00000000..77fa5d70 --- /dev/null +++ b/src/nanotron/testing/utils.py @@ -0,0 +1,21 @@ +import shutil +import uuid +from functools import lru_cache +from pathlib import Path + + +class TestContext: + def __init__(self): + self._random_string = str(uuid.uuid1()) + self._root_dir = Path(__file__).parent.parent / ".test_cache" + self._root_dir.mkdir(parents=True, exist_ok=True) + + @lru_cache(maxsize=1) + def get_auto_remove_tmp_dir(self): + path = self._root_dir / self._random_string + path.mkdir(parents=True, exist_ok=True) + return path + + def __del__(self): + path = self.get_auto_remove_tmp_dir() + shutil.rmtree(path) diff --git a/tests/fp8/test_new_tensor.py b/tests/fp8/test_new_tensor.py new file mode 100644 index 00000000..05912e11 --- /dev/null +++ b/tests/fp8/test_new_tensor.py @@ -0,0 +1,452 @@ +from copy import deepcopy +from typing import cast + +import numpy as np +import pytest +import torch +from nanotron.fp8.constants import ( + FP8_ATOL_THRESHOLD, + FP8_RTOL_THRESHOLD, + FP16_ATOL_THRESHOLD, + FP16_RTOL_THRESHOLD, + QTYPE_TO_DTYPE, +) +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.meta import FP8Meta +from nanotron.fp8.tensor import FP8Tensor, FP16Tensor +from nanotron.testing.utils import TestContext + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +@pytest.mark.parametrize("interval", [1, 5]) +def test_fp8_and_fp16_metadata(tensor_cls, dtype, interval): + import transformer_engine as te # noqa + import transformer_engine_torch as tex + + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = tensor_cls(tensor, dtype=dtype, interval=interval) + fp8_meta = cast(FP8Meta, fp8_tensor.fp8_meta) + + # TODO(xrsrke): remove the fixed 1 factor + # it couples with the current implementation of FP8Meta + # because we initialize scale with 1 + assert fp8_meta.amax == ref_tensor.abs().max() + assert isinstance(fp8_meta.inverse_scale, torch.Tensor) + assert fp8_meta.scale != 0.1 and fp8_meta.scale != 1.0 + assert isinstance(fp8_meta.te_dtype, tex.DType) + assert fp8_meta.interval == interval + + +@pytest.mark.parametrize("size", [4, 8, 16, 64]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_quantize_and_dequantize_tensor_in_fp8(size, dtype): + tensor = torch.randn((size, size), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + fp8_tensor = FP8Tensor(tensor, dtype=dtype) + + assert not np.array_equal(fp8_tensor.cpu().numpy(), ref_tensor.cpu().numpy()) + + tensor = fp8_tensor.to(torch.float32) + # NOTE: sometimes type(tensor) is FP8Tensor, but it still passes, so we directly check the class name + # to make sure it's a torch.Tensor + assert tensor.__class__ == torch.Tensor + assert tensor.dtype == ref_tensor.dtype + + torch.testing.assert_close(tensor, ref_tensor, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD) + + +# @pytest.mark.parametrize("interval", [1, 5]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_create_fp8_tensor_from_metadata(dtype): + INTERVAL = 5 + TOTAL_STEPS, REMAINING_STEPS = 20, 16 + tensor = torch.randn((16, 16), dtype=torch.float32, device="cuda") + fp8_tensor = FP8Tensor(tensor, dtype=dtype, interval=INTERVAL) + + new_values = [torch.randn((16, 16), dtype=torch.float32, device="cuda") for i in range(TOTAL_STEPS)] + + for i in range(TOTAL_STEPS): + if TOTAL_STEPS - REMAINING_STEPS == i: + current_tensor = fp8_tensor.orig_data + fp8_meta = deepcopy(fp8_tensor.fp8_meta) + + fp8_tensor.data = new_values[i] + + resumed_fp8_tensor = FP8Tensor.from_metadata(current_tensor, fp8_meta) + for i in range(TOTAL_STEPS - REMAINING_STEPS, TOTAL_STEPS): + resumed_fp8_tensor.data = new_values[i] + + # NOTE: we expect a resume tensor to have the state trajectory of the original tensor + assert resumed_fp8_tensor == fp8_tensor + + +@pytest.mark.parametrize("size", [4, 8, 16, 64]) +def test_quantize_and_dequantize_tensor_in_fp16(size): + tensor = torch.randn((size, size), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp16_tensor = FP16Tensor(tensor, dtype=DTypes.KFLOAT16) + + assert not np.array_equal(fp16_tensor.cpu().numpy(), ref_tensor.cpu().numpy()) + + # tensor = convert_tensor_from_fp16(fp16_tensor, torch.float32) + tensor = fp16_tensor.to(torch.float32) + # NOTE: sometimes type(tensor) is FP16Tensor, but it still passes + assert tensor.__class__ == torch.Tensor + assert tensor.dtype == ref_tensor.dtype + + # NOTE: this tolerance is from FP8-LM's implementation + # reference: https://github.com/Azure/MS-AMP/blob/9ac98df5371f3d4174d8f103a1932b3a41a4b8a3/tests/common/tensor/test_cast.py#L35 + torch.testing.assert_close(tensor, ref_tensor, rtol=FP16_RTOL_THRESHOLD, atol=FP16_ATOL_THRESHOLD) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_fp8_and_fp16_tensor_repr(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(tensor, dtype) + + # NOTE: in some cases, it causes an infinite loop + # in repr(tensor), so just check if it doesn't loop + assert isinstance(repr(fp8_tensor), str) + + +@pytest.mark.parametrize( + "tensor_cls, dtype, expected_dtype", + [ + (FP8Tensor, DTypes.FP8E4M3, torch.uint8), + (FP8Tensor, DTypes.FP8E5M2, torch.uint8), + (FP16Tensor, DTypes.KFLOAT16, torch.float16), + ], +) +def test_fp8_and_fp16_tensor_attrs(tensor_cls, dtype, expected_dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + ref_tensor = tensor.detach().clone() + + fp8_tensor = tensor_cls(tensor, dtype) + + assert isinstance(fp8_tensor, tensor_cls) + assert isinstance(fp8_tensor.fp8_meta, FP8Meta) + assert fp8_tensor.dtype == expected_dtype + assert fp8_tensor.device == ref_tensor.device + assert fp8_tensor.shape == ref_tensor.shape + assert fp8_tensor.numel() == ref_tensor.numel() + assert fp8_tensor.device == ref_tensor.device + + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", +# [ +# (FP8Tensor, DTypes.FP8E4M3), +# (FP8Tensor, DTypes.FP8E5M2), +# (FP16Tensor, DTypes.KFLOAT16), +# ], +# ) +# @pytest.mark.parametrize( +# "scale", +# [ +# torch.ones(1, device="cuda:0").squeeze() * 2, # an random scalar +# torch.ones(1, device="cuda:0") * 2, +# torch.ones(4, 4, device="cuda:0") * 2, +# ], +# ) +# def test_multiple_fp8_tensor(tensor_cls, dtype, scale): +# RTOL, ATOL = ( +# (FP8_RTOL_THRESHOLD, FP8_ATOL_THRESHOLD) +# if tensor_cls == FP8Tensor +# else (FP16_RTOL_THRESHOLD, FP16_ATOL_THRESHOLD) +# ) +# tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda:0") +# ref_tensor = tensor.detach().clone() + +# fp8_tensor = tensor_cls(deepcopy(tensor), dtype) +# ref_fp8_tensor = fp8_tensor.clone() + +# with fail_if_expect_to_fail(expect_to_fail=scale.ndim > 1): +# fp8_tensor.mul_(scale) + +# assert torch.equal(fp8_tensor, ref_fp8_tensor) +# assert fp8_tensor.fp8_meta.scale != ref_fp8_tensor.fp8_meta.scale + +# if isinstance(fp8_tensor, FP8Tensor): +# # NOTE: with the current implementation, we only scale the metadata +# # not the tensor itself, so we expect the tensor to be the same +# tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) +# else: +# tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) + +# torch.testing.assert_allclose(tensor, ref_tensor * scale, rtol=RTOL, atol=ATOL) + + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", +# [ +# (FP8Tensor, DTypes.FP8E4M3), +# (FP8Tensor, DTypes.FP8E5M2), +# (FP16Tensor, DTypes.KFLOAT16), +# ], +# ) +# @pytest.mark.parametrize( +# "scale", +# [ +# torch.ones(1, device="cuda:0").squeeze() * 2, # an random scalar +# torch.ones(1, device="cuda:0") * 2, +# torch.ones(4, 4, device="cuda:0") * 2, +# ], +# ) +# def test_divide_fp8_tensor(tensor_cls, dtype, scale): +# # NOTE: the reason that we use 2 as the scale is because +# # the purpose of this test is to test whether we scale the magnitude +# # of the tensor, so if use other values from normal distribution, +# # some values could lead to quantization error, and for this test we don't +# # test the quantization error +# tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") +# ref_tensor = deepcopy(tensor) + +# fp8_tensor = tensor_cls(deepcopy(tensor), dtype) +# ref_fp8_tensor = fp8_tensor.clone() + +# with fail_if_expect_to_fail(expect_to_fail=scale.ndim > 1): +# fp8_tensor.div_(scale) + +# assert torch.equal(fp8_tensor, ref_fp8_tensor) +# assert fp8_tensor.fp8_meta.scale != ref_fp8_tensor.fp8_meta.scale + +# if isinstance(fp8_tensor, FP8Tensor): +# tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) +# # NOTE: use the same tolerance as test_quantize_and_dequantize_tensor_in_fp8 +# torch.testing.assert_allclose(tensor, ref_tensor / scale, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD) +# else: +# tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) +# torch.testing.assert_close(tensor, ref_tensor / scale, rtol=FP16_RTOL_THRESHOLD, atol=FP16_ATOL_THRESHOLD) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_add_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + with pytest.raises(ValueError): + fp8_tensor + 1 + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_subtract_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + with pytest.raises(ValueError): + fp8_tensor - 1 + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_clone_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + cloned_fp8_tensor = fp8_tensor.clone() + + assert isinstance(cloned_fp8_tensor, tensor_cls) + assert id(cloned_fp8_tensor) != id(fp8_tensor) + assert cloned_fp8_tensor.device == fp8_tensor.device + + assert torch.equal(cloned_fp8_tensor, fp8_tensor) + assert cloned_fp8_tensor.data_ptr() != fp8_tensor.data_ptr() + assert cloned_fp8_tensor.data.data_ptr() != fp8_tensor.data.data_ptr() + + assert cloned_fp8_tensor.fp8_meta == fp8_tensor.fp8_meta + assert id(cloned_fp8_tensor.fp8_meta) != id(fp8_tensor.fp8_meta) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + # (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_transpose_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((16, 16), dtype=torch.float32, device="cuda:0") + ref_transposed_tensor = deepcopy(tensor).transpose() + fp8_tensor = tensor_cls(tensor, dtype) + + transposed_fp8_tensor = fp8_tensor.transpose() + + assert isinstance(transposed_fp8_tensor, FP8Tensor) + + dequant_transposed_fp8_tensor = transposed_fp8_tensor.to(torch.float32) + torch.testing.assert_close(dequant_transposed_fp8_tensor, ref_transposed_tensor) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_determistic_quantization(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + ref_fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + assert torch.equal(fp8_tensor, ref_fp8_tensor) + assert fp8_tensor.fp8_meta == ref_fp8_tensor.fp8_meta + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_fp8_and_fp16_tensor_storage_memory(tensor_cls, dtype): + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = tensor_cls(tensor, dtype=dtype) + + assert id(fp8_tensor) != id(ref_tensor) + + assert isinstance(fp8_tensor.data, torch.Tensor) + assert id(fp8_tensor.data) != id(ref_tensor.data) + assert fp8_tensor.data_ptr() == fp8_tensor.data.data_ptr() + assert fp8_tensor.data.data_ptr() != ref_tensor.data_ptr() + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +@pytest.mark.parametrize("is_quantized", [True, False]) +def test_setting_new_data_for_fp8_and_fp16_tensor(tensor_cls, dtype, is_quantized): + RTOL, ATOL = ( + (FP8_RTOL_THRESHOLD, FP8_ATOL_THRESHOLD) + if tensor_cls == FP8Tensor + else (FP16_RTOL_THRESHOLD, FP16_ATOL_THRESHOLD) + ) + + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + quant_tensor = tensor_cls(tensor, dtype=dtype) + + new_data = torch.randn(quant_tensor.shape, dtype=torch.float32, device="cuda") * 2 + ref_new_data = deepcopy(new_data) + expected_quantized_tensor = tensor_cls(ref_new_data, dtype=dtype) + + new_data = tensor_cls(new_data, dtype=dtype) if is_quantized else new_data + quant_tensor.data = new_data + assert dequant_tensor.data.data_ptr() == new_data.data.data_ptr() + + assert quant_tensor.data.dtype == QTYPE_TO_DTYPE[dtype] + assert torch.equal(quant_tensor, expected_quantized_tensor) + + if is_quantized: + # if tensor_cls == FP8Tensor: + # dequant_tensor = convert_tensor_from_fp8(quant_tensor, fp8_tensor.fp8_meta, torch.float32) + # else: + # dequantized_tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) + + dequant_tensor = quant_tensor.to(torch.float32) + assert torch.allclose(dequant_tensor, ref_new_data, rtol=RTOL, atol=ATOL) + + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +# ) +# @pytest.mark.parametrize("is_quantized", [True, False]) +# def test_setting_None_data_for_fp8_and_fp16_tensor(tensor_cls, dtype, is_quantized): +# tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") +# fp8_tensor = tensor_cls(tensor, dtype=dtype) + +# fp8_tensor.set_data(None) + +# assert fp8_tensor is None +# assert fp8_tensor.data is None + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +# ) +# def test_quantize_overflow_fp8_and_fp16_tensor(tensor_cls, dtype): +# tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") +# tensor[0, 0] = torch.tensor(float("inf")) +# fp8_tensor = tensor_cls(tensor, dtype) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_zero_out_data_of_fp8_and_fp16_tensor(tensor_cls, dtype): + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + quant_tensor = tensor_cls(tensor, dtype=dtype) + + quant_tensor.zero_() + + assert torch.equal(quant_tensor, torch.zeros_like(quant_tensor)) + + dequant_tensor = quant_tensor.to(torch.float32) + assert torch.equal(dequant_tensor, torch.zeros_like(tensor)) + + +# NOTE: add testing based on tensor metadata +@pytest.mark.parametrize("is_meta_the_same", [True, False]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_fp8_and_fp16_tensor_equality_based_on_tensor_value(is_meta_the_same, dtype): + # TODO(xrsrke): support torch.equal for FP8Tensor + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = FP8Tensor(tensor, dtype=dtype) + ref_fp8_tensor = FP8Tensor(ref_tensor, dtype=dtype) + + if not is_meta_the_same: + fp8_tensor.fp8_meta.scale = ref_fp8_tensor.fp8_meta.scale * 2 + + assert (fp8_tensor == ref_fp8_tensor) is is_meta_the_same + + new_data = torch.randn(tensor.shape, dtype=torch.float32, device="cuda") + ref_fp8_tensor.data = new_data + + assert not fp8_tensor == ref_fp8_tensor + + +# TODO(xrsrke): test it has all the methods of torch.Tensor + +# TODO(xrsrke): test it has all the attributes of its input tensor + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_serialize_fp8_tensor(tensor_cls, dtype): + test_context = TestContext() + store_folder = test_context.get_auto_remove_tmp_dir() + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + + fp8_tensor = tensor_cls(tensor, dtype=dtype) + + torch.save(fp8_tensor, f"{store_folder}/fp8_tensor.pt") + torch.load(f"{store_folder}/fp8_tensor.pt") diff --git a/tests/fp8/test_tensor.py b/tests/fp8/test_tensor.py index 40608d34..237e89d5 100644 --- a/tests/fp8/test_tensor.py +++ b/tests/fp8/test_tensor.py @@ -297,6 +297,8 @@ def test_transpose_fp8_tensor(tensor_cls, dtype): ref_fp8_tensor = tensor_cls(deepcopy(tensor), dtype) transposed_fp8_tensor = fp8_tensor.transpose_fp8() + # transposed_fp8_tensor = fp8_tensor.t() + # transposed_fp8_tensor = torch.transpose(fp8_tensor, 0, 1) # NOTE: we expect the original tensor to be the same assert fp8_tensor == ref_fp8_tensor diff --git a/tests/nanoset/test_build_nanoset_dataloader.py b/tests/nanoset/test_build_nanoset_dataloader.py index 113c545c..f7726262 100644 --- a/tests/nanoset/test_build_nanoset_dataloader.py +++ b/tests/nanoset/test_build_nanoset_dataloader.py @@ -8,7 +8,6 @@ import numpy as np import pytest -from helpers.context import TestContext from helpers.data import ( assert_batch_dataloader, assert_nanoset_sync_across_all_ranks, @@ -22,6 +21,7 @@ from nanotron.data.nanoset import Nanoset from nanotron.data.utils import count_dataset_indexes, normalize from nanotron.parallel import ParallelContext +from nanotron.testing.utils import TestContext from nanotron.utils import main_rank_first from transformers import AutoTokenizer diff --git a/tests/pytest.ini b/tests/pytest.ini index 0e0b2653..c0c6f3eb 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,4 +1,4 @@ [pytest] -addopts=-n 35 +; addopts=-n 10 markers = fa2: FA2-related diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 329ff279..62e2a584 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -1,6 +1,5 @@ import pytest import torch -from helpers.context import TestContext from helpers.dummy import dummy_infinite_data_loader, init_dummy_model from helpers.utils import ( available_gpus, @@ -33,6 +32,7 @@ save_weights, ) from nanotron.serialize.metadata import TensorMetadata +from nanotron.testing.utils import TestContext from torch.nn.parallel import DistributedDataParallel