diff --git a/.gitignore b/.gitignore
index cbc04eaf..2da7cc5a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -163,3 +163,8 @@ cython_debug/
checkpoints/
wandb/
+
+*.csv
+*.html
+src/nanotron/.test_cache/
+log/
diff --git a/benchmark/fp8_tp_speed.py b/benchmark/fp8_tp_speed.py
index bba7c1f3..dbd58f13 100644
--- a/benchmark/fp8_tp_speed.py
+++ b/benchmark/fp8_tp_speed.py
@@ -1,5 +1,4 @@
import argparse
-import itertools
import pandas as pd
import torch
@@ -11,16 +10,13 @@
# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/
h100_peak_flops_float32 = 67e12
+# NOTE: without sparsity
h100_peak_flops_fp16_tc = 989e12
h100_peak_tops_float8_tc = 1979e12
-
-# def color_scale(val, min_val, max_val):
-# """Generate a color scale from white to dark blue based on value."""
-# if pd.isna(val):
-# return 'background-color: white'
-# normalized = (val - min_val) / (max_val - min_val) if max_val != min_val else 0
-# return f'background-color: rgba(0, 0, 139, {normalized:.2f}); color: {"white" if normalized > 0.5 else "black"}'
+# # NOTE: with sparity
+# h100_peak_flops_fp16_tc = 1979e12
+# h100_peak_tops_float8_tc = 3958e12
def color_scale(val, min_val, max_val, metric_type="default"):
@@ -43,55 +39,6 @@ def color_scale(val, min_val, max_val, metric_type="default"):
return f"{color}; color: {text_color}"
-# [Previous benchmark_fn_in_sec and run functions remain the same...]
-
-# def create_html_table(df, exp_number, tp_size):
-# # Style the dataframe
-# styled_df = df.style.format({
-# 'FP8_time_ms': '{:.2f}',
-# 'BF16_time_ms': '{:.2f}',
-# 'FP8_TFLOPS': '{:.2f}',
-# 'BF16_TFLOPS': '{:.2f}',
-# 'FP8_efficiency_%': '{:.2f}',
-# 'BF16_efficiency_%': '{:.2f}',
-# 'Speedup': '{:.2f}'
-# })
-
-# # Apply color scaling to specific columns
-# styled_df = styled_df.apply(lambda x: pd.Series([
-# color_scale(v, x.min(), x.max(), 'time') if col.endswith('time_ms')
-# else color_scale(v, x.min(), x.max(), 'performance') if col.endswith('TFLOPS')
-# else color_scale(v, x.min(), x.max(), 'efficiency') if col.endswith('efficiency_%')
-# else color_scale(v, x.min(), x.max()) if col == 'Speedup'
-# else '' for v in x
-# ]), axis=0)
-
-# # Generate HTML
-# html = f'''
-#
-#
-#
-#
-#
-#
-# {styled_df.to_html()}
-#
-#
-# '''
-
-# with open(f'{exp_number}_benchmark_results_tp{tp_size}.html', 'w') as f:
-# f.write(html)
-
-
def create_html_table(df, exp_number, tp_size):
def style_df(df):
# Create an empty DataFrame with the same shape as the input
@@ -199,6 +146,23 @@ def run_linear(input, M, N, K, parallel_context, include_backward=False):
sharded_output.sum().backward()
+# def parse_args():
+# parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions")
+# parser.add_argument("--exp_number", type=str, help="Experiment number")
+# parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size")
+# # parser.add_argument(
+# # "--dimensions",
+# # type=str,
+# # default="4096,16384,32768,28672,49152",
+# # help="Comma-separated list of dimensions to test. Used when M/K/N are not explicitly provided",
+# # )
+# # Add the missing argument definitions
+# parser.add_argument("--m_size", type=int, help="Explicitly set M dimension")
+# parser.add_argument("--k_size", type=int, help="Explicitly set K dimension")
+# parser.add_argument("--n_size", type=int, help="Explicitly set N dimension")
+# return parser.parse_args()
+
+
def parse_args():
parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions")
parser.add_argument("--exp_number", type=str, help="Experiment number")
@@ -207,14 +171,17 @@ def parse_args():
"--dimensions",
type=str,
default="4096,16384,32768,28672,49152",
- help="Comma-separated list of dimensions to test",
+ help="Comma-separated list of dimensions to test. Used when M/K/N are not explicitly provided",
)
+ parser.add_argument("--m_size", type=int, help="Explicitly set M dimension")
+ parser.add_argument("--k_size", type=int, help="Explicitly set K dimension")
+ parser.add_argument("--n_size", type=int, help="Explicitly set N dimension")
return parser.parse_args()
-def benchmark_linear_operations(M, N, K, parallel_context, include_backward):
- input = torch.randn(M, K, device="cuda", requires_grad=False)
- bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=False, dtype=torch.bfloat16)
+def benchmark_linear_operations(M, N, K, parallel_context, include_backward, requires_grad):
+ input = torch.randn(M, K, device="cuda", requires_grad=requires_grad)
+ bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=requires_grad, dtype=torch.bfloat16)
# Benchmark FP8
fp8_time = benchmark_fn_in_sec(run_fp8_linear, input, M, N, K, parallel_context, include_backward)
@@ -223,13 +190,18 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward):
bfloat16_time = benchmark_fn_in_sec(run_linear, bfloat16_input, M, N, K, parallel_context, include_backward)
# Calculate FLOPS
- # Each linear operation performs 2*M*N*K FLOPs (multiply-add)
total_flops = 2 * M * N * K // parallel_context.tensor_parallel_size
+ if include_backward:
+ # Gradient with Respect to Parameters
+ total_flops += 2 * M * N * K // parallel_context.tensor_parallel_size
+ if requires_grad:
+ # Gradient with Respect to Inputs
+ total_flops += 2 * M * N * K // parallel_context.tensor_parallel_size
fp8_tflops = (total_flops / fp8_time) / 1e12
bfloat16_tflops = (total_flops / bfloat16_time) / 1e12
- # Calculate efficiency compared to peak performance
+ # Calculate efficiency
fp8_efficiency = (fp8_tflops / (h100_peak_tops_float8_tc / 1e12)) * 100
bfloat16_efficiency = (bfloat16_tflops / (h100_peak_flops_fp16_tc / 1e12)) * 100
@@ -238,12 +210,13 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward):
"N": N,
"K": K,
"Include_Backward": include_backward,
+ "Input_Requires_Grad": requires_grad,
"FP8_time_ms": fp8_time * 1000,
"BF16_time_ms": bfloat16_time * 1000,
"FP8_TFLOPS": fp8_tflops,
"BF16_TFLOPS": bfloat16_tflops,
- "FP8_efficiency_%": fp8_efficiency,
- "BF16_efficiency_%": bfloat16_efficiency,
+ "FP8_MFU%": fp8_efficiency,
+ "BF16_MFU%": bfloat16_efficiency,
"Speedup": bfloat16_time / fp8_time,
}
@@ -252,79 +225,61 @@ def benchmark_linear_operations(M, N, K, parallel_context, include_backward):
torch.backends.cudnn.benchmark = True
args = parse_args()
-
dimensions = [int(d.strip()) for d in args.dimensions.split(",")]
- TP_SIZE = args.tp_size
- EXP_NUMBER = args.exp_number
-
- results = []
- total = len(list(itertools.product(dimensions, dimensions, dimensions)))
- parallel_context = ParallelContext(data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=TP_SIZE)
+ parallel_context = ParallelContext(
+ data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=args.tp_size
+ )
- # Run benchmarks and collect results
results = []
- i = 0
- for M, N, K in itertools.product(dimensions, dimensions, dimensions):
- i += 1
- # Run forward-only case
- result = benchmark_linear_operations(M, N, K, parallel_context, include_backward=False)
+ # combinations = list(itertools.product(
+ # dimensions, # M
+ # dimensions, # N
+ # dimensions, # K
+ # [False, True], # include_backward
+ # [False, True] # requires_grad
+ # ))
+ # Pair dimensions index-wise
+ # combinations = []
+ # for i in range(len(dimensions)):
+ # M = dimensions[i]
+ # N = dimensions[i]
+ # K = dimensions[i]
+ # # For each dimension pair, test all combinations of include_backward and requires_grad
+ # for include_backward in [False, True]:
+ # for requires_grad in [False, True]:
+ # combinations.append((M, N, K, include_backward, requires_grad))
+
+ combinations = []
+ # Check if explicit M, K, N dimensions are provided
+ if all(dim is not None for dim in [args.m_size, args.k_size, args.n_size]):
+ # Use explicitly provided dimensions
+ for include_backward in [False, True]:
+ # for requires_grad in [False, True]:
+ for requires_grad in [False]:
+ combinations.append((args.m_size, args.n_size, args.k_size, include_backward, requires_grad))
+ else:
+ # Use dimensions from the --dimensions argument
+ dimensions = [int(d.strip()) for d in args.dimensions.split(",")]
+ for i in range(len(dimensions)):
+ M = dimensions[i]
+ N = dimensions[i]
+ K = dimensions[i]
+ for include_backward in [False, True]:
+ # for requires_grad in [False, True]:
+ for requires_grad in [False]:
+ combinations.append((M, N, K, include_backward, requires_grad))
+
+ total = len(combinations)
+
+ for i, (M, N, K, include_backward, requires_grad) in enumerate(combinations, 1):
+ result = benchmark_linear_operations(M, N, K, parallel_context, include_backward, requires_grad)
results.append(result)
- print(f"Experiment {i}/{total} complete (Forward-only)")
-
- # Run forward+backward case
- result = benchmark_linear_operations(M, N, K, parallel_context, include_backward=True)
- results.append(result)
- print(f"Experiment {i}/{total} complete (Forward+Backward)")
+ print(f"Experiment {i}/{total} complete")
df = pd.DataFrame(results)
- df = df.round(2) # Round to 2 decimal places
- df = df.sort_values(by=["M", "N", "K", "Include_Backward"])
+ df = df.round(2)
+ df = df.sort_values(by=["M", "N", "K", "Include_Backward", "Input_Requires_Grad"])
print(df)
-
- # # Define columns to color and their respective color scales
- # color_columns = {
- # 'FP8_time_ms': 'Reds',
- # 'BF16_time_ms': 'Blues',
- # 'FP8_TFLOPS': 'Greens',
- # 'BF16_TFLOPS': 'Purples',
- # 'FP8_efficiency_%': 'Oranges',
- # 'BF16_efficiency_%': 'Viridis',
- # 'Speedup': 'RdYlBu'
- # }
-
- # # Create the table
- # fig = go.Figure(data=[go.Table(
- # header=dict(
- # values=list(df.columns),
- # fill_color='lightgrey',
- # align='left',
- # font=dict(size=12, color='black')
- # ),
- # cells=dict(
- # values=[df[col] for col in df.columns],
- # align='left',
- # font=dict(size=11),
- # # Format cells with colors for numeric columns
- # fill_color=[
- # 'white' if col not in color_columns else
- # [f'rgba({int(255*i)}, {int(255*i)}, {int(255*i)}, 0.5)'
- # for i in (df[col]-df[col].min())/(df[col].max()-df[col].min())]
- # for col in df.columns
- # ]
- # )
- # )])
-
- # # Update layout
- # fig.update_layout(
- # title=f'Benchmark Results (TP_SIZE={TP_SIZE})',
- # width=1200,
- # height=800,
- # margin=dict(l=20, r=20, t=40, b=20)
- # )
-
- # # Save the interactive HTML file
- # fig.write_html(f'{EXP_NUMBER}_benchmark_results_tp{TP_SIZE}.html')
- create_html_table(df, EXP_NUMBER, TP_SIZE)
-
- print(f"\nResults have been saved to {EXP_NUMBER}_benchmark_results_tp{TP_SIZE}.html")
+ create_html_table(df, args.exp_number, args.tp_size)
+ print(f"\nResults have been saved to {args.exp_number}_benchmark_results_tp{args.tp_size}.html")
diff --git a/examples/config_fp8_llama.yaml b/examples/config_fp8_llama.yaml
new file mode 100644
index 00000000..fbbad98f
--- /dev/null
+++ b/examples/config_fp8_llama.yaml
@@ -0,0 +1,109 @@
+checkpoints:
+ checkpoint_interval: 10
+ checkpoints_path: checkpoints
+ checkpoints_path_is_shared_file_system: false
+ resume_checkpoint_path: null
+ save_initial_state: false
+data_stages:
+- data:
+ dataset:
+ dataset_overwrite_cache: false
+ dataset_processing_num_proc_per_process: 1
+ hf_dataset_config_name: null
+ hf_dataset_or_datasets: stas/openwebtext-10k
+ hf_dataset_splits: train
+ text_column_name: text
+ num_loading_workers: 1
+ seed: 42
+ name: Stable Training Stage
+ start_training_step: 1
+- data:
+ dataset:
+ dataset_overwrite_cache: false
+ dataset_processing_num_proc_per_process: 1
+ hf_dataset_config_name: null
+ hf_dataset_or_datasets: stas/openwebtext-10k
+ hf_dataset_splits: train
+ text_column_name: text
+ num_loading_workers: 1
+ seed: 42
+ name: Annealing Phase
+ start_training_step: 10
+general:
+ benchmark_csv_path: null
+ consumed_train_samples: null
+ ignore_sanity_checks: true
+ project: debug
+ run: tiny_llama_%date_%jobid
+ seed: 42
+ step: null
+lighteval: null
+logging:
+ iteration_step_info_interval: 1
+ log_level: info
+ log_level_replica: info
+model:
+ ddp_bucket_cap_mb: 25
+ dtype: int8
+ init_method:
+ std: 0.025
+ make_vocab_size_divisible_by: 1
+ model_config:
+ bos_token_id: 1
+ eos_token_id: 2
+ hidden_act: silu
+ hidden_size: 16
+ initializer_range: 0.02
+ intermediate_size: 64
+ is_llama_config: true
+ max_position_embeddings: 256
+ num_attention_heads: 4
+ num_hidden_layers: 2
+ num_key_value_heads: 4
+ pad_token_id: null
+ pretraining_tp: 1
+ rms_norm_eps: 1.0e-05
+ rope_scaling: null
+ tie_word_embeddings: true
+ use_cache: true
+ vocab_size: 256
+optimizer:
+ accumulate_grad_in_fp32: true
+ clip_grad: 1.0
+ learning_rate_scheduler:
+ learning_rate: 0.0003
+ lr_decay_starting_step: null
+ lr_decay_steps: 13
+ lr_decay_style: cosine
+ lr_warmup_steps: 2
+ lr_warmup_style: linear
+ min_decay_lr: 1.0e-05
+ optimizer_factory:
+ adam_beta1: 0.9
+ adam_beta2: 0.95
+ adam_eps: 1.0e-08
+ name: adamW
+ torch_adam_is_fused: true
+ weight_decay: 0.01
+ zero_stage: 0
+parallelism:
+ dp: 2
+ expert_parallel_size: 1
+ pp: 1
+ pp_engine: 1f1b
+ tp: 4
+ tp_linear_async_communication: false
+ tp_mode: ALL_REDUCE
+profiler: null
+tokenizer:
+ tokenizer_max_length: null
+ tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel
+ tokenizer_revision: null
+tokens:
+ batch_accumulation_per_replica: 1
+ limit_test_batches: 0
+ limit_val_batches: 0
+ micro_batch_size: 2
+ sequence_length: 256
+ train_steps: 15
+ val_check_interval: -1
diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml
index 58645e2d..0fd639d5 100644
--- a/examples/config_tiny_llama.yaml
+++ b/examples/config_tiny_llama.yaml
@@ -44,7 +44,7 @@ logging:
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
- dtype: bfloat16
+ dtype: float8
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
diff --git a/examples/llama/tests/test_conversion.py b/examples/llama/tests/test_conversion.py
index b5ce3529..8fa02a81 100644
--- a/examples/llama/tests/test_conversion.py
+++ b/examples/llama/tests/test_conversion.py
@@ -15,6 +15,7 @@
from nanotron.models.base import init_on_device_and_dtype
from nanotron.models.llama import LlamaForTraining
from nanotron.parallel import ParallelContext
+from nanotron.testing.utils import TestContext
from nanotron.trainer import mark_tied_parameters
from examples.llama.convert_hf_to_nanotron import convert_checkpoint_and_save as convert_hf_to_nt_and_save
@@ -22,7 +23,6 @@
from examples.llama.convert_nanotron_to_hf import convert_checkpoint_and_save as convert_nt_to_hf_and_save
from examples.llama.convert_nanotron_to_hf import convert_nt_to_hf, get_hf_config
from examples.llama.convert_weights import load_nanotron_model, make_parallel_config
-from tests.helpers.context import TestContext
from tests.helpers.utils import init_distributed
CONFIG = NanotronLlamaConfig(
@@ -141,7 +141,7 @@ def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor):
logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2)
logits_hf = model_hf(input_ids).logits
assert logits_nt.size() == logits_hf.size()
- torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL)
+ torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL)
def test_hf_to_nt(input_ids: torch.Tensor):
diff --git a/scripts/01_standalone_fp8_tensor.py b/scripts/01_standalone_fp8_tensor.py
new file mode 100644
index 00000000..6de99488
--- /dev/null
+++ b/scripts/01_standalone_fp8_tensor.py
@@ -0,0 +1,73 @@
+import torch
+import torch.utils._pytree as pytree
+
+
+class QuantTensor(torch.Tensor):
+ @staticmethod
+ def __new__(cls, data: torch.Tensor):
+ return torch.Tensor._make_wrapper_subclass(
+ cls,
+ data.shape,
+ device=data.device,
+ )
+
+ # @torch._dynamo.disable
+ def __init__(self, data: torch.Tensor):
+ self._data = data
+
+ def __tensor_flatten__(self):
+ return ["_data"], []
+
+ @classmethod
+ def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None):
+ return cls(tensor_data_dict["_data"], *tensor_attributes)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(data={self._data})"
+
+ # @classmethod
+ # def __torch_function__(cls, func, types, args=(), kwargs=None):
+ # kwargs = kwargs or dict()
+
+ # if func is F.linear:
+ # return _BitNetTrainingLinear.apply(*args, **kwargs)
+
+ # with torch._C.DisableTorchFunctionSubclass():
+ # return func(*args, **kwargs)
+
+ # adapted from FP8 implementation of WeightWithDynamicFloat8CastTensor
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args, kwargs):
+ out = func(
+ *pytree.tree_map_only(cls, lambda x: x._data, args),
+ **pytree.tree_map_only(cls, lambda x: x._data, kwargs),
+ )
+
+ if func is torch.ops.aten.copy_.default:
+ # return original object
+ return args[0]
+ elif func in {
+ torch.ops.aten.t.default,
+ torch.ops.aten.detach.default,
+ torch.ops.aten.empty_like.default,
+ torch.ops.aten.new_zeros.default,
+ torch.ops.aten.slice.Tensor,
+ torch.ops.aten.view.default,
+ torch.ops.aten.as_strided.default,
+ torch.ops.aten._to_copy.default,
+ torch.ops.aten._pin_memory.default,
+ torch.ops.aten.split.Tensor,
+ torch.ops.aten.clone.default,
+ }:
+ # return new wrapped object
+ return pytree.tree_map_only(torch.Tensor, lambda x: cls(x), out)
+ else:
+ # return new unwrapped object
+ return out
+
+
+if __name__ == "__main__":
+ tensor = torch.randn(2, 2, device="cuda")
+ quant_tensor = QuantTensor(tensor)
+
+ assert 1 == 1
diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py
index f4c07146..bf8407fc 100644
--- a/src/nanotron/config/utils_config.py
+++ b/src/nanotron/config/utils_config.py
@@ -75,8 +75,9 @@ def serialize(data) -> dict:
torch.complex128: "complex128",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
- torch.uint8: "uint8",
- torch.int8: "int8",
+ # torch.uint8: "uint8",
+ # torch.int8: "int8",
+ torch.int8: "float8",
torch.int16: "int16",
torch.int32: "int32",
torch.int64: "int64",
diff --git a/src/nanotron/fp8/_tensor.py b/src/nanotron/fp8/_tensor.py
new file mode 100644
index 00000000..0d288850
--- /dev/null
+++ b/src/nanotron/fp8/_tensor.py
@@ -0,0 +1,90 @@
+import torch
+
+from nanotron.fp8.meta import FP8Meta
+
+
+# This is within core, the end user never have to look at this
+class _WrapperTensor(torch.Tensor):
+ @staticmethod
+ def __new__(cls, *args, **kwargs):
+ t, kwargs = cls.get_wrapper_properties(*args, **kwargs)
+ if "size" not in kwargs:
+ size = t.size()
+ else:
+ size = kwargs["size"]
+ del kwargs["size"]
+ if "dtype" not in kwargs:
+ kwargs["dtype"] = t.dtype
+ if "layout" not in kwargs:
+ kwargs["layout"] = t.layout
+ if "device" not in kwargs:
+ kwargs["device"] = t.device
+ if "requires_grad" not in kwargs:
+ kwargs["requires_grad"] = False
+ # Ignore memory_format and pin memory for now as I don't know how to
+ # safely access them on a Tensor (if possible??)
+
+ wrapper = torch.Tensor._make_wrapper_subclass(cls, size, **kwargs)
+ wrapper._validate_methods()
+ return wrapper
+
+ @classmethod
+ def get_wrapper_properties(cls, *args, **kwargs):
+ # Should return both an example Tensor and a dictionaly of kwargs
+ # to override any of that example Tensor's properly.
+ # This is very similar to the `t.new_*(args)` API
+ raise NotImplementedError("You need to implement get_wrapper_properties")
+
+ def _validate_methods(self):
+ # Skip this if not in debug mode?
+ # Changing these on the python side is wrong as it would not be properly reflected
+ # on the c++ side
+ # This doesn't catch attributes set in the __init__
+ forbidden_overrides = ["size", "stride", "dtype", "layout", "device", "requires_grad"]
+ for el in forbidden_overrides:
+ if getattr(self.__class__, el) is not getattr(torch.Tensor, el):
+ raise RuntimeError(
+ f"Subclass {self.__class__.__name__} is overwriting the "
+ f"property {el} but this is not allowed as such change would "
+ "not be reflected to c++ callers."
+ )
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}({self.__dict__})"
+
+
+from torch.utils._pytree import tree_map
+
+
+class _FP8Tensor(_WrapperTensor):
+ @classmethod
+ def get_wrapper_properties(cls, diag):
+ # return diag, {"size": diag.size() + diag.size()}
+ return diag, {}
+
+ def __init__(self, data: torch.Tensor, fp8_meta: FP8Meta):
+ self._tensor = data
+
+ @property
+ def data(self):
+ return self._tensor
+
+ @data.setter
+ def data(self, data):
+ self._tensor = data
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ def unwrap(e):
+ return torch.diag(e._diag) if isinstance(e, _FP8Tensor) else e
+
+ def wrap(e):
+ return _FP8Tensor(torch.diag(e)) if isinstance(e, torch.Tensor) else e
+
+ rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
+ return rs
+
+
+class FP8E4M3Tensor(_FP8Tensor):
+ def __init__(self):
+ pass
diff --git a/src/nanotron/fp8/constant_recipe.py b/src/nanotron/fp8/constant_recipe.py
new file mode 100644
index 00000000..4173ca7b
--- /dev/null
+++ b/src/nanotron/fp8/constant_recipe.py
@@ -0,0 +1,12 @@
+from torch import nn
+
+from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding
+
+MODULE_NAMES_THAT_NOT_FP8 = [
+ "token_embedding",
+ "input_layernorm",
+ "post_attention_layernorm",
+ "final_layer_norm",
+ "lm_head",
+]
+MODULES_THAT_IN_FLOAT16 = [TensorParallelEmbedding, nn.LayerNorm]
diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py
index 0bc31b09..4936e42c 100644
--- a/src/nanotron/fp8/linear.py
+++ b/src/nanotron/fp8/linear.py
@@ -12,7 +12,6 @@
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.recipe import FP8LinearRecipe
from nanotron.fp8.tensor import FP8Tensor
-from nanotron.parallel.parameters import get_data_from_param
@dataclass
@@ -55,7 +54,10 @@ def __init__(
assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
self.weight = quant_w
- assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
+ if self.name == "model.decoder.0.attention.qkv_proj":
+ assert 1 == 1
+
+ assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}, name: {self.name}"
if self.bias is not None:
self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype))
@@ -66,6 +68,7 @@ def __init__(
def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:
import nanotron.fp8.functional as F
+ from nanotron.parallel.parameters import get_data_from_param
return F.linear(
input=input,
diff --git a/src/nanotron/fp8/meta.py b/src/nanotron/fp8/meta.py
index f9f3afbe..3064cf1f 100644
--- a/src/nanotron/fp8/meta.py
+++ b/src/nanotron/fp8/meta.py
@@ -30,6 +30,7 @@ class FP8Meta:
# TODO(xrsrke): change to Literal[torch.int8, torch.uint8]
dtype: DTypes
interval: int
+ sync_amax: bool = False
@property
def te_dtype(self) -> tex.DType:
diff --git a/src/nanotron/fp8/tensor.py b/src/nanotron/fp8/tensor.py
index 80fbc533..368ae6fc 100644
--- a/src/nanotron/fp8/tensor.py
+++ b/src/nanotron/fp8/tensor.py
@@ -78,6 +78,16 @@ def __new__(
return obj
+ def __init__(
+ self,
+ tensor: torch.Tensor,
+ dtype: Optional[DTypes] = None,
+ interval: Optional[int] = 1,
+ fp8_meta: Optional[FP8Meta] = None,
+ sync: bool = False,
+ ) -> None:
+ pass
+
@staticmethod
# @torch.no_grad()
def _get_metadata(tensor: torch.Tensor, dtype: DTypes, interval: int, sync: bool) -> "FP8Meta":
@@ -188,6 +198,20 @@ def clone(self) -> FP8Tensor:
tensor.fp8_meta = deepcopy(self.fp8_meta)
return tensor
+ # def __torch_function__(self, func, types, args=(), kwargs=None):
+ # return super().__torch_function__(func, types, args, kwargs)
+
+ # @classmethod
+ # def __torch_function__(cls, func, types, args, kwargs=None):
+ # kwargs = kwargs or {}
+ # if func is torch.transpose:
+ # assert type(args[0]) == cls
+ # assert type(args[1]) == type(args[2]) == int
+ # # return CustomMaskedSum.apply(*args, **kwargs)
+ # assert 1 == 1
+ # else:
+ # super().__torch_function__(func, types, args, kwargs)
+
class FP8Tensor(LowPrecisionTensor):
"""FP8 Tensor."""
diff --git a/src/nanotron/fp8/utils.py b/src/nanotron/fp8/utils.py
index b1934aa0..0dc7c3d3 100644
--- a/src/nanotron/fp8/utils.py
+++ b/src/nanotron/fp8/utils.py
@@ -313,10 +313,10 @@ def is_convert_to_fp16(module) -> bool:
IS_CONVERT_TO_FLOAT16 = False
name_of_modules_not_in_fp16 = get_modules_not_in_fp16()
- if hasattr(module, "name") and "lm_head" in module.name:
- assert 1 == 1
+ # if hasattr(module, "name") and "lm_head" in module.name:
+ # assert 1 == 1
- if constants.CONFIG.fp8.model is None:
+ if constants.CONFIG is not None and constants.CONFIG.fp8.model is None:
if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16):
IS_CONVERT_TO_FLOAT16 = True
else:
diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py
index 14ac6908..78fd3965 100644
--- a/src/nanotron/models/base.py
+++ b/src/nanotron/models/base.py
@@ -71,7 +71,7 @@ def get_embeddings_lm_head_tied_names(self) -> list[str]:
Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"]
"""
return []
-
+
def get_named_params_without_weight_decay(self) -> List[str]:
"""Return a list of named parameters that should not have weight decay applied to them."""
return []
@@ -237,11 +237,118 @@ def build_model(
return model
+old_register_parameter = nn.Module.register_parameter
+old_register_buffer = nn.Module.register_buffer
+
+
+def _register_empty_parameter_for_fp8(module, name, param):
+ old_register_parameter(module, name, param)
+
+ # from nanotron.fp8.constant_recipe import MODULES_THAT_IN_FLOAT16
+ # from nanotron.fp8.utils import get_modules_not_in_fp16
+
+ # MODULES_THAT_IN_FLOAT16 = [TensorParallelEmbedding, nn.LayerNorm]
+ # name_of_modules_not_in_fp16 = get_modules_not_in_fp16()
+
+ if param is not None:
+ from nanotron.fp8.utils import is_convert_to_fp16
+
+ # IS_CONVERT_TO_FLOAT16 = False
+ # if constants.CONFIG.fp8.model is None:
+ # if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16):
+ # IS_CONVERT_TO_FLOAT16 = True
+ # else:
+ # if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16) or (
+ # hasattr(module, "name") and module.name not in name_of_modules_not_in_fp16
+ # ):
+ # IS_CONVERT_TO_FLOAT16 = True
+ is_convert_to_float16 = is_convert_to_fp16(module)
+
+ if is_convert_to_float16:
+ import nanotron
+ from nanotron import constants
+
+ if constants.CONFIG is not None:
+ from typing import cast
+
+ from nanotron.config.fp8_config import FP8Args
+
+ fp8_config = cast(FP8Args, constants.CONFIG.fp8)
+ resid_dtype = fp8_config.resid_dtype
+ else:
+ resid_dtype = nanotron.fp8.constants.FP8LM_LINEAR_RECIPE.accum_dtype
+
+ param.data = param.data.to(torch.device("cuda"), resid_dtype)
+ else:
+ param.data = param.data.to(torch.device("cuda"))
+
+
+def _register_empty_buffer_for_fp8(module, name, buffer, persistent=True):
+ old_register_buffer(module, name, buffer, persistent=persistent)
+
+ # from nanotron.fp8.constant_recipe import MODULES_THAT_IN_FLOAT16
+ # from nanotron.fp8.utils import get_modules_not_in_fp16
+
+ # name_of_modules_not_in_fp16 = get_modules_not_in_fp16()
+
+ if buffer is not None:
+ # IS_CONVERT_TO_FLOAT16 = False
+
+ # # NOTE: convert all modules in FP8 except MODULES_THAT_IN_FLOAT16
+ # # if fp8.model is None
+ # if constants.CONFIG.fp8.model is None:
+ # if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16):
+ # IS_CONVERT_TO_FLOAT16 = True
+ # else:
+ # if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16) or (
+ # hasattr(module, "name") and module.name not in name_of_modules_not_in_fp16
+ # ):
+ # IS_CONVERT_TO_FLOAT16 = True
+
+ from nanotron.fp8.utils import is_convert_to_fp16
+
+ is_convert_to_float16 = is_convert_to_fp16(module)
+
+ if is_convert_to_float16:
+
+ # from nanotron import constants
+ from nanotron.fp8 import constants
+
+ # fp8_config = cast(FP8Args, constants.CONFIG.fp8)
+ # fp8_config = cast(FP8Args, constants.CONFIG.fp8)
+ # module._buffers[name] = module._buffers[name].to(torch.device("cuda"), torch.float16)
+ module._buffers[name] = module._buffers[name].to(
+ torch.device("cuda"), constants.FP8LM_LINEAR_RECIPE.accum_dtype
+ )
+ else:
+ buffer.data = buffer.data.to(torch.device("cuda"))
+
+
+def _register_empty_parameter(module, name, param, device, dtype):
+ old_register_parameter(module, name, param)
+ if param is not None:
+ if isinstance(param, DTypeInvariantTensor):
+ # if param is DTypeInvariantTensor we should avoid updating it
+ param.data = param.data.to(device)
+ else:
+ param.data = param.data.to(device, dtype)
+
+
+def _register_empty_buffer(module, name, buffer, device, dtype, persistent=True):
+ old_register_buffer(module, name, buffer, persistent=persistent)
+ if buffer is not None:
+ if isinstance(buffer, DTypeInvariantTensor):
+ # if buffer is DTypeInvariantTensor we should avoid updating it
+ buffer.data = buffer.data.to(device)
+ else:
+ module._buffers[name] = module._buffers[name].to(device, dtype)
+
+
# TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does.
@contextmanager
def init_on_device_and_dtype(
device: torch.device = torch.device("cpu"),
- dtype: torch.dtype = torch.float,
+ dtype: torch.dtype = torch.float32,
):
"""
A context manager under which models are initialized with all parameters on the specified device.
@@ -258,28 +365,32 @@ def init_on_device_and_dtype(
from accelerate import init_on_device
with init_on_device_and_dtype(device=torch.device("cuda")):
tst = nn.Liner(100, 100) # on `cuda` device
+
+ NOTE: in order to initialize an hybrid fp8 properly, you should use this context manager
```
"""
+ from functools import wraps
+
+ def method_partial(func, *args, **kwargs):
+ @wraps(func)
+ def wrapper(self, *fargs, **fkwargs):
+ return func(self, *args, *fargs, **kwargs, **fkwargs)
+
+ return wrapper
+
old_register_parameter = nn.Module.register_parameter
old_register_buffer = nn.Module.register_buffer
- def register_empty_parameter(module, name, param):
- old_register_parameter(module, name, param)
- if param is not None:
- if isinstance(param, DTypeInvariantTensor):
- # if param is DTypeInvariantTensor we should avoid updating it
- param.data = param.data.to(device)
- else:
- param.data = param.data.to(device, dtype)
-
- def register_empty_buffer(module, name, buffer, persistent=True):
- old_register_buffer(module, name, buffer, persistent=persistent)
- if buffer is not None:
- if isinstance(buffer, DTypeInvariantTensor):
- # if buffer is DTypeInvariantTensor we should avoid updating it
- buffer.data = buffer.data.to(device)
- else:
- module._buffers[name] = module._buffers[name].to(device, dtype)
+ register_empty_parameter = (
+ _register_empty_parameter_for_fp8
+ if dtype == torch.int8
+ else method_partial(_register_empty_parameter, device=device, dtype=dtype)
+ )
+ register_empty_buffer = (
+ _register_empty_buffer_for_fp8
+ if dtype == torch.int8
+ else method_partial(_register_empty_buffer, device=device, dtype=dtype)
+ )
# Patch tensor creation
tensor_constructors_to_patch = {
@@ -289,8 +400,10 @@ def register_empty_buffer(module, name, buffer, persistent=True):
def patch_tensor_constructor(fn):
def wrapper(*args, **kwargs):
+ # NOTE: nanotron automatically sets the device and dtype of the tensor
+ # but for FP8 training, we initializes with float16 first
kwargs["device"] = device
- kwargs["dtype"] = dtype
+ kwargs["dtype"] = torch.float32 if dtype == torch.int8 else dtype
return fn(*args, **kwargs)
return wrapper
@@ -308,6 +421,77 @@ def wrapper(*args, **kwargs):
setattr(torch, torch_function_name, old_torch_function)
+# TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does.
+# @contextmanager
+# def init_on_device_and_dtype(
+# device: torch.device = torch.device("cpu"),
+# dtype: torch.dtype = torch.float,
+# ):
+# """
+# A context manager under which models are initialized with all parameters on the specified device.
+# Args:
+# device (`torch.device` defaults to `cpu`):
+# Device to initialize all parameters on.
+# dtype (`torch.dtype` defaults to `torch.float`):
+# Dtype to initialize all parameters on.
+# include_buffers (`bool`, defaults to `False`):
+# Whether or not to also default all buffers constructors given previous arguments.
+# Example:
+# ```python
+# import torch.nn as nn
+# from accelerate import init_on_device
+# with init_on_device_and_dtype(device=torch.device("cuda")):
+# tst = nn.Liner(100, 100) # on `cuda` device
+# ```
+# """
+# old_register_parameter = nn.Module.register_parameter
+# old_register_buffer = nn.Module.register_buffer
+
+# def register_empty_parameter(module, name, param):
+# old_register_parameter(module, name, param)
+# if param is not None:
+# if isinstance(param, DTypeInvariantTensor):
+# # if param is DTypeInvariantTensor we should avoid updating it
+# param.data = param.data.to(device)
+# else:
+# param.data = param.data.to(device, dtype)
+
+# def register_empty_buffer(module, name, buffer, persistent=True):
+# old_register_buffer(module, name, buffer, persistent=persistent)
+# if buffer is not None:
+# if isinstance(buffer, DTypeInvariantTensor):
+# # if buffer is DTypeInvariantTensor we should avoid updating it
+# buffer.data = buffer.data.to(device)
+# else:
+# module._buffers[name] = module._buffers[name].to(device, dtype)
+
+# # Patch tensor creation
+# tensor_constructors_to_patch = {
+# torch_function_name: getattr(torch, torch_function_name)
+# for torch_function_name in ["empty", "zeros", "ones", "full"]
+# }
+
+# def patch_tensor_constructor(fn):
+# def wrapper(*args, **kwargs):
+# kwargs["device"] = device
+# kwargs["dtype"] = dtype
+# return fn(*args, **kwargs)
+
+# return wrapper
+
+# try:
+# nn.Module.register_parameter = register_empty_parameter
+# nn.Module.register_buffer = register_empty_buffer
+# for torch_function_name in tensor_constructors_to_patch.keys():
+# setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
+# yield
+# finally:
+# nn.Module.register_parameter = old_register_parameter
+# nn.Module.register_buffer = old_register_buffer
+# for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
+# setattr(torch, torch_function_name, old_torch_function)
+
+
def check_model_has_grad(model: NanotronModel, parallel_context: "ParallelContext"):
"""Check that there's at least a parameter in current PP rank that has a gradient."""
for param in model.parameters():
diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py
index 88fb6bcb..e6e74ecb 100644
--- a/src/nanotron/models/llama.py
+++ b/src/nanotron/models/llama.py
@@ -35,10 +35,10 @@
from nanotron.parallel.pipeline_parallel.p2p import P2P
from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy
from nanotron.parallel.tensor_parallel.nn import (
- TensorParallelColumnLinear,
+ FP8TensorParallelColumnLinear,
+ FP8TensorParallelRowLinear,
TensorParallelEmbedding,
TensorParallelLinearMode,
- TensorParallelRowLinear,
)
from nanotron.random import RandomStates
from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator
@@ -207,6 +207,7 @@ def __init__(
config: LlamaConfig,
parallel_config: Optional[ParallelismArgs],
tp_pg: dist.ProcessGroup,
+ layer_idx: int,
):
super().__init__()
@@ -220,7 +221,8 @@ def __init__(
config.intermediate_size, # shape of gate_linear
config.intermediate_size, # shape of up_linear
)
- self.gate_up_proj = TensorParallelColumnLinear(
+ # self.gate_up_proj = TensorParallelColumnLinear(
+ self.gate_up_proj = FP8TensorParallelColumnLinear(
config.hidden_size,
2 * config.intermediate_size,
pg=tp_pg,
@@ -228,15 +230,18 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
- tp_recompute_allgather=parallel_config.tp_recompute_allgather,
+ name=f"model.decoder.{layer_idx}.mlp.gate_up_proj",
+ # tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
- self.down_proj = TensorParallelRowLinear(
+ # self.down_proj = TensorParallelRowLinear(
+ self.down_proj = FP8TensorParallelRowLinear(
config.intermediate_size,
config.hidden_size,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
+ name=f"model.decoder.{layer_idx}.mlp.down_proj",
)
self.split_silu_mul = GLUActivation(config.hidden_act)
@@ -381,7 +386,8 @@ def __init__(
config.num_key_value_heads * self.d_qk, # shape of k
config.num_key_value_heads * self.d_qk, # shape of v
)
- self.qkv_proj = TensorParallelColumnLinear(
+ # self.qkv_proj = TensorParallelColumnLinear(
+ self.qkv_proj = FP8TensorParallelColumnLinear(
self.d_model,
config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk,
pg=tp_pg,
@@ -389,7 +395,8 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
- tp_recompute_allgather=parallel_config.tp_recompute_allgather,
+ name=f"model.decoder.{layer_idx}.attention.qkv_proj",
+ # tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
if config.rope_interleaved:
@@ -411,13 +418,15 @@ def __init__(
dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved
)
- self.o_proj = TensorParallelRowLinear(
+ # self.o_proj = TensorParallelRowLinear(
+ self.o_proj = FP8TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
self.d_model,
pg=tp_pg,
mode=tp_mode,
bias=False,
async_communication=tp_linear_async_communication,
+ name=f"model.decoder.{layer_idx}.attention.o_proj",
)
self.attention = CoreAttention(
@@ -710,7 +719,7 @@ def __init__(
)
self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)
+ self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx)
self.recompute_layer = parallel_config.recompute_layer
@@ -856,7 +865,8 @@ def __init__(
self.lm_head = PipelineBlock(
p2p=self.p2p,
# Understand that this means that we return sharded logits that are going to need to be gathered
- module_builder=TensorParallelColumnLinear,
+ # module_builder=TensorParallelColumnLinear,
+ module_builder=FP8TensorParallelColumnLinear,
module_kwargs={
"in_features": config.hidden_size,
"out_features": config.vocab_size,
@@ -865,7 +875,7 @@ def __init__(
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
- "tp_recompute_allgather": parallel_config.tp_recompute_allgather,
+ # "tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
@@ -920,7 +930,8 @@ def get_block_compute_costs(self):
LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size
+ 3 * d_ff * model_config.hidden_size,
# This is the last lm_head
- TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
+ # TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
+ FP8TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size,
}
return block_compute_costs
diff --git a/src/nanotron/testing/parallel.py b/src/nanotron/testing/parallel.py
new file mode 100644
index 00000000..1dbe8a3c
--- /dev/null
+++ b/src/nanotron/testing/parallel.py
@@ -0,0 +1,159 @@
+import os
+import re
+from inspect import signature
+from typing import Callable
+
+import torch.cuda
+import torch.multiprocessing as mp
+from nanotron.parallel import ParallelContext
+from packaging import version
+
+
+def global_wrapper(rank, func, tp, pp, dp, port, kwargs):
+ def setup_dist_env(rank, world_size, port):
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["RANK"] = str(rank)
+ # NOTE: since we do unit tests in a
+ # single node => this is fine!
+ os.environ["LOCAL_RANK"] = str(rank)
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = str(port)
+
+ world_size = tp * pp * dp
+ setup_dist_env(rank, world_size, port)
+ parallel_context = ParallelContext(data_parallel_size=dp, pipeline_parallel_size=pp, tensor_parallel_size=tp)
+ func(parallel_context, **kwargs)
+
+
+def init_distributed(tp: int, dp: int, pp: int):
+ def _init_distributed(func):
+ def wrapper(**kwargs):
+ from nanotron.utils import find_free_port
+
+ world_size = tp * pp * dp
+ port = find_free_port()
+
+ # Note that kwargs needs to be passed as part of args in a way that can be unpacked
+ args = (func, tp, pp, dp, port, kwargs)
+ mp.spawn(global_wrapper, args=args, nprocs=world_size)
+
+ return wrapper
+
+ return _init_distributed
+
+
+def rerun_if_address_is_in_use(max_try: int = 500):
+ """
+ This function reruns a wrapped function if "address already in use" occurs
+ in testing spawned with torch.multiprocessing
+
+ Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L157
+
+ Usage::
+
+ @rerun_if_address_is_in_use()
+ def test_something():
+ ...
+
+ """
+ # check version
+ torch_version = version.parse(torch.__version__)
+ assert torch_version.major >= 1
+
+ # only torch >= 1.8 has ProcessRaisedException
+ if torch_version >= version.parse("1.8.0"):
+ exception = torch.multiprocessing.ProcessRaisedException
+ else:
+ exception = Exception
+
+ func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=max_try)
+ return func_wrapper
+
+
+def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 10) -> Callable:
+ """
+ A decorator on a function to re-run when an exception occurs.
+
+ Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L71
+
+ Usage::
+
+ # rerun for all kinds of exception
+ @rerun_on_exception()
+ def test_method():
+ print('hey')
+ raise RuntimeError('Address already in use')
+
+ # rerun for RuntimeError only
+ @rerun_on_exception(exception_type=RuntimeError)
+ def test_method():
+ print('hey')
+ raise RuntimeError('Address already in use')
+
+ # rerun for maximum 10 times if Runtime error occurs
+ @rerun_on_exception(exception_type=RuntimeError, max_try=10)
+ def test_method():
+ print('hey')
+ raise RuntimeError('Address already in use')
+
+ # rerun for infinite times if Runtime error occurs
+ @rerun_on_exception(exception_type=RuntimeError, max_try=None)
+ def test_method():
+ print('hey')
+ raise RuntimeError('Address already in use')
+
+ # rerun only the exception message is matched with pattern
+ # for infinite times if Runtime error occurs
+ @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$")
+ def test_method():
+ print('hey')
+ raise RuntimeError('Address already in use')
+
+ Args:
+ exception_type (Exception, Optional): The type of exception to detect for rerun
+ pattern (str, Optional): The pattern to match the exception message.
+ If the pattern is not None and matches the exception message,
+ the exception will be detected for rerun
+ max_try (int, Optional): Maximum reruns for this function. The default value is 5.
+ If max_try is None, it will rerun forever if exception keeps occurring
+ """
+
+ def _match_lines(lines, pattern):
+ for line in lines:
+ if re.match(pattern, line):
+ return True
+ return False
+
+ def _wrapper(func):
+ def _run_until_success(*args, **kwargs):
+ try_count = 0
+ assert max_try is None or isinstance(
+ max_try, int
+ ), f"Expected max_try to be None or int, but got {type(max_try)}"
+
+ while max_try is None or try_count < max_try:
+ try:
+ try_count += 1
+ ret = func(*args, **kwargs)
+ return ret
+ except exception_type as e:
+ error_lines = str(e).split("\n")
+ if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)):
+
+ print("Exception is caught, retrying...")
+ # when pattern is not specified, we always skip the exception
+ # when pattern is specified, we only skip when pattern is matched
+ continue
+ else:
+ print("Maximum number of attempts is reached or pattern is not matched, no more retrying...")
+ raise e
+
+ # Override signature
+ # otherwise pytest.mark.parameterize will raise the following error:
+ # function does not use argument xxx
+ sig = signature(func)
+ _run_until_success.__signature__ = sig
+
+ return _run_until_success
+
+ return _wrapper
diff --git a/src/nanotron/testing/utils.py b/src/nanotron/testing/utils.py
new file mode 100644
index 00000000..77fa5d70
--- /dev/null
+++ b/src/nanotron/testing/utils.py
@@ -0,0 +1,21 @@
+import shutil
+import uuid
+from functools import lru_cache
+from pathlib import Path
+
+
+class TestContext:
+ def __init__(self):
+ self._random_string = str(uuid.uuid1())
+ self._root_dir = Path(__file__).parent.parent / ".test_cache"
+ self._root_dir.mkdir(parents=True, exist_ok=True)
+
+ @lru_cache(maxsize=1)
+ def get_auto_remove_tmp_dir(self):
+ path = self._root_dir / self._random_string
+ path.mkdir(parents=True, exist_ok=True)
+ return path
+
+ def __del__(self):
+ path = self.get_auto_remove_tmp_dir()
+ shutil.rmtree(path)
diff --git a/tests/fp8/test_new_tensor.py b/tests/fp8/test_new_tensor.py
new file mode 100644
index 00000000..05912e11
--- /dev/null
+++ b/tests/fp8/test_new_tensor.py
@@ -0,0 +1,452 @@
+from copy import deepcopy
+from typing import cast
+
+import numpy as np
+import pytest
+import torch
+from nanotron.fp8.constants import (
+ FP8_ATOL_THRESHOLD,
+ FP8_RTOL_THRESHOLD,
+ FP16_ATOL_THRESHOLD,
+ FP16_RTOL_THRESHOLD,
+ QTYPE_TO_DTYPE,
+)
+from nanotron.fp8.dtypes import DTypes
+from nanotron.fp8.meta import FP8Meta
+from nanotron.fp8.tensor import FP8Tensor, FP16Tensor
+from nanotron.testing.utils import TestContext
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)]
+)
+@pytest.mark.parametrize("interval", [1, 5])
+def test_fp8_and_fp16_metadata(tensor_cls, dtype, interval):
+ import transformer_engine as te # noqa
+ import transformer_engine_torch as tex
+
+ tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda")
+ ref_tensor = deepcopy(tensor)
+
+ fp8_tensor = tensor_cls(tensor, dtype=dtype, interval=interval)
+ fp8_meta = cast(FP8Meta, fp8_tensor.fp8_meta)
+
+ # TODO(xrsrke): remove the fixed 1 factor
+ # it couples with the current implementation of FP8Meta
+ # because we initialize scale with 1
+ assert fp8_meta.amax == ref_tensor.abs().max()
+ assert isinstance(fp8_meta.inverse_scale, torch.Tensor)
+ assert fp8_meta.scale != 0.1 and fp8_meta.scale != 1.0
+ assert isinstance(fp8_meta.te_dtype, tex.DType)
+ assert fp8_meta.interval == interval
+
+
+@pytest.mark.parametrize("size", [4, 8, 16, 64])
+@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2])
+def test_quantize_and_dequantize_tensor_in_fp8(size, dtype):
+ tensor = torch.randn((size, size), dtype=torch.float32, device="cuda")
+ ref_tensor = deepcopy(tensor)
+ fp8_tensor = FP8Tensor(tensor, dtype=dtype)
+
+ assert not np.array_equal(fp8_tensor.cpu().numpy(), ref_tensor.cpu().numpy())
+
+ tensor = fp8_tensor.to(torch.float32)
+ # NOTE: sometimes type(tensor) is FP8Tensor, but it still passes, so we directly check the class name
+ # to make sure it's a torch.Tensor
+ assert tensor.__class__ == torch.Tensor
+ assert tensor.dtype == ref_tensor.dtype
+
+ torch.testing.assert_close(tensor, ref_tensor, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD)
+
+
+# @pytest.mark.parametrize("interval", [1, 5])
+@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2])
+def test_create_fp8_tensor_from_metadata(dtype):
+ INTERVAL = 5
+ TOTAL_STEPS, REMAINING_STEPS = 20, 16
+ tensor = torch.randn((16, 16), dtype=torch.float32, device="cuda")
+ fp8_tensor = FP8Tensor(tensor, dtype=dtype, interval=INTERVAL)
+
+ new_values = [torch.randn((16, 16), dtype=torch.float32, device="cuda") for i in range(TOTAL_STEPS)]
+
+ for i in range(TOTAL_STEPS):
+ if TOTAL_STEPS - REMAINING_STEPS == i:
+ current_tensor = fp8_tensor.orig_data
+ fp8_meta = deepcopy(fp8_tensor.fp8_meta)
+
+ fp8_tensor.data = new_values[i]
+
+ resumed_fp8_tensor = FP8Tensor.from_metadata(current_tensor, fp8_meta)
+ for i in range(TOTAL_STEPS - REMAINING_STEPS, TOTAL_STEPS):
+ resumed_fp8_tensor.data = new_values[i]
+
+ # NOTE: we expect a resume tensor to have the state trajectory of the original tensor
+ assert resumed_fp8_tensor == fp8_tensor
+
+
+@pytest.mark.parametrize("size", [4, 8, 16, 64])
+def test_quantize_and_dequantize_tensor_in_fp16(size):
+ tensor = torch.randn((size, size), dtype=torch.float32, device="cuda")
+ ref_tensor = deepcopy(tensor)
+
+ fp16_tensor = FP16Tensor(tensor, dtype=DTypes.KFLOAT16)
+
+ assert not np.array_equal(fp16_tensor.cpu().numpy(), ref_tensor.cpu().numpy())
+
+ # tensor = convert_tensor_from_fp16(fp16_tensor, torch.float32)
+ tensor = fp16_tensor.to(torch.float32)
+ # NOTE: sometimes type(tensor) is FP16Tensor, but it still passes
+ assert tensor.__class__ == torch.Tensor
+ assert tensor.dtype == ref_tensor.dtype
+
+ # NOTE: this tolerance is from FP8-LM's implementation
+ # reference: https://github.com/Azure/MS-AMP/blob/9ac98df5371f3d4174d8f103a1932b3a41a4b8a3/tests/common/tensor/test_cast.py#L35
+ torch.testing.assert_close(tensor, ref_tensor, rtol=FP16_RTOL_THRESHOLD, atol=FP16_ATOL_THRESHOLD)
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)]
+)
+def test_fp8_and_fp16_tensor_repr(tensor_cls, dtype):
+ tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0")
+ fp8_tensor = tensor_cls(tensor, dtype)
+
+ # NOTE: in some cases, it causes an infinite loop
+ # in repr(tensor), so just check if it doesn't loop
+ assert isinstance(repr(fp8_tensor), str)
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype, expected_dtype",
+ [
+ (FP8Tensor, DTypes.FP8E4M3, torch.uint8),
+ (FP8Tensor, DTypes.FP8E5M2, torch.uint8),
+ (FP16Tensor, DTypes.KFLOAT16, torch.float16),
+ ],
+)
+def test_fp8_and_fp16_tensor_attrs(tensor_cls, dtype, expected_dtype):
+ tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0")
+ ref_tensor = tensor.detach().clone()
+
+ fp8_tensor = tensor_cls(tensor, dtype)
+
+ assert isinstance(fp8_tensor, tensor_cls)
+ assert isinstance(fp8_tensor.fp8_meta, FP8Meta)
+ assert fp8_tensor.dtype == expected_dtype
+ assert fp8_tensor.device == ref_tensor.device
+ assert fp8_tensor.shape == ref_tensor.shape
+ assert fp8_tensor.numel() == ref_tensor.numel()
+ assert fp8_tensor.device == ref_tensor.device
+
+
+# @pytest.mark.parametrize(
+# "tensor_cls, dtype",
+# [
+# (FP8Tensor, DTypes.FP8E4M3),
+# (FP8Tensor, DTypes.FP8E5M2),
+# (FP16Tensor, DTypes.KFLOAT16),
+# ],
+# )
+# @pytest.mark.parametrize(
+# "scale",
+# [
+# torch.ones(1, device="cuda:0").squeeze() * 2, # an random scalar
+# torch.ones(1, device="cuda:0") * 2,
+# torch.ones(4, 4, device="cuda:0") * 2,
+# ],
+# )
+# def test_multiple_fp8_tensor(tensor_cls, dtype, scale):
+# RTOL, ATOL = (
+# (FP8_RTOL_THRESHOLD, FP8_ATOL_THRESHOLD)
+# if tensor_cls == FP8Tensor
+# else (FP16_RTOL_THRESHOLD, FP16_ATOL_THRESHOLD)
+# )
+# tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda:0")
+# ref_tensor = tensor.detach().clone()
+
+# fp8_tensor = tensor_cls(deepcopy(tensor), dtype)
+# ref_fp8_tensor = fp8_tensor.clone()
+
+# with fail_if_expect_to_fail(expect_to_fail=scale.ndim > 1):
+# fp8_tensor.mul_(scale)
+
+# assert torch.equal(fp8_tensor, ref_fp8_tensor)
+# assert fp8_tensor.fp8_meta.scale != ref_fp8_tensor.fp8_meta.scale
+
+# if isinstance(fp8_tensor, FP8Tensor):
+# # NOTE: with the current implementation, we only scale the metadata
+# # not the tensor itself, so we expect the tensor to be the same
+# tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32)
+# else:
+# tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32)
+
+# torch.testing.assert_allclose(tensor, ref_tensor * scale, rtol=RTOL, atol=ATOL)
+
+
+# @pytest.mark.parametrize(
+# "tensor_cls, dtype",
+# [
+# (FP8Tensor, DTypes.FP8E4M3),
+# (FP8Tensor, DTypes.FP8E5M2),
+# (FP16Tensor, DTypes.KFLOAT16),
+# ],
+# )
+# @pytest.mark.parametrize(
+# "scale",
+# [
+# torch.ones(1, device="cuda:0").squeeze() * 2, # an random scalar
+# torch.ones(1, device="cuda:0") * 2,
+# torch.ones(4, 4, device="cuda:0") * 2,
+# ],
+# )
+# def test_divide_fp8_tensor(tensor_cls, dtype, scale):
+# # NOTE: the reason that we use 2 as the scale is because
+# # the purpose of this test is to test whether we scale the magnitude
+# # of the tensor, so if use other values from normal distribution,
+# # some values could lead to quantization error, and for this test we don't
+# # test the quantization error
+# tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0")
+# ref_tensor = deepcopy(tensor)
+
+# fp8_tensor = tensor_cls(deepcopy(tensor), dtype)
+# ref_fp8_tensor = fp8_tensor.clone()
+
+# with fail_if_expect_to_fail(expect_to_fail=scale.ndim > 1):
+# fp8_tensor.div_(scale)
+
+# assert torch.equal(fp8_tensor, ref_fp8_tensor)
+# assert fp8_tensor.fp8_meta.scale != ref_fp8_tensor.fp8_meta.scale
+
+# if isinstance(fp8_tensor, FP8Tensor):
+# tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32)
+# # NOTE: use the same tolerance as test_quantize_and_dequantize_tensor_in_fp8
+# torch.testing.assert_allclose(tensor, ref_tensor / scale, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD)
+# else:
+# tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32)
+# torch.testing.assert_close(tensor, ref_tensor / scale, rtol=FP16_RTOL_THRESHOLD, atol=FP16_ATOL_THRESHOLD)
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype",
+ [
+ (FP8Tensor, DTypes.FP8E4M3),
+ (FP8Tensor, DTypes.FP8E5M2),
+ (FP16Tensor, DTypes.KFLOAT16),
+ ],
+)
+def test_add_fp8_tensor(tensor_cls, dtype):
+ tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0")
+ fp8_tensor = tensor_cls(deepcopy(tensor), dtype)
+
+ with pytest.raises(ValueError):
+ fp8_tensor + 1
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype",
+ [
+ (FP8Tensor, DTypes.FP8E4M3),
+ (FP8Tensor, DTypes.FP8E5M2),
+ (FP16Tensor, DTypes.KFLOAT16),
+ ],
+)
+def test_subtract_fp8_tensor(tensor_cls, dtype):
+ tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0")
+ fp8_tensor = tensor_cls(deepcopy(tensor), dtype)
+
+ with pytest.raises(ValueError):
+ fp8_tensor - 1
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype",
+ [
+ (FP8Tensor, DTypes.FP8E4M3),
+ (FP8Tensor, DTypes.FP8E5M2),
+ (FP16Tensor, DTypes.KFLOAT16),
+ ],
+)
+def test_clone_fp8_tensor(tensor_cls, dtype):
+ tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0")
+ fp8_tensor = tensor_cls(deepcopy(tensor), dtype)
+
+ cloned_fp8_tensor = fp8_tensor.clone()
+
+ assert isinstance(cloned_fp8_tensor, tensor_cls)
+ assert id(cloned_fp8_tensor) != id(fp8_tensor)
+ assert cloned_fp8_tensor.device == fp8_tensor.device
+
+ assert torch.equal(cloned_fp8_tensor, fp8_tensor)
+ assert cloned_fp8_tensor.data_ptr() != fp8_tensor.data_ptr()
+ assert cloned_fp8_tensor.data.data_ptr() != fp8_tensor.data.data_ptr()
+
+ assert cloned_fp8_tensor.fp8_meta == fp8_tensor.fp8_meta
+ assert id(cloned_fp8_tensor.fp8_meta) != id(fp8_tensor.fp8_meta)
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype",
+ [
+ (FP8Tensor, DTypes.FP8E4M3),
+ (FP8Tensor, DTypes.FP8E5M2),
+ # (FP16Tensor, DTypes.KFLOAT16),
+ ],
+)
+def test_transpose_fp8_tensor(tensor_cls, dtype):
+ tensor = torch.randn((16, 16), dtype=torch.float32, device="cuda:0")
+ ref_transposed_tensor = deepcopy(tensor).transpose()
+ fp8_tensor = tensor_cls(tensor, dtype)
+
+ transposed_fp8_tensor = fp8_tensor.transpose()
+
+ assert isinstance(transposed_fp8_tensor, FP8Tensor)
+
+ dequant_transposed_fp8_tensor = transposed_fp8_tensor.to(torch.float32)
+ torch.testing.assert_close(dequant_transposed_fp8_tensor, ref_transposed_tensor)
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype",
+ [
+ (FP8Tensor, DTypes.FP8E4M3),
+ (FP8Tensor, DTypes.FP8E5M2),
+ (FP16Tensor, DTypes.KFLOAT16),
+ ],
+)
+def test_determistic_quantization(tensor_cls, dtype):
+ tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0")
+ fp8_tensor = tensor_cls(deepcopy(tensor), dtype)
+ ref_fp8_tensor = tensor_cls(deepcopy(tensor), dtype)
+
+ assert torch.equal(fp8_tensor, ref_fp8_tensor)
+ assert fp8_tensor.fp8_meta == ref_fp8_tensor.fp8_meta
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)]
+)
+def test_fp8_and_fp16_tensor_storage_memory(tensor_cls, dtype):
+ tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda")
+ ref_tensor = deepcopy(tensor)
+
+ fp8_tensor = tensor_cls(tensor, dtype=dtype)
+
+ assert id(fp8_tensor) != id(ref_tensor)
+
+ assert isinstance(fp8_tensor.data, torch.Tensor)
+ assert id(fp8_tensor.data) != id(ref_tensor.data)
+ assert fp8_tensor.data_ptr() == fp8_tensor.data.data_ptr()
+ assert fp8_tensor.data.data_ptr() != ref_tensor.data_ptr()
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)]
+)
+@pytest.mark.parametrize("is_quantized", [True, False])
+def test_setting_new_data_for_fp8_and_fp16_tensor(tensor_cls, dtype, is_quantized):
+ RTOL, ATOL = (
+ (FP8_RTOL_THRESHOLD, FP8_ATOL_THRESHOLD)
+ if tensor_cls == FP8Tensor
+ else (FP16_RTOL_THRESHOLD, FP16_ATOL_THRESHOLD)
+ )
+
+ tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda")
+ quant_tensor = tensor_cls(tensor, dtype=dtype)
+
+ new_data = torch.randn(quant_tensor.shape, dtype=torch.float32, device="cuda") * 2
+ ref_new_data = deepcopy(new_data)
+ expected_quantized_tensor = tensor_cls(ref_new_data, dtype=dtype)
+
+ new_data = tensor_cls(new_data, dtype=dtype) if is_quantized else new_data
+ quant_tensor.data = new_data
+ assert dequant_tensor.data.data_ptr() == new_data.data.data_ptr()
+
+ assert quant_tensor.data.dtype == QTYPE_TO_DTYPE[dtype]
+ assert torch.equal(quant_tensor, expected_quantized_tensor)
+
+ if is_quantized:
+ # if tensor_cls == FP8Tensor:
+ # dequant_tensor = convert_tensor_from_fp8(quant_tensor, fp8_tensor.fp8_meta, torch.float32)
+ # else:
+ # dequantized_tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32)
+
+ dequant_tensor = quant_tensor.to(torch.float32)
+ assert torch.allclose(dequant_tensor, ref_new_data, rtol=RTOL, atol=ATOL)
+
+
+# @pytest.mark.parametrize(
+# "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)]
+# )
+# @pytest.mark.parametrize("is_quantized", [True, False])
+# def test_setting_None_data_for_fp8_and_fp16_tensor(tensor_cls, dtype, is_quantized):
+# tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda")
+# fp8_tensor = tensor_cls(tensor, dtype=dtype)
+
+# fp8_tensor.set_data(None)
+
+# assert fp8_tensor is None
+# assert fp8_tensor.data is None
+
+# @pytest.mark.parametrize(
+# "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)]
+# )
+# def test_quantize_overflow_fp8_and_fp16_tensor(tensor_cls, dtype):
+# tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0")
+# tensor[0, 0] = torch.tensor(float("inf"))
+# fp8_tensor = tensor_cls(tensor, dtype)
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)]
+)
+def test_zero_out_data_of_fp8_and_fp16_tensor(tensor_cls, dtype):
+ tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda")
+ quant_tensor = tensor_cls(tensor, dtype=dtype)
+
+ quant_tensor.zero_()
+
+ assert torch.equal(quant_tensor, torch.zeros_like(quant_tensor))
+
+ dequant_tensor = quant_tensor.to(torch.float32)
+ assert torch.equal(dequant_tensor, torch.zeros_like(tensor))
+
+
+# NOTE: add testing based on tensor metadata
+@pytest.mark.parametrize("is_meta_the_same", [True, False])
+@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2])
+def test_fp8_and_fp16_tensor_equality_based_on_tensor_value(is_meta_the_same, dtype):
+ # TODO(xrsrke): support torch.equal for FP8Tensor
+ tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda")
+ ref_tensor = deepcopy(tensor)
+
+ fp8_tensor = FP8Tensor(tensor, dtype=dtype)
+ ref_fp8_tensor = FP8Tensor(ref_tensor, dtype=dtype)
+
+ if not is_meta_the_same:
+ fp8_tensor.fp8_meta.scale = ref_fp8_tensor.fp8_meta.scale * 2
+
+ assert (fp8_tensor == ref_fp8_tensor) is is_meta_the_same
+
+ new_data = torch.randn(tensor.shape, dtype=torch.float32, device="cuda")
+ ref_fp8_tensor.data = new_data
+
+ assert not fp8_tensor == ref_fp8_tensor
+
+
+# TODO(xrsrke): test it has all the methods of torch.Tensor
+
+# TODO(xrsrke): test it has all the attributes of its input tensor
+
+
+@pytest.mark.parametrize(
+ "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)]
+)
+def test_serialize_fp8_tensor(tensor_cls, dtype):
+ test_context = TestContext()
+ store_folder = test_context.get_auto_remove_tmp_dir()
+ tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda")
+
+ fp8_tensor = tensor_cls(tensor, dtype=dtype)
+
+ torch.save(fp8_tensor, f"{store_folder}/fp8_tensor.pt")
+ torch.load(f"{store_folder}/fp8_tensor.pt")
diff --git a/tests/fp8/test_tensor.py b/tests/fp8/test_tensor.py
index 40608d34..237e89d5 100644
--- a/tests/fp8/test_tensor.py
+++ b/tests/fp8/test_tensor.py
@@ -297,6 +297,8 @@ def test_transpose_fp8_tensor(tensor_cls, dtype):
ref_fp8_tensor = tensor_cls(deepcopy(tensor), dtype)
transposed_fp8_tensor = fp8_tensor.transpose_fp8()
+ # transposed_fp8_tensor = fp8_tensor.t()
+ # transposed_fp8_tensor = torch.transpose(fp8_tensor, 0, 1)
# NOTE: we expect the original tensor to be the same
assert fp8_tensor == ref_fp8_tensor
diff --git a/tests/nanoset/test_build_nanoset_dataloader.py b/tests/nanoset/test_build_nanoset_dataloader.py
index 113c545c..f7726262 100644
--- a/tests/nanoset/test_build_nanoset_dataloader.py
+++ b/tests/nanoset/test_build_nanoset_dataloader.py
@@ -8,7 +8,6 @@
import numpy as np
import pytest
-from helpers.context import TestContext
from helpers.data import (
assert_batch_dataloader,
assert_nanoset_sync_across_all_ranks,
@@ -22,6 +21,7 @@
from nanotron.data.nanoset import Nanoset
from nanotron.data.utils import count_dataset_indexes, normalize
from nanotron.parallel import ParallelContext
+from nanotron.testing.utils import TestContext
from nanotron.utils import main_rank_first
from transformers import AutoTokenizer
diff --git a/tests/pytest.ini b/tests/pytest.ini
index 0e0b2653..c0c6f3eb 100644
--- a/tests/pytest.ini
+++ b/tests/pytest.ini
@@ -1,4 +1,4 @@
[pytest]
-addopts=-n 35
+; addopts=-n 10
markers =
fa2: FA2-related
diff --git a/tests/test_serialize.py b/tests/test_serialize.py
index 329ff279..62e2a584 100644
--- a/tests/test_serialize.py
+++ b/tests/test_serialize.py
@@ -1,6 +1,5 @@
import pytest
import torch
-from helpers.context import TestContext
from helpers.dummy import dummy_infinite_data_loader, init_dummy_model
from helpers.utils import (
available_gpus,
@@ -33,6 +32,7 @@
save_weights,
)
from nanotron.serialize.metadata import TensorMetadata
+from nanotron.testing.utils import TestContext
from torch.nn.parallel import DistributedDataParallel