diff --git a/.gitignore b/.gitignore index cbc04eaf..2da7cc5a 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,8 @@ cython_debug/ checkpoints/ wandb/ + +*.csv +*.html +src/nanotron/.test_cache/ +log/ diff --git a/benchmark/fp8_gemm.py b/benchmark/fp8_gemm.py new file mode 100644 index 00000000..f534fa3b --- /dev/null +++ b/benchmark/fp8_gemm.py @@ -0,0 +1,256 @@ +import torch +import transformer_engine.pytorch.cpp_extensions as texcpp + +# from transformer_engine.pytorch.module import get_workspace +# import transformer_engine_extensions as tex +import transformer_engine_torch as tex + +scale = 1.0 + +meta = tex.FP8TensorMeta() +meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale +meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale +meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda") + + +def cast_to_fp8(x, qtype): + ret = texcpp.cast_to_fp8(x, meta, tex.FP8FwdTensors.GEMM1_INPUT, qtype) + ret._fp8_qtype = qtype + return ret + + +def cast_from_fp8(x, qtype): + ret = texcpp.cast_from_fp8(x, meta, tex.FP8FwdTensors.GEMM1_INPUT, x._fp8_qtype, qtype) + ret._fp8_qtype = qtype + return ret + + +one_scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") +empty_tensor = torch.Tensor() +# workspace = get_workspace() +workspace = torch.empty(33_554_432, dtype=torch.int8, device="cuda") +assert workspace.is_cuda + + +# PT_DType = dict([(v, k) for k, v in texcpp.TE_DType.items()]) +# PT_DType[tex.DType.kFloat8E4M3] = torch.uint8 +# PT_DType[tex.DType.kFloat8E5M2] = torch.uint8 + + +def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType: + # NOTE: transformer engine maintains it own dtype mapping + # so we need to manually map torch dtypes to TE dtypes + TORCH_DTYPE_TE_DTYPE_NAME_MAPPING = { + torch.int32: "kInt32", + torch.float32: "kFloat32", + torch.float16: "kFloat16", + torch.bfloat16: "kBFloat16", + # DTypes.FP8E4M3: "kFloat8E4M3", + # DTypes.FP8E5M2: "kFloat8E5M2", + # DTypes.KFLOAT16: "kFloat16", + } + return getattr(tex.DType, TORCH_DTYPE_TE_DTYPE_NAME_MAPPING[dtype]) + + +def fp8_gemm(fa, fb, trans_a, trans_b, bias=None, qtype=tex.DType.kFloat32): + """ + # te_gemm + + input_A: (A_row, A_col) + input_B: (B_row, B_col) + + when transa, transb = True, False + m, k, n = A_row, A_col, B_row + lda, ldb, ldd = A_col, A_col, A_row + output_D: (B_row, A_row) + + when transa, transb = False, False + m, k, n = A_col, A_row, B_row + lda, ldb, ldd = A_col, A_row, A_col + output_D: (B_row, A_col) + + when transa, transb = False, True + m, k, n = A_col, A_row, B_col + lda, ldb, ldd = A_col, B_col, A_col + output_D: (B_col, A_col) + """ + assert fa.is_cuda and fb.is_cuda + assert fa.is_contiguous() + assert fb.is_contiguous() + device = fa.device + fa_qtype, fb_qtype = fa._fp8_qtype, fb._fp8_qtype + A_row, A_col = fa.shape + B_row, B_col = fb.shape + if trans_a and not trans_b: + assert A_col == B_col + C_row, C_col = B_row, A_row + elif not trans_a and not trans_b: + assert A_row == B_col + C_row, C_col = B_row, A_col + elif not trans_a and trans_b: + assert A_row == B_row + C_row, C_col = B_col, A_col + out_shape = (C_row, C_col) + + # dtype = PT_DType[qtype] + if qtype == tex.DType.kFloat32: + dtype = torch.float32 + elif qtype == tex.DType.kFloat16: + dtype = torch.float16 + + out = torch.empty(out_shape, dtype=dtype, device=device) + # te_gemm is column-order. + + # tex.te_gemm( + # fa, one_scale_inv, fa_qtype, trans_a, + # fb, one_scale_inv, fb_qtype, trans_b, + # out, qtype, + # bias or empty_tensor, empty_tensor, False, + # workspace, workspace.shape[0], + # False, True, + # ) + + _empty_tensor = torch.Tensor() + SCALE = AMAX = _empty_tensor + TE_CONFIG_TRANSPOSE_BIAS = False + + tex.te_gemm( + fa, + one_scale_inv, + fa_qtype, + trans_a, + fb, + one_scale_inv, + fb_qtype, + trans_b, + # out, SCALE, qtype, AMAX, + # bias or empty_tensor, qtype, False, + # workspace, workspace.shape[0], + # False, True, + out, + SCALE, + qtype, + AMAX, + torch.tensor([], dtype=dtype), + qtype, + _empty_tensor, + TE_CONFIG_TRANSPOSE_BIAS, + workspace, + workspace.shape[0], + False, + True, + 0, + ) + + out._fp8_qtype = qtype + return out + + +def fp8_matmul(fa, fb, bias=None, qtype=tex.DType.kFloat32): + # trans_a = False and trans_b = False is not implemented. + fb_qtype = fb._fp8_qtype + fb = fb.T.contiguous() + fb._fp8_qtype = fb_qtype + return fp8_gemm(fb, fa, trans_a=True, trans_b=False, bias=bias, qtype=qtype) + + +h100_peak_flops_float32 = 67e12 +h100_peak_flops_fp16_tc = 989e12 +h100_peak_tops_float8_tc = 1979e12 + +dtype_to_peak_tops = { + torch.float32: h100_peak_flops_float32, + torch.float16: h100_peak_flops_fp16_tc, + torch.bfloat16: h100_peak_flops_fp16_tc, + torch.float8_e4m3fn: h100_peak_tops_float8_tc, + torch.float8_e5m2: h100_peak_tops_float8_tc, +} + +from torch.utils import benchmark + + +def benchmark_fn_in_sec(f, *args, **kwargs): + # Manual warmup + for _ in range(4): + f(*args, **kwargs) + + t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) + measurement = t0.blocked_autorange() + return measurement.mean + + +def run_fp8(a, b): + fa = cast_to_fp8(a, tex.DType.kFloat8E4M3) + fb = cast_to_fp8(b, tex.DType.kFloat8E4M3) + fp8_matmul(fa, fb, qtype=tex.DType.kFloat16) + + +def run_bfloat16(a, b): + a = a.to(torch.bfloat16) + b = b.to(torch.bfloat16) + torch.matmul(a, b) + + +def benchmark_linear_operations(a, b): + M, K = a.shape + N, _ = b.shape + + # Benchmark FP8 + fp8_time = benchmark_fn_in_sec(run_fp8, a, b) + + # Benchmark BFloat16 + bfloat16_time = benchmark_fn_in_sec(run_bfloat16, a, b) + + # Calculate FLOPS + # Each linear operation performs 2*M*N*K FLOPs (multiply-add) + total_flops = 2 * M * N * K + + fp8_tflops = (total_flops / fp8_time) / 1e12 + bfloat16_tflops = (total_flops / bfloat16_time) / 1e12 + + # Calculate efficiency compared to peak performance + fp8_efficiency = (fp8_tflops / (h100_peak_tops_float8_tc / 1e12)) * 100 + bfloat16_efficiency = (bfloat16_tflops / (h100_peak_flops_fp16_tc / 1e12)) * 100 + + return { + "M": M, + "N": N, + "K": K, + "FP8_time_ms": fp8_time * 1000, + "BF16_time_ms": bfloat16_time * 1000, + "FP8_TFLOPS": fp8_tflops, + "BF16_TFLOPS": bfloat16_tflops, + "FP8_eff%": fp8_efficiency, + "BF16_eff%": bfloat16_efficiency, + "Speedup": bfloat16_time / fp8_time, + } + + +if __name__ == "__main__": + # a = torch.randn(128, 128).cuda() + # b = torch.randn(128, 128).cuda() + # qa = cast_from_fp8(fa, tex.DType.kFloat32) + # qb = cast_from_fp8(fb, tex.DType.kFloat32) + # qc = torch.matmul(qa, qb) + + # E4M3/E5M2 @ E4M3/E5M2 = FP16/FP32 + # print(qc, qc2) + + import pandas as pd + + def create_benchmark_table(sizes): + results = [] + for size in sizes: + a = torch.randn(size, size).cuda() + b = torch.randn(size, size).cuda() + result = benchmark_linear_operations(a, b) + results.append(result) + + df = pd.DataFrame(results) + df = df.round(2) # Round to 2 decimal places + return df + + # Example usage: + sizes = [4096, 16384, 32768, 28672, 49152] + benchmark_table = create_benchmark_table(sizes) + print(benchmark_table) diff --git a/benchmark/fp8_tp.py b/benchmark/fp8_tp.py new file mode 100644 index 00000000..3578de34 --- /dev/null +++ b/benchmark/fp8_tp.py @@ -0,0 +1,153 @@ +import argparse +import itertools + +import pandas as pd +import torch +import torch.distributed as dist +from nanotron.parallel import ParallelContext +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import FP8TensorParallelColumnLinear +from torch.profiler import ProfilerActivity + + +def run_experiment(exp_name, M, N, K, TP_SIZE, parallel_context): + torch.cuda.synchronize() + input = torch.randn(M, K, device="cuda", requires_grad=True) + column_linear = FP8TensorParallelColumnLinear( + in_features=K, + out_features=N, + pg=parallel_context.tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + device="cuda", + async_communication=False, + bias=False, + ) + + sharded_output = column_linear(input) + + with torch.profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./log/{exp_name}"), + record_shapes=True, + # profile_memory=True, + with_stack=True, + with_modules=True, + experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True), + use_cuda=True, + ) as prof: + prof.step() + sharded_output.sum().backward() + + return prof + + +def print_profiling_table(prof, sort_by="cpu_time_total"): + print(f"###### sorted by {sort_by} ######") + print( + prof.key_averages(group_by_stack_n=100).table( + sort_by=sort_by, + row_limit=20, + top_level_events_only=False, + # max_src_column_width=2000, # Increase source column width + # max_name_column_width=2000, + # max_shapes_column_width=1000, + max_src_column_width=100, # Increase source column width + max_name_column_width=30, + max_shapes_column_width=100, + ) + ) + + +def explore_event_values(event): + for attr in dir(event): + if not attr.startswith("_"): # Skip internal attributes + try: + value = getattr(event, attr) + if callable(value): # Skip methods + continue + print(f"\n{attr}:") + print(value) + print("-" * 50) # Separator for better readability + except Exception: + print(f"{attr}: ") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions") + parser.add_argument("--exp_number", type=str, help="Experiment number") + parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") + parser.add_argument( + "--dimensions", + type=str, + default="1024,2048,4096,8192,16384,32768", + help="Comma-separated list of dimensions to test", + ) + return parser.parse_args() + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + + args = parse_args() + + # Parse dimensions from comma-separated string to list of integers + dimensions = [int(d.strip()) for d in args.dimensions.split(",")] + TP_SIZE = args.tp_size + EXP_NUMBER = args.exp_number + + # dimensions = [1024, 2048, 4096, 8192, 16384] + # TP_SIZE = 8 + + results = [] + total = len(list(itertools.product(dimensions, dimensions, dimensions))) + experiment_count = 0 + parallel_context = ParallelContext(data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=TP_SIZE) + + for M, N, K in itertools.product(dimensions, dimensions, dimensions): + exp_name = f"{EXP_NUMBER}_fp8_m{M}_n{N}_k{K}_and_tp{TP_SIZE}" + total += 1 + print(f"Running experiment with M={M}, N={N}, K={K}, {experiment_count}/{total}") + + prof = run_experiment(exp_name, M, N, K, TP_SIZE=TP_SIZE, parallel_context=parallel_context) + + if dist.get_rank() == 0: + print_profiling_table(prof, sort_by="cpu_time_total") + print_profiling_table(prof, sort_by="cuda_time_total") + print_profiling_table(prof, sort_by="self_cuda_time_total") + # explore_event_values(table) + + # Get top 5 operations by CPU time + # sorted_events = prof.key_averages().table(sort_by="cpu_time_total") + + # NOTE: loop through all events and sum up the total time, then calculate the percent + averages = prof.key_averages(group_by_stack_n=100) + # NOTE: why sum .self_cpu_time_total instead of .cpu_time_total? + # source: https://github.com/pytorch/pytorch/blob/f14f245747db2f80e963bd72561f5bd5ed216a4a/torch/autograd/profiler_util.py#L976-L990 + # i test and it matches the torch's generated table + cpu_time_total_of_all_events = sum([event.self_cpu_time_total for event in averages]) + sorted_events = sorted(averages, key=lambda x: x.cpu_time_total, reverse=True)[:5] + + for event in sorted_events: + event_cpu_time_percent = (event.cpu_time_total / cpu_time_total_of_all_events) * 100 + + results.append( + { + "M": M, + "N": N, + "K": K, + "Operation": event.key, + "CPU Time (ms)": event.cpu_time_total / 1000, # Convert to milliseconds + "CPU Time %": f"{event_cpu_time_percent:.2f}%", + "CUDA Time (ms)": event.cuda_time_total / 1000, # Convert to milliseconds + # 'Memory Used (MB)': event.cpu_memory_usage / (1024 * 1024) if event.cpu_memory_usage else 0 + } + ) + +if dist.get_rank() == 0: + df = pd.DataFrame(results) + print("\nTop 5 most time-consuming operations for each dimension combination:") + print(df.to_string()) + df.to_csv( + f'{EXP_NUMBER}_profiling_results_with_m_n_k_with_cartesian_product_{"_".join(map(str, dimensions))}.csv', + index=False, + ) diff --git a/benchmark/fp8_tp_speed.py b/benchmark/fp8_tp_speed.py new file mode 100644 index 00000000..dbd58f13 --- /dev/null +++ b/benchmark/fp8_tp_speed.py @@ -0,0 +1,285 @@ +import argparse + +import pandas as pd +import torch +from nanotron.models.base import init_on_device_and_dtype +from nanotron.parallel import ParallelContext +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import FP8TensorParallelColumnLinear, TensorParallelColumnLinear +from torch.utils import benchmark + +# H100 SXM specs: bottom of https://www.nvidia.com/en-us/data-center/h100/ +h100_peak_flops_float32 = 67e12 +# NOTE: without sparsity +h100_peak_flops_fp16_tc = 989e12 +h100_peak_tops_float8_tc = 1979e12 + +# # NOTE: with sparity +# h100_peak_flops_fp16_tc = 1979e12 +# h100_peak_tops_float8_tc = 3958e12 + + +def color_scale(val, min_val, max_val, metric_type="default"): + """Generate background color based on value and metric type.""" + if pd.isna(val): + return "background-color: white" + + normalized = (val - min_val) / (max_val - min_val) if max_val != min_val else 0 + + if metric_type == "time": # Lower is better - red scale + color = f"background-color: rgba(255, {int(255 * (1-normalized))}, {int(255 * (1-normalized))}, 0.8)" + elif metric_type == "performance": # Higher is better - green scale + color = f"background-color: rgba({int(255 * (1-normalized))}, 255, {int(255 * (1-normalized))}, 0.8)" + elif metric_type == "efficiency": # Higher is better - blue scale + color = f"background-color: rgba({int(255 * (1-normalized))}, {int(255 * (1-normalized))}, 255, 0.8)" + else: # Default purple scale + color = f"background-color: rgba({int(255 * (1-normalized))}, 0, 255, 0.8)" + + text_color = "white" if normalized > 0.5 else "black" + return f"{color}; color: {text_color}" + + +def create_html_table(df, exp_number, tp_size): + def style_df(df): + # Create an empty DataFrame with the same shape as the input + styled = pd.DataFrame("", index=df.index, columns=df.columns) + + # Style specific columns + for column in df.columns: + if column.endswith("time_ms"): + styled[column] = df[column].apply(lambda x: color_scale(x, df[column].min(), df[column].max(), "time")) + elif column.endswith("TFLOPS"): + styled[column] = df[column].apply( + lambda x: color_scale(x, df[column].min(), df[column].max(), "performance") + ) + elif column.endswith("efficiency_%"): + styled[column] = df[column].apply( + lambda x: color_scale(x, df[column].min(), df[column].max(), "efficiency") + ) + elif column == "Speedup": + styled[column] = df[column].apply(lambda x: color_scale(x, df[column].min(), df[column].max())) + return styled + + # Format numbers and apply styling + styled_df = df.style.format( + { + "FP8_time_ms": "{:.2f}", + "BF16_time_ms": "{:.2f}", + "FP8_TFLOPS": "{:.2f}", + "BF16_TFLOPS": "{:.2f}", + "FP8_efficiency_%": "{:.2f}", + "BF16_efficiency_%": "{:.2f}", + "Speedup": "{:.2f}", + } + ).apply(lambda _: style_df(df), axis=None) + + # Generate HTML + html = f""" + + + + + +
+

Benchmark Results (TP_SIZE={tp_size})

+

Experiment: {exp_number}

+
+ {styled_df.to_html(table_id="results")} + + + """ + + with open(f"{exp_number}_benchmark_results_tp{tp_size}.html", "w") as f: + f.write(html) + + +def benchmark_fn_in_sec(f, *args, **kwargs): + # Manual warmup + for _ in range(4): + f(*args, **kwargs) + + t0 = benchmark.Timer(stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}) + measurement = t0.blocked_autorange() + return measurement.mean + + +def run_fp8_linear(input, M, N, K, parallel_context, include_backward=False): + column_linear = FP8TensorParallelColumnLinear( + in_features=K, + out_features=N, + pg=parallel_context.tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + device="cuda", + async_communication=False, + bias=False, + ) + + sharded_output = column_linear(input) + + if include_backward is True: + sharded_output.sum().backward() + + +def run_linear(input, M, N, K, parallel_context, include_backward=False): + with init_on_device_and_dtype(device="cuda", dtype=torch.bfloat16): + column_linear = TensorParallelColumnLinear( + in_features=K, + out_features=N, + pg=parallel_context.tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + device="cuda", + async_communication=False, + bias=False, + ) + + sharded_output = column_linear(input) + + if include_backward is True: + sharded_output.sum().backward() + + +# def parse_args(): +# parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions") +# parser.add_argument("--exp_number", type=str, help="Experiment number") +# parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") +# # parser.add_argument( +# # "--dimensions", +# # type=str, +# # default="4096,16384,32768,28672,49152", +# # help="Comma-separated list of dimensions to test. Used when M/K/N are not explicitly provided", +# # ) +# # Add the missing argument definitions +# parser.add_argument("--m_size", type=int, help="Explicitly set M dimension") +# parser.add_argument("--k_size", type=int, help="Explicitly set K dimension") +# parser.add_argument("--n_size", type=int, help="Explicitly set N dimension") +# return parser.parse_args() + + +def parse_args(): + parser = argparse.ArgumentParser(description="Run profiling experiments with configurable dimensions") + parser.add_argument("--exp_number", type=str, help="Experiment number") + parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallel size") + parser.add_argument( + "--dimensions", + type=str, + default="4096,16384,32768,28672,49152", + help="Comma-separated list of dimensions to test. Used when M/K/N are not explicitly provided", + ) + parser.add_argument("--m_size", type=int, help="Explicitly set M dimension") + parser.add_argument("--k_size", type=int, help="Explicitly set K dimension") + parser.add_argument("--n_size", type=int, help="Explicitly set N dimension") + return parser.parse_args() + + +def benchmark_linear_operations(M, N, K, parallel_context, include_backward, requires_grad): + input = torch.randn(M, K, device="cuda", requires_grad=requires_grad) + bfloat16_input = torch.randn(M, K, device="cuda", requires_grad=requires_grad, dtype=torch.bfloat16) + + # Benchmark FP8 + fp8_time = benchmark_fn_in_sec(run_fp8_linear, input, M, N, K, parallel_context, include_backward) + + # Benchmark BFloat16 + bfloat16_time = benchmark_fn_in_sec(run_linear, bfloat16_input, M, N, K, parallel_context, include_backward) + + # Calculate FLOPS + total_flops = 2 * M * N * K // parallel_context.tensor_parallel_size + if include_backward: + # Gradient with Respect to Parameters + total_flops += 2 * M * N * K // parallel_context.tensor_parallel_size + if requires_grad: + # Gradient with Respect to Inputs + total_flops += 2 * M * N * K // parallel_context.tensor_parallel_size + + fp8_tflops = (total_flops / fp8_time) / 1e12 + bfloat16_tflops = (total_flops / bfloat16_time) / 1e12 + + # Calculate efficiency + fp8_efficiency = (fp8_tflops / (h100_peak_tops_float8_tc / 1e12)) * 100 + bfloat16_efficiency = (bfloat16_tflops / (h100_peak_flops_fp16_tc / 1e12)) * 100 + + return { + "M": M, + "N": N, + "K": K, + "Include_Backward": include_backward, + "Input_Requires_Grad": requires_grad, + "FP8_time_ms": fp8_time * 1000, + "BF16_time_ms": bfloat16_time * 1000, + "FP8_TFLOPS": fp8_tflops, + "BF16_TFLOPS": bfloat16_tflops, + "FP8_MFU%": fp8_efficiency, + "BF16_MFU%": bfloat16_efficiency, + "Speedup": bfloat16_time / fp8_time, + } + + +if __name__ == "__main__": + torch.backends.cudnn.benchmark = True + + args = parse_args() + dimensions = [int(d.strip()) for d in args.dimensions.split(",")] + parallel_context = ParallelContext( + data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=args.tp_size + ) + + results = [] + # combinations = list(itertools.product( + # dimensions, # M + # dimensions, # N + # dimensions, # K + # [False, True], # include_backward + # [False, True] # requires_grad + # )) + # Pair dimensions index-wise + # combinations = [] + # for i in range(len(dimensions)): + # M = dimensions[i] + # N = dimensions[i] + # K = dimensions[i] + # # For each dimension pair, test all combinations of include_backward and requires_grad + # for include_backward in [False, True]: + # for requires_grad in [False, True]: + # combinations.append((M, N, K, include_backward, requires_grad)) + + combinations = [] + # Check if explicit M, K, N dimensions are provided + if all(dim is not None for dim in [args.m_size, args.k_size, args.n_size]): + # Use explicitly provided dimensions + for include_backward in [False, True]: + # for requires_grad in [False, True]: + for requires_grad in [False]: + combinations.append((args.m_size, args.n_size, args.k_size, include_backward, requires_grad)) + else: + # Use dimensions from the --dimensions argument + dimensions = [int(d.strip()) for d in args.dimensions.split(",")] + for i in range(len(dimensions)): + M = dimensions[i] + N = dimensions[i] + K = dimensions[i] + for include_backward in [False, True]: + # for requires_grad in [False, True]: + for requires_grad in [False]: + combinations.append((M, N, K, include_backward, requires_grad)) + + total = len(combinations) + + for i, (M, N, K, include_backward, requires_grad) in enumerate(combinations, 1): + result = benchmark_linear_operations(M, N, K, parallel_context, include_backward, requires_grad) + results.append(result) + print(f"Experiment {i}/{total} complete") + + df = pd.DataFrame(results) + df = df.round(2) + df = df.sort_values(by=["M", "N", "K", "Include_Backward", "Input_Requires_Grad"]) + + print(df) + create_html_table(df, args.exp_number, args.tp_size) + print(f"\nResults have been saved to {args.exp_number}_benchmark_results_tp{args.tp_size}.html") diff --git a/benchmark/ref_tp.py b/benchmark/ref_tp.py new file mode 100644 index 00000000..e8be8cdf --- /dev/null +++ b/benchmark/ref_tp.py @@ -0,0 +1,170 @@ +# from nanotron import distributed as dist +import nanotron.fp8.distributed as dist + +# import torch.distributed as dist +import torch +from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8 +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import get_data_from_param, get_grad_from_parameter +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import ( + FP8TensorParallelColumnLinear, +) +from nanotron.sanity_checks import assert_tensor_synced_across_pg +from torch import nn + +if __name__ == "__main__": + with_bias = False + # NOTE: divisible by 16 for TP + in_features = 32 + out_features_per_tp_rank = 16 + + parallel_context = ParallelContext(data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=2) + + out_features = parallel_context.tp_pg.size() * out_features_per_tp_rank + + # Sharded + column_linear = FP8TensorParallelColumnLinear( + in_features=in_features, + out_features=out_features, + pg=parallel_context.tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + device="cuda", + async_communication=False, + bias=with_bias, + ) + + # Un-sharded + reference_linear = nn.Linear(in_features=in_features, out_features=out_features, bias=with_bias, device="cuda") + + # Copy weights/bias from sharded to un-sharded + with torch.inference_mode(): + # weight = column_linear.weight.data + # weight = convert_tensor_from_fp8(weight, weight.fp8_meta, torch.bfloat16), + dist.all_gather( + tensor_list=list(reference_linear.weight.split(out_features_per_tp_rank, dim=0)), + # tensor=column_linear.weight.data, + tensor=get_data_from_param(column_linear.weight), + group=parallel_context.tp_pg, + ) + + if with_bias is True: + # TODO(xrsrke): support if bias is in FP8 + # bias = column_linear.bias.data + bias = get_data_from_param(column_linear.bias) + bias = bias.to(reference_linear.bias.dtype) if bias.dtype != reference_linear.bias.dtype else bias + dist.all_gather( + tensor_list=list(reference_linear.bias.split(out_features_per_tp_rank, dim=0)), + tensor=bias, + group=parallel_context.tp_pg, + ) + + # TODO(xrsrke) + if with_bias is True: + assert column_linear.bias.requires_grad is (with_bias is True) + # assert column_linear.bias.data.__class__ == torch.Tensor + assert get_data_from_param(column_linear.bias).__class__ == nn.Parameter + # assert column_linear.bias.data.requires_grad is (with_bias is True) + + # Generate random input + random_input: torch.Tensor + sharded_random_input: torch.Tensor + + # batch_size = 5 + batch_size = 16 + random_input = torch.randn(batch_size, in_features, device="cuda") + # synchronize random_input across tp + dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) + sharded_random_input = random_input + + dist.barrier() + assert_tensor_synced_across_pg(random_input, pg=parallel_context.tp_pg) + + # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage + sharded_random_input = sharded_random_input.clone() + sharded_random_input = sharded_random_input.contiguous() + random_input.requires_grad = True + sharded_random_input.requires_grad = True + + # Test that we get the same output after forward pass + sharded_output = column_linear(sharded_random_input) + + reference_output = reference_linear(random_input) + # reference_output = ReferenceLinear.apply(random_input, reference_linear.weight, reference_linear.bias) + + # TODO @thomasw21: Tune tolerance + try: + torch.testing.assert_close( + sharded_output, + # TODO(xrsrke): retrieve accumulation precision from recipe + # NOTE: before the reference_output.to(torch.bfloat16) + reference_output[ + :, + dist.get_rank(parallel_context.tp_pg) + * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * out_features_per_tp_rank, + ].to(torch.bfloat16), + rtol=0, + atol=0.1, + ) + except BaseException as e: + print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: FAIL.") + dist.barrier() + raise e + + print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: SUCCESS.") + dist.barrier() + + # Test that we get the same gradient after backward pass + sharded_output.sum().backward() + reference_output.sum().backward() + hidden_dim_slice = slice( + dist.get_rank(parallel_context.tp_pg) * out_features_per_tp_rank, + (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank, + ) + + torch.testing.assert_close( + # convert_tensor_from_fp8(column_linear.weight.data, column_linear.weight.data.fp8_meta, torch.bfloat16), + convert_tensor_from_fp8( + get_data_from_param(column_linear.weight), + get_data_from_param(column_linear.weight).fp8_meta, + torch.bfloat16, + ), + reference_linear.weight[hidden_dim_slice].to(torch.bfloat16), + rtol=0.1, + atol=0.1, + ) + + # TODO(xrsrke): retrieve accumulation precision from recipe + assert sharded_output.dtype == torch.bfloat16 + # NOTE(xrsrke): we expect the output is a raw torch.Tensor, not FP8Paramter, or NanotronParameter + # assert isinstance(sharded_output, torch.Tensor) + assert sharded_output.__class__ == torch.Tensor + assert sharded_output.requires_grad is True + + torch.testing.assert_close(sharded_random_input.grad, random_input.grad, rtol=0.1, atol=0.1) + + if with_bias is True: + torch.testing.assert_close( + column_linear.bias.grad, + reference_linear.bias.grad[hidden_dim_slice], + ) + + if isinstance(get_data_from_param(column_linear.weight), FP8Tensor): + # grad = column_linear.weight.data._temp_grad + # grad = convert_tensor_from_fp8(grad, column_linear.weight.data._temp_grad.fp8_meta, torch.bfloat16) + grad = get_grad_from_parameter(column_linear.weight) + grad = convert_tensor_from_fp8(grad, grad.fp8_meta, torch.bfloat16) + else: + # grad = column_linear.weight.grad + grad = get_grad_from_parameter(column_linear.weight) + + torch.testing.assert_close( + grad, + reference_linear.weight.grad[hidden_dim_slice].to(torch.bfloat16), + # rtol=0.1, atol=0.1 + rtol=0.2, + atol=0.2, + ) + + parallel_context.destroy() diff --git a/examples/config_tiny_fp8_llama.yaml b/examples/config_tiny_fp8_llama.yaml new file mode 100644 index 00000000..2c43eb45 --- /dev/null +++ b/examples/config_tiny_fp8_llama.yaml @@ -0,0 +1,109 @@ +checkpoints: + checkpoint_interval: 10000 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 1 + hf_dataset_config_name: null + hf_dataset_or_datasets: roneneldan/TinyStories + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +# - data: +# dataset: +# dataset_overwrite_cache: false +# dataset_processing_num_proc_per_process: 1 +# hf_dataset_config_name: null +# hf_dataset_or_datasets: stas/openwebtext-10k +# hf_dataset_splits: train +# text_column_name: text +# num_loading_workers: 1 +# seed: 42 +# name: Annealing Phase +# start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: false + project: debug + run: tiny_llama_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: float8 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 1024 + initializer_range: 0.02 + intermediate_size: 4096 + is_llama_config: true + max_position_embeddings: 1024 + num_attention_heads: 4 + num_hidden_layers: 6 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_scaling: null + tie_word_embeddings: false + use_cache: true + vocab_size: 1024 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 2 + tp_linear_async_communication: false + tp_mode: ALL_REDUCE +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: robot-test/dummy-tokenizer-wordlevel + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 2 + sequence_length: 1024 + train_steps: 1500 + val_check_interval: -1 diff --git a/examples/config_tiny_llama.yaml b/examples/config_tiny_llama.yaml index 58645e2d..e1fac82e 100644 --- a/examples/config_tiny_llama.yaml +++ b/examples/config_tiny_llama.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 10 + checkpoint_interval: 10000 checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null @@ -10,25 +10,25 @@ data_stages: dataset_overwrite_cache: false dataset_processing_num_proc_per_process: 1 hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_or_datasets: roneneldan/TinyStories hf_dataset_splits: train text_column_name: text num_loading_workers: 1 seed: 42 name: Stable Training Stage start_training_step: 1 -- data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 - name: Annealing Phase - start_training_step: 10 +# - data: +# dataset: +# dataset_overwrite_cache: false +# dataset_processing_num_proc_per_process: 1 +# hf_dataset_config_name: null +# hf_dataset_or_datasets: stas/openwebtext-10k +# hf_dataset_splits: train +# text_column_name: text +# num_loading_workers: 1 +# seed: 42 +# name: Annealing Phase +# start_training_step: 10 general: benchmark_csv_path: null consumed_train_samples: null @@ -52,13 +52,13 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 16 + hidden_size: 1024 initializer_range: 0.02 - intermediate_size: 64 + intermediate_size: 4096 is_llama_config: true - max_position_embeddings: 256 + max_position_embeddings: 1024 num_attention_heads: 4 - num_hidden_layers: 2 + num_hidden_layers: 6 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -66,7 +66,7 @@ model: rope_scaling: null tie_word_embeddings: true use_cache: true - vocab_size: 256 + vocab_size: 1024 optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 @@ -87,13 +87,13 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 2 + dp: 1 expert_parallel_size: 1 - pp: 2 + pp: 1 pp_engine: 1f1b tp: 2 - tp_linear_async_communication: true - tp_mode: REDUCE_SCATTER + tp_linear_async_communication: false + tp_mode: ALL_REDUCE profiler: null tokenizer: tokenizer_max_length: null @@ -104,6 +104,6 @@ tokens: limit_test_batches: 0 limit_val_batches: 0 micro_batch_size: 2 - sequence_length: 256 - train_steps: 15 + sequence_length: 1024 + train_steps: 1500 val_check_interval: -1 diff --git a/examples/llama/tests/test_conversion.py b/examples/llama/tests/test_conversion.py index b5ce3529..8fa02a81 100644 --- a/examples/llama/tests/test_conversion.py +++ b/examples/llama/tests/test_conversion.py @@ -15,6 +15,7 @@ from nanotron.models.base import init_on_device_and_dtype from nanotron.models.llama import LlamaForTraining from nanotron.parallel import ParallelContext +from nanotron.testing.utils import TestContext from nanotron.trainer import mark_tied_parameters from examples.llama.convert_hf_to_nanotron import convert_checkpoint_and_save as convert_hf_to_nt_and_save @@ -22,7 +23,6 @@ from examples.llama.convert_nanotron_to_hf import convert_checkpoint_and_save as convert_nt_to_hf_and_save from examples.llama.convert_nanotron_to_hf import convert_nt_to_hf, get_hf_config from examples.llama.convert_weights import load_nanotron_model, make_parallel_config -from tests.helpers.context import TestContext from tests.helpers.utils import init_distributed CONFIG = NanotronLlamaConfig( @@ -141,7 +141,7 @@ def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor): logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) logits_hf = model_hf(input_ids).logits assert logits_nt.size() == logits_hf.size() - torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL) + torch.testing.assert_allclose(logits_hf, logits_nt, atol=ATOL) def test_hf_to_nt(input_ids: torch.Tensor): diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 8a8c8926..73ca6286 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -11,6 +11,7 @@ from datasets.download.streaming_download_manager import xPath from yaml.loader import SafeLoader +from nanotron.config.fp8_config import FP8Args from nanotron.config.lighteval_config import LightEvalConfig from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs @@ -317,6 +318,7 @@ class OptimizerArgs: clip_grad: Optional[float] accumulate_grad_in_fp32: bool learning_rate_scheduler: LRSchedulerArgs + master_weight_dtype: torch.dtype = torch.float32 @dataclass @@ -353,6 +355,7 @@ class Config: profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None s3_upload: Optional[S3UploadArgs] = None + fp8: Optional[FP8Args] = None @classmethod def create_empty(cls): @@ -400,6 +403,10 @@ def __post_init__(self): # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None + if self.model.dtype == torch.int8: + if self.fp8 is None: + self.fp8 = FP8Args() + @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp @@ -440,6 +447,9 @@ def get_config_from_dict( for k, v in config_dict.items() if v is not None } + + from nanotron.fp8.dtypes import DTypes + return from_dict( data_class=config_class, data=config_dict, @@ -451,6 +461,7 @@ def get_config_from_dict( TensorParallelLinearMode: lambda x: TensorParallelLinearMode[x.upper()], RecomputeGranularity: lambda x: RecomputeGranularity[x.upper()], SamplerType: lambda x: SamplerType[x.upper()], + DTypes: lambda x: DTypes[x.upper()], # Add this line, }, # strict_unions_match=True, strict=True, diff --git a/src/nanotron/config/fp8_config.py b/src/nanotron/config/fp8_config.py new file mode 100644 index 00000000..59a25243 --- /dev/null +++ b/src/nanotron/config/fp8_config.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from nanotron.fp8.constants import FP8LM_OPTIM_RECIPE +from nanotron.fp8.recipe import FP8LinearRecipe, FP8OptimRecipe + + +@dataclass +class FP8LayerArgs(FP8LinearRecipe): + module_name: Optional[str] = None + + def __post_init__(self): + assert self.module_name is not None, "module_name must be specified" + + +@dataclass +class FP8Args: + # NOTE: this is the datatype of model initialization, before casting to fp8 + init_dtype: torch.dtype = torch.float32 + # NOTE: this is the datatype for residual stream (aka: non-fp8 operation) + resid_dtype: torch.dtype = torch.float32 + # NOTE: the datatype for fp8 operation's accumulation + accum_dtype: torch.dtype = torch.bfloat16 + + model: Optional[List[FP8LayerArgs]] = None + optim: Optional[FP8OptimRecipe] = FP8LM_OPTIM_RECIPE + + run_fp8_sanity_check: bool = False + + update_clipping: bool = False + skip_param_update_if_nan: bool = False + + sync_amax_in_input: bool = False + sync_amax_in_weight: bool = False + sync_amax_in_igrad: bool = False + sync_amax_in_wgrad: bool = False + sync_amax_func: str = "default" + weight_decay_without_lr_decay: bool = False + + triton_rms_norm: bool = False + + is_sanity_logging: bool = False + is_post_scaling_all_reduce: bool = True + # NOTE: 1.0e-6 was the default + gradient_clipping_eps: float = 1.0e-6 + + is_quant_all_except_first_and_last: Optional[bool] = None + fp8_linear_config_temp: Optional[FP8LayerArgs] = None diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index f4c07146..6cc092d4 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -62,6 +62,7 @@ def serialize(data) -> dict: "bfloat16": torch.bfloat16, "uint8": torch.uint8, "int8": torch.int8, + "float8": torch.int8, "int16": torch.int16, "int32": torch.int32, "int64": torch.int64, @@ -75,8 +76,9 @@ def serialize(data) -> dict: torch.complex128: "complex128", torch.float16: "float16", torch.bfloat16: "bfloat16", - torch.uint8: "uint8", - torch.int8: "int8", + # torch.uint8: "uint8", + # torch.int8: "int8", + torch.int8: "float8", torch.int16: "int16", torch.int32: "int32", torch.int64: "int64", diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 580bd99d..e0603164 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -10,3 +10,16 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" + + +# TODO(xrsrke): remove this shit +ITERATION_STEP = 1 +# TODO(xrsrke): refactor to training stage, +# keep it in the same class as iteration_step +CONFIG = None + +is_ready_to_log = False + +# TODO(xrsrke): refactor +CPU_WEIGHTS = {} +ACCUM_GRADS = {} diff --git a/src/nanotron/fp8/DESIGN.md b/src/nanotron/fp8/DESIGN.md new file mode 100644 index 00000000..ff8a2a26 --- /dev/null +++ b/src/nanotron/fp8/DESIGN.md @@ -0,0 +1,25 @@ + +- For parameters like input embedding, where we are just indexing to get the corresponding embedding vectors, FP8 doesn't have to speed up the matmul since this is not a GEMM operation. +- We only keep master weights of FP8 modules. For non-FP8 modules, we directly keep them in float16. + +### Key Technical Details + +- Selectively choose a suitable FP8 format for weights, gradients, and activations. +- Selectively choose which layers should be in FP8. +- Perform delayed and dynamic quantization on the fly. +- Use mixed precision training for FP8 parameters (we don't keep a master weight for non-FP8 parameters). +- Loss scaling. +- Direct communication in FP8. +- Minimize quantization errors in FP8 all-reduce by taking into account the min/max range of participant tensors. +- Perform optimizer state calculations in FP32 to retain precision. + + +### Tips +- FP8 gives a net positive if the model is large. + + +- In the FP8 recipe, when you set quantize an input to FP8, that means after an FP8 gemm operation, we will store the outputs as FP8, communicate it using the FP8 format, but if you set an input to float16, that means we store it as float16, but we will also need to quantize it to FP8 in order to performs + + +k{dtype}: a data type with scaling factor +dtype: a torch dtype diff --git a/src/nanotron/fp8/__init__.py b/src/nanotron/fp8/__init__.py index 2adc80c2..963d424b 100644 --- a/src/nanotron/fp8/__init__.py +++ b/src/nanotron/fp8/__init__.py @@ -7,6 +7,6 @@ try: import transformer_engine as te # noqa - import transformer_engine_extensions as tex # noqa + import transformer_engine_torch as tex # noqa except ImportError: warnings.warn("Please install Transformer engine for FP8 training!") diff --git a/src/nanotron/fp8/constant_recipe.py b/src/nanotron/fp8/constant_recipe.py new file mode 100644 index 00000000..5d3ea7bb --- /dev/null +++ b/src/nanotron/fp8/constant_recipe.py @@ -0,0 +1,13 @@ +from torch import nn + +from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding + +MODULE_NAMES_THAT_NOT_FP8 = [ + "rotary_embedding", + "token_embedding", + "input_layernorm", + "post_attention_layernorm", + "final_layer_norm", + "lm_head", +] +MODULES_THAT_IN_FLOAT16 = [TensorParallelEmbedding, nn.LayerNorm] diff --git a/src/nanotron/fp8/constants.py b/src/nanotron/fp8/constants.py index 996843bc..730a7afe 100644 --- a/src/nanotron/fp8/constants.py +++ b/src/nanotron/fp8/constants.py @@ -1,18 +1,114 @@ import torch from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.recipe import ( + FP8LinearRecipe, + FP8OptimRecipe, + FP8SplitAccumulator, + FP8TensorRecipe, + FP8TrainingRecipe, +) FP8_GPU_NAMES = ["h100", "rtx 4090"] -INITIAL_AMAX = 1.0 -INITIAL_SCALING_FACTOR = 1.0 - -# FP8_DTYPES = [torch.fp8e4m3, torch.fp8e5m2] -# FP8E4M3_DTYPE = torch.fp8e4m3 -# FP8E5M2_DTYPE = torch.fp8e5m2 +INITIAL_AMAX = torch.tensor(1.0, dtype=torch.float32) +INITIAL_SCALING_FACTOR = torch.tensor(1.0, dtype=torch.float32) FP8_DTYPES = [torch.int8, torch.uint8] FP8E4M3_DTYPE = torch.int8 FP8E5M2_DTYPE = torch.uint8 +# TODO(xrsrke): rename to DTYPE_TO_FP_MAX +# TODO(xrsrke): change to QDTYPE DTYPE_TO_FP8_MAX = {DTypes.FP8E4M3: 448.0, DTypes.FP8E5M2: 57344.0, DTypes.KFLOAT16: 65504.0} + +QTYPE_TO_DTYPE = { + # DTypes.FP8E4M3: torch.int8, + # TODO(xrsrke): FP8E4M3 stores as uint8? + DTypes.FP8E4M3: torch.int8, + DTypes.FP8E5M2: torch.uint8, + DTypes.KFLOAT16: torch.float16, + DTypes.KFLOAT32: torch.float32, + DTypes.KBFLOAT16: torch.bfloat16, + torch.float16: torch.float16, + torch.float32: torch.float32, + torch.bfloat16: torch.bfloat16, +} + +# NOTE: the training recipe of the FP8-LM paper +# FP8-LM: Training FP8 Large Language Models +# https://arxiv.org/abs/2310.18313 + +# FP8-LM +# weight.window_size = 1, input.window_size = 16, +# wgrad.window_size = 1, ograd.window_size = 16 (this one is the input of the backward pass), +# input_grad.window_size = 1 (this one is the output of the backward pass) + +# TODO(xrsrke): differentiate the precision that you initializes model weight +# and the accumulation precision in FP8 recipe + +FP8LM_LINEAR_RECIPE = FP8LinearRecipe( + accum_dtype=torch.bfloat16, + input=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=16), + weight=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1), + bias=torch.float16, + # NOTE: these are the dtypes for the gradients + input_grad=FP8TensorRecipe(dtype=DTypes.FP8E5M2, margin=0, interval=16), # NOTE: this is output_grad + weight_grad=FP8TensorRecipe(dtype=DTypes.FP8E4M3, margin=0, interval=1), + output_grad=FP8TensorRecipe(dtype=DTypes.FP8E5M2, margin=0, interval=16), + # NOTE: tested, and it works + split_accumulator=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True), + accumulate=FP8SplitAccumulator(output=True, input_grad=True, weight_grad=True), + # # NOTE: passes the test with 4% speed up relative to the above + # split_accumulator=FP8SplitAccumulator(output=False, input_grad=True, weight_grad=True), + # accumulate=FP8SplitAccumulator(output=False, input_grad=False, weight_grad=True), +) + +FP8LM_OPTIM_RECIPE = FP8OptimRecipe( + accum_dtype=torch.float32, + master_weight_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, +) + +FP8LM_RECIPE = FP8TrainingRecipe( + linear=FP8LM_LINEAR_RECIPE, + optim=FP8LM_OPTIM_RECIPE, +) + +### FOR DYNAMIC LOSS SCALING ### + +# TODO(xrsrke): Make it more deliberate, like if people import this constant, +# they should know that it is a constant for dynamic loss scaling +# NOTE: these initial scaling factors are from deepspeed, but we are technically free to choose our own +# LS_INITIAL_SCALING_VALUE = torch.tensor(2**32, dtype=torch.float32) +# 2^15 = 32768 +# LS_INITIAL_SCALING_VALUE = torch.tensor(2**32, dtype=torch.float32) +LS_INITIAL_SCALING_VALUE = torch.tensor(32768, dtype=torch.float32) +LS_INITIAL_SCALING_FACTOR = torch.tensor(2.0, dtype=torch.float32) +LS_INTERVAL = 1000 + + +# FOR TESTING +# NOTE: this tolerance is from FP8-LM's implementation +# reference: https://github.com/Azure/MS-AMP/blob/9ac98df5371f3d4174d8f103a1932b3a41a4b8a3/tests/common/tensor/test_cast.py#L23 +# NOTE: i tried to use rtol=0, atol=0.1 +# but even msamp fails to pass 6/8 tests +# so now use 0.1, but better do a systematic tuning +FP8_RTOL_THRESHOLD = 0.1 +FP8_ATOL_THRESHOLD = 0.1 + +FP16_RTOL_THRESHOLD = 0 +FP16_ATOL_THRESHOLD = 1e-03 + +# NOTE: FP8-LM use RTOL is 0, and ATOL is 3e-4 for model weights +FP8_WEIGHT_RTOL_THRESHOLD = 0 +FP8_WEIGHT_ATOL_THRESHOLD = 0.1 + +FP8_1ST_OPTIM_STATE_RTOL_THRESHOLD = 0 +FP8_1ST_OPTIM_STATE_ATOL_THRESHOLD = 0.1 +FP8_2ND_OPTIM_STATE_RTOL_THRESHOLD = 0 +FP8_2ND_OPTIM_STATE_ATOL_THRESHOLD = 0.1 + +workspace = torch.empty(33_554_432, dtype=torch.int8, device="cuda") +_empty_tensor = torch.Tensor() diff --git a/src/nanotron/fp8/distributed.py b/src/nanotron/fp8/distributed.py new file mode 100644 index 00000000..3ca10ce0 --- /dev/null +++ b/src/nanotron/fp8/distributed.py @@ -0,0 +1,65 @@ +from typing import List, Union + +import torch +import torch.distributed as dist +from torch.distributed import * # noqa + +from nanotron.distributed import * +from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8 +from nanotron.parallel.parameters import NanotronParameter + + +def post_scaling_all_reduce_mean( + tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None, async_op=False +) -> torch.Tensor: + dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op) + tensor.div_(dist.get_world_size()) + return tensor + + +def post_scaling_all_reduce_coalesced_mean( + tensor: torch.Tensor, group: Optional[dist.ProcessGroup] = None, async_op=False +) -> torch.Tensor: + dist.all_reduce_coalesced(tensor, op=dist.ReduceOp.SUM, group=group, async_op=async_op) + tensor.div_(dist.get_world_size()) + return tensor + + +def all_reduce( + tensor: Union[torch.Tensor, NanotronParameter], + op: dist.ReduceOp = dist.ReduceOp.SUM, + group: Optional[dist.ProcessGroup] = None, + async_op: bool = False, +): + # NOTE: if keep nn.Parameter, got the following error: + # File ".../.vscode-server/extensions/ms-python.debugpy-2024.12.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_vars.py", line 267, in eval_in_context + # result = eval(compiled, global_vars, local_vars) + # File "", line 1, in + # AttributeError: module 'torch.distributed.nn' has no attribute 'Parameter' + assert tensor.__class__ in [torch.Tensor, torch.nn.Parameter, NanotronParameter] + data = get_data_from_param(tensor) if tensor.__class__ == NanotronParameter else tensor + + # if data.__class__ == FP8Tensor: + # dist.all_reduce(data, op=op, group=group, async_op=async_op) + # else: + # dist.all_reduce(data, op=op, group=group, async_op=async_op) + dist.all_reduce(data, op=op, group=group, async_op=async_op) + + +def all_gather( + tensor_list: List[torch.Tensor], + tensor: Union[FP8Tensor, NanotronParameter], + group: dist.ProcessGroup, + async_op: bool = False, +) -> torch.Tensor: + tensor = get_data_from_param(tensor) if tensor.__class__ == NanotronParameter else tensor + + if tensor.__class__ == FP8Tensor: + # TODO(xrsrke): convert to the dtype of the first tensor in the list + tensor = ( + convert_tensor_from_fp8(tensor, tensor.fp8_meta, torch.float32) + if tensor_list[0].dtype != tensor.dtype + else tensor + ) + + dist.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op) diff --git a/src/nanotron/fp8/dtypes.py b/src/nanotron/fp8/dtypes.py index 347ce8b6..e6b6dad7 100644 --- a/src/nanotron/fp8/dtypes.py +++ b/src/nanotron/fp8/dtypes.py @@ -1,7 +1,12 @@ -from enum import Enum, auto +from enum import Enum +# TODO(xrsrke): don't use plural +# TODO(xrsrke): change to QDType, so we don't mistaken it with the torch dtype +# QDType = Quantization DType class DTypes(Enum): - FP8E4M3 = auto() - FP8E5M2 = auto() - KFLOAT16 = auto() + FP8E4M3 = "FP8E4M3" + FP8E5M2 = "FP8E5M2" + KFLOAT16 = "kfloat16" + KFLOAT32 = "KFLOAT32" + KBFLOAT16 = "kbfloat16" diff --git a/src/nanotron/fp8/functional.py b/src/nanotron/fp8/functional.py new file mode 100644 index 00000000..cdf9e58a --- /dev/null +++ b/src/nanotron/fp8/functional.py @@ -0,0 +1,55 @@ +from typing import Optional + +import torch + +from nanotron.fp8.linear import FP8LinearMeta +from nanotron.fp8.recipe import FP8LinearRecipe +from nanotron.parallel.parameters import NanotronParameter + + +def linear( + input: torch.Tensor, + weight: NanotronParameter, + bias: Optional[torch.Tensor] = None, + metadatas: FP8LinearMeta = None, + recipe: FP8LinearRecipe = None, + name: Optional[str] = None, +): + assert isinstance(weight, NanotronParameter) + + assert metadatas is not None, "metadatas must be specified" + assert recipe is not None, "recipe must be specified" + assert input.device != torch.device("cpu"), "FP8Linear only supports CUDA tensors" + + # TODO(xrsrke): refactor this out, don't duplicate the code + from einops import rearrange + + from nanotron.fp8.linear import _FP8Matmul + + seq_len = None + batch_size = None + is_input_flat = False + if input.ndim == 3: + batch_size = input.shape[0] + seq_len = input.shape[1] + is_input_flat = True + input = rearrange(input, "b n h -> (b n) h") + elif input.ndim > 3: + raise ValueError(f"Unsupported input shape: {input.shape}") + + # NOTE: just a phony tensor to make pytorch trigger the backward pass + # because weight and bias's requires_grad are set to False + # so that we can compute the gradients using the fp8 kernels by ourselves + phony = torch.empty(0, device=input.device, requires_grad=True) + # NOTE: interesting that if i initialize the output buffer as torch.empty + # it leads to nan matmul, so i do torch.zeros instead + # output = torch.empty(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype) + output = torch.zeros(input.shape[0], weight.shape[0], device="cuda", dtype=recipe.accum_dtype) + output, _ = _FP8Matmul.apply(input, weight, output, phony, metadatas, recipe, name) + + # TODO(xrsrke): add support for adding bias in fp8 + # TODO(xrsrke): support return an fp8 tensor as output + # since we will quantize it back to FP8 anyway in the next linear + output = rearrange(output, "(b n) h -> b n h", n=seq_len, b=batch_size) if is_input_flat is True else output + output = output if bias is None else output + bias + return output diff --git a/src/nanotron/fp8/kernel.py b/src/nanotron/fp8/kernel.py index e2c80981..0f09aab6 100644 --- a/src/nanotron/fp8/kernel.py +++ b/src/nanotron/fp8/kernel.py @@ -1,34 +1,50 @@ import torch import transformer_engine as te # noqa -import transformer_engine_extensions as tex +import transformer_engine_torch as tex -from nanotron.fp8.tensor import FP8Tensor from nanotron.fp8.meta import FP8Meta +from nanotron.fp8.tensor import FP8Tensor @torch.no_grad() def fp8_matmul_kernel( mat_a: FP8Tensor, - transpose_a: bool, mat_b: FP8Tensor, - transpose_b: bool, + output, use_split_accumulator: bool, + accumulate: bool, + accum_qtype: torch.dtype, + # TODO(xrsrke): remove this flag ) -> torch.Tensor: + # from nanotron.fp8.constants import _empty_tensor, workspace + assert ( mat_a.device != "cpu" and mat_b.device != "cpu" - ), "The tensors must be on a CUDA device in order to use the FP8 kernel!!" + ), "The tensors must be on a CUDA device in order to use FP8!!" + # assert isinstance(accum_qtype, DTypes) + assert isinstance(accum_qtype, torch.dtype) device = mat_a.device + # NOTE: this is the accumulation precision dtype + if accum_qtype == torch.float32: + out_dtype = getattr(tex.DType, "kFloat32") + out_torch_dtype = torch.float32 + elif accum_qtype == torch.float16: + out_dtype = getattr(tex.DType, "kFloat16") + out_torch_dtype = torch.float16 + elif accum_qtype == torch.bfloat16: + out_dtype = getattr(tex.DType, "kBFloat16") + out_torch_dtype = torch.bfloat16 + else: + raise ValueError(f"Unsupported accumulation dtype: {accum_qtype}") + _empty_tensor = torch.Tensor() - output = torch.empty(mat_a.shape[0], mat_b.shape[1], device=device, dtype=torch.float32) workspace = torch.empty(33_554_432, dtype=torch.int8, device=device) - accumulate = False - out_dtype = getattr(tex.DType, "kFloat32") # NOTE: currently TE don't support adding bias in FP8 # along with matmul, it only takes an empty bias - bias = torch.tensor([], dtype=torch.float32) + bias = torch.tensor([], dtype=out_torch_dtype) TE_CONFIG_TRANSPOSE_BIAS = False mat_a_fp8_meta: FP8Meta = mat_a.fp8_meta @@ -40,9 +56,6 @@ def fp8_matmul_kernel( TE_CONFIG_TRANSPOSE_B = False SCALE = AMAX = _empty_tensor - mat_a = tex.fp8_transpose(mat_a, mat_a_fp8_meta.te_dtype) if transpose_a is False else mat_a - mat_b = tex.fp8_transpose(mat_b, mat_b_fp8_meta.te_dtype) if transpose_b is True else mat_b - tex.te_gemm( mat_a, mat_a_fp8_meta.inverse_scale, diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index dcda9b1e..96777e97 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -1,112 +1,242 @@ -from typing import Optional, Tuple, TypedDict, Union +from dataclasses import dataclass +from typing import Optional, Tuple, Union, cast import torch -import torch.nn.functional as F import transformer_engine as te # noqa from torch import nn -from nanotron.fp8.constants import INITIAL_AMAX, INITIAL_SCALING_FACTOR -from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.constants import FP8LM_LINEAR_RECIPE from nanotron.fp8.kernel import fp8_matmul_kernel from nanotron.fp8.meta import FP8Meta -from nanotron.fp8.parameter import FP8Parameter -from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor +from nanotron.fp8.recipe import FP8LinearRecipe +from nanotron.fp8.tensor import FP8Tensor +from nanotron.parallel.parameters import NanotronParameter -class FP8LinearMeta(TypedDict): +@dataclass +class FP8LinearMeta: """FP8 metadata for FP8Linear.""" - input_grad: FP8Meta - weight_grad: FP8Meta - output_grad: FP8Meta + input: Optional[FP8Meta] = None + weight: Optional[FP8Meta] = None + input_grad: Optional[FP8Meta] = None + weight_grad: Optional[FP8Meta] = None class FP8Linear(nn.Linear): - def __init__(self, in_features: int, out_features: int, bias: bool = True, device: Optional[torch.device] = None): - super().__init__(in_features, out_features, bias, device) - # TODO(xrsrke): add device, and 2 fp8 dtypes - if self.weight.device != torch.device("cpu"): - self.weight = FP8Parameter(self.weight, dtype=DTypes.FP8E4M3) - - # NOTE: quantization metadata for input gradients, weight gradients, and output gradients - # TODO(xrsrke): don't fixed this - fp8e4m3_scale = update_scaling_factor( - amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32), - scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR), - dtype=DTypes.FP8E4M3, - ) - fp8e5m2_scale = update_scaling_factor( - amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32), - scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), - dtype=DTypes.FP8E5M2, - ) - self.fp8_meta: FP8LinearMeta = { - # kfloat8_e4m3 - "input_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale), - "weight_grad": FP8Meta(amax=1, dtype=DTypes.FP8E4M3, scale=fp8e4m3_scale), - # kfloat8_e5m2 - "output_grad": FP8Meta(amax=1, dtype=DTypes.FP8E5M2, scale=fp8e5m2_scale), - } + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: Optional[torch.device] = None, + ): + """ + Args: + qtype (DTypes, optional): This is accumulation precision dtype + """ + assert device != torch.device("cpu"), "FP8Linear only supports CUDA tensors" - def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor: - # NOTE: only do fp8 kernel if both input and weight are on CUDA device - if input.device == torch.device("cpu") or self.weight.device == torch.device("cpu"): - return F.linear(input, self.weight, self.bias) + # TODO(xrsrke): take initialization dtype from recipe + # NOTE: initialize in float32 + super().__init__(in_features, out_features, bias, device, dtype=torch.float32) + self._set_and_quantize_weights(self.weight.data) - # NOTE: just a phony tensor to make pytorch trigger the backward pass - # because weight and bias's requires_grad are set to False - # so that we can compute the gradients using the fp8 kernels by ourselves - phony = torch.empty(0, device=input.device, requires_grad=True) - output, _ = _FP8Matmul.apply(input, self.weight, self.fp8_meta, phony) + def _set_and_quantize_weights(self, data: torch.Tensor, recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE): + """ + data: if set to None, then we quantize the module's current weights, otherwise, we quantize + the provided tensor + """ + assert data is None or isinstance(data, torch.Tensor) + quant_w = FP8Tensor(data, dtype=recipe.weight.dtype, interval=recipe.weight.interval) - # TODO(xrsrke): add support for adding bias in fp8 - # TODO(xrsrke): support return an fp8 tensor as output - # since we will quantize it back to FP8 anyway in the next linear - output = output if self.bias is None else output + self.bias - return output + # NOTE: if we create a new parameter, then we can have that new quantized parameter + # in [torch.int8, torch.uint8] dtype, then we can assign int|uint8 gradient to it + # TODO(xrsrke): keep the metadata of the original NanotronParameter + new_param = NanotronParameter.create_param_that_share_metadata(quant_w, param=self.weight) + setattr(self, "weight", new_param) + + # NOTE: assume each time we requantize the weights, we reset the metadata + self.metadatas = FP8LinearMeta() + self.recipe = recipe + + def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor: + import nanotron.fp8.functional as F + + return F.linear( + input=input, + weight=self.weight, + bias=self.bias, + metadatas=self.metadatas, + recipe=self.recipe, + ) + + # def __repr__(self) -> str: + # return f"FP8{super().__repr__()}" class _FP8Matmul(torch.autograd.Function): @staticmethod - @torch.no_grad() + # @torch.no_grad() def forward( - ctx, input: FP8Tensor, weight: FP8Tensor, fp8_meta: FP8LinearMeta, phony: torch.Tensor + ctx, + input: Union[FP8Tensor, torch.Tensor], + weight: NanotronParameter, + output: torch.Tensor, + phony: torch.Tensor, + metadatas: FP8LinearMeta, + recipe: FP8LinearRecipe, + name, ) -> torch.Tensor: - if type(input) == torch.Tensor: - input = FP8Tensor(input, dtype=DTypes.FP8E4M3) + assert not isinstance(input, FP8Tensor) + assert isinstance(weight, NanotronParameter) + + from nanotron import constants + from nanotron.config.fp8_config import FP8Args + + if constants.CONFIG is None: + fp8_config = FP8Args() + else: + fp8_config = cast(FP8Args, constants.CONFIG.fp8) + + sync_amax_in_input = fp8_config.sync_amax_in_input + + if metadatas.input is None: + fp8_input = FP8Tensor( + input, dtype=recipe.input.dtype, interval=recipe.input.interval, sync=sync_amax_in_input + ) + metadatas.input = fp8_input.fp8_meta + else: + fp8_input = FP8Tensor.from_metadata(input, metadatas.input, sync=sync_amax_in_input) - ctx.save_for_backward(input, weight) - ctx.fp8_meta = fp8_meta + ctx.save_for_backward(fp8_input, weight) + ctx.is_input_require_grad = input.requires_grad + ctx.metadatas = metadatas + ctx.name = name + ctx.recipe = recipe + + accum_output = output - # NOTE: pass FP8Tensor instead of FP8Parameter output = fp8_matmul_kernel( - mat_a=weight.data, transpose_a=True, mat_b=input, transpose_b=False, use_split_accumulator=False + # NOTE: that works + mat_a=weight.data, + mat_b=fp8_input, + output=accum_output, + use_split_accumulator=recipe.split_accumulator.output, + accumulate=recipe.accumulate.output, + accum_qtype=recipe.accum_dtype, ) - return output, phony @staticmethod - @torch.no_grad() + # @torch.no_grad() # NOTE: drop 5% speed up in fwd only, and add 2% speed up in fwd+bwd def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[torch.Tensor, None, None, None]: """ ∂L/∂X = ∂L/∂Y @ Wᵀ ∂L/∂W = Xᵀ @ ∂L/∂Y - Source: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html + Reference: https://web.eecs.umich.edu/~justincj/teaching/eecs442/notes/linear-backprop.html """ - # TODO(xrsrke): investigate how does grad_output.contiguous() affect the outputs - input, weight = ctx.saved_tensors + from typing import cast + + from nanotron import constants + from nanotron.config.fp8_config import FP8Args + from nanotron.fp8.utils import is_overflow_underflow_nan + + if constants.CONFIG is None: + fp8_config = FP8Args() + else: + fp8_config = cast(FP8Args, constants.CONFIG.fp8) - if type(grad_output) == torch.Tensor: - grad_output = torch.ones_like(grad_output) - grad_output = grad_output.contiguous() - grad_output = FP8Tensor(grad_output, dtype=DTypes.FP8E5M2) + sync_amax_in_igrad = fp8_config.sync_amax_in_igrad + sync_amax_in_wgrad = fp8_config.sync_amax_in_wgrad - grad_input = fp8_matmul_kernel( - mat_a=grad_output, transpose_a=True, mat_b=weight, transpose_b=True, use_split_accumulator=True + fp8_input, fp8_weight_param = ctx.saved_tensors + fp8_weight = fp8_weight_param.data + recipe = ctx.recipe + recipe = cast(FP8LinearRecipe, recipe) + + fp8_input = cast(FP8Tensor, fp8_input) + fp8_weight = cast(FP8Tensor, fp8_weight) + + assert is_overflow_underflow_nan(grad_output) is False, f"name: {ctx.name}" + + ctx.metadatas = cast(FP8LinearMeta, ctx.metadatas) + if ctx.metadatas.input_grad is None: + fp8_grad_output = FP8Tensor( + grad_output, + dtype=recipe.input_grad.dtype, + interval=recipe.input_grad.interval, + sync=sync_amax_in_igrad, + ) + ctx.metadatas.input_grad = fp8_grad_output.fp8_meta + else: + fp8_grad_output = FP8Tensor.from_metadata(grad_output, ctx.metadatas.input_grad, sync=sync_amax_in_igrad) + + if ctx.is_input_require_grad: + transposed_fp8_weight = fp8_weight.transpose_fp8() + # NOTE: same reason as output buffer in .forward + grad_input_temp = torch.zeros( + fp8_grad_output.shape[0], + transposed_fp8_weight.shape[0], + device="cuda", + dtype=recipe.accum_dtype, + ) + grad_input = fp8_matmul_kernel( + mat_a=transposed_fp8_weight, + mat_b=fp8_grad_output, + output=grad_input_temp, + use_split_accumulator=recipe.split_accumulator.input_grad, + accum_qtype=recipe.accum_dtype, + accumulate=recipe.accumulate.input_grad, + ) + grad_input.__debug_is_from_fp8 = True + else: + grad_input = None + + assert is_overflow_underflow_nan(grad_input) is False + + # TODO(xrsrke): fuse cast and transpose + transposed_fp8_grad_output = fp8_grad_output.transpose_fp8() + transposed_fp8_input = fp8_input.transpose_fp8() + + # NOTE: same reason as output buffer in .forward + grad_weight_temp = torch.zeros( + transposed_fp8_input.shape[0], + transposed_fp8_grad_output.shape[0], + device="cuda", + dtype=recipe.accum_dtype, ) grad_weight = fp8_matmul_kernel( - mat_a=input, transpose_a=False, mat_b=grad_output, transpose_b=False, use_split_accumulator=True + mat_a=transposed_fp8_input, + mat_b=transposed_fp8_grad_output, + output=grad_weight_temp, + use_split_accumulator=recipe.split_accumulator.weight_grad, + accumulate=recipe.accumulate.weight_grad, + accum_qtype=recipe.accum_dtype, ) - weight.grad = grad_weight + assert is_overflow_underflow_nan(grad_weight) is False + + if ctx.is_input_require_grad: + assert grad_input.dtype == recipe.accum_dtype + + assert grad_weight.dtype == recipe.accum_dtype + # TODO(xrsrke): maintain a persistence metadata across training + + grad_weight = grad_weight.reshape(grad_weight.shape[::-1]) + + if ctx.metadatas.weight_grad is None: + fp8_weight_grad = FP8Tensor( + grad_weight, + dtype=recipe.weight_grad.dtype, + interval=recipe.weight_grad.interval, + sync=sync_amax_in_wgrad, + ) + ctx.metadatas.weight_grad = fp8_weight_grad.fp8_meta + else: + fp8_weight_grad = FP8Tensor.from_metadata(grad_weight, ctx.metadatas.weight_grad, sync=sync_amax_in_wgrad) + + fp8_weight_param.grad = fp8_weight_grad - return grad_input, None, None, None + # NOTE: sanity check + assert isinstance(fp8_weight_param.grad, FP8Tensor) + return grad_input, None, None, None, None, None, None diff --git a/src/nanotron/fp8/meta.py b/src/nanotron/fp8/meta.py index 64f71d42..3064cf1f 100644 --- a/src/nanotron/fp8/meta.py +++ b/src/nanotron/fp8/meta.py @@ -1,31 +1,64 @@ from dataclasses import dataclass -from typing import Union +from typing import List import torch import transformer_engine as te # noqa -import transformer_engine_extensions as tex +import transformer_engine_torch as tex from nanotron.fp8.constants import DTYPE_TO_FP8_MAX -from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype +from nanotron.fp8.dtypes import DTypes @dataclass class FP8Meta: - """Metadata for FP8Tensor.""" + """ + Metadata for FP8Tensor. - amax: Union[int, float] + NOTE: + scale is the scaling factor for quantization, it doesn't not necessary represent the scaling factor + for the current tensor value, but rather a running scaling factor if we use delayed quantization + + amax is the current absolute maximum value of the tensor + + interval is the number of steps before we rescale the scaling factor + """ + + # TODO(xrsrke): add "margin" + amax: torch.Tensor scale: torch.Tensor # TODO(xrsrke): change to Literal[torch.int8, torch.uint8] - dtype: torch.dtype + dtype: DTypes + interval: int + sync_amax: bool = False @property def te_dtype(self) -> tex.DType: + from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype + return convert_torch_dtype_to_te_dtype(self.dtype) def __post_init__(self): + # assert isinstance(self.scale, torch.Tensor) + assert isinstance(self.amax, torch.Tensor) + assert isinstance(self.dtype, DTypes) + assert isinstance(self.interval, int) + assert self.interval > 0, f"Expected interval to be greater than 0, got {self.interval}" + assert ( + self.scale.dtype == torch.float32 + ), f"Expected scale to be of dtype torch.float32, got {self.scale.dtype}" + + # TODO(xrsrke): move these to a constant + assert self.amax.dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + ], f"Expected amax to be of dtype torch.float32 or torch.float16, got {self.amax.dtype}" + # NOTE: transformer engine only accepts torch tensors self.amax = torch.tensor(self.amax, device="cuda") if not isinstance(self.amax, torch.Tensor) else self.amax + self._amaxs: List[torch.Tensor] = [self.amax] + self._num_remaining_steps_until_rescale: int = self.interval - 1 @property def fp8_max(self) -> float: @@ -34,7 +67,63 @@ def fp8_max(self) -> float: @property def inverse_scale(self) -> torch.Tensor: + # TODO(xrsrke): this is a hacky way, remove the _inverse_scale return 1 / self.scale + # TODO(xrsrke): move to strategy pattern + def add_amax(self, amax: torch.Tensor): + from nanotron.fp8.utils import is_overflow_underflow_nan + + if len(self._amaxs) == self.interval: + # TODO(xrsrke): do we have to clear the old amax + # from memory? + self._amaxs.pop(0) + + is_overflowed = is_overflow_underflow_nan(amax) + + if is_overflowed: + # NOTE: if amax is inf or nan, we use 0 as the new amax + amax = torch.tensor(0.0, dtype=torch.float32, device="cuda") + + self.amax = amax + self._amaxs.append(amax) + + if is_overflowed: + self._num_remaining_steps_until_rescale = 0 + elif self.interval != 1: + self._num_remaining_steps_until_rescale -= 1 + + if self.is_ready_to_scale: + self.rescale() + + @property + def amaxs(self) -> List[torch.Tensor]: + return self._amaxs + + @property + def is_delayed_scaling(self) -> bool: + return self.interval > 1 + + @property + def is_ready_to_scale(self) -> bool: + if self.is_delayed_scaling is False: + # NOTE: if this is not dynamic scaling, then we scale every interval + return True + + if self.is_delayed_scaling and self._num_remaining_steps_until_rescale == 0: + # NOTE: if this is dynamic scaling, then we only scale once we reach the interval + return True + + return False + + def rescale(self): + assert self.is_ready_to_scale is True, "Cannot rescale if not ready to scale" + from nanotron.fp8.tensor import update_scaling_factor + + max_amax = torch.max(torch.stack(self.amaxs)) + current_scale = self.scale + self.scale = update_scaling_factor(max_amax, current_scale, self.dtype) + self._num_remaining_steps_until_rescale = self.interval + def __repr__(self) -> str: - return f"FP8Meta(amax={self.amax}, scale={self.scale}, inverse_scale={self.inverse_scale}, dtype={self.dtype})" + return f"FP8Meta(amax={self.amax}, scale={self.scale}, inverse_scale={self.inverse_scale}, dtype={self.dtype}, interval={self.interval}" diff --git a/src/nanotron/fp8/optim.py b/src/nanotron/fp8/optim.py new file mode 100644 index 00000000..a14eb2dd --- /dev/null +++ b/src/nanotron/fp8/optim.py @@ -0,0 +1,332 @@ +from typing import Any, Dict, List, Tuple + +import torch +from torch import nn +from torch.optim import Optimizer + +from nanotron import logging +from nanotron.fp8.constants import FP8LM_RECIPE +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.recipe import FP8OptimRecipe +from nanotron.fp8.tensor import ( + FP8Tensor, + FP16Tensor, + convert_tensor_from_fp8, + convert_tensor_from_fp16, +) +from nanotron.fp8.utils import is_overflow_underflow_nan +from nanotron.logging import log_rank + +logger = logging.get_logger(__name__) + + +class FP8AdamW(Optimizer): + """ + FP8 AdamW optimizer. + """ + + def __init__( + self, + params: List[nn.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0, + recipe: FP8OptimRecipe = FP8LM_RECIPE, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + defaults = {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay, "amsgrad": False} + + super().__init__(params, defaults) + + # TODO(xrsrke): make FP8Adam take a FP8Recipe + # then retrieve the exp_avg_dtype from the recipe + self.recipe = recipe + self.master_weight_dtype = recipe.master_weight_dtype + self.optim_accum_dtype = recipe.accum_dtype + + self.loggings = [] + self._is_overflow = False + + def _create_master_weight(self, data): + if self.master_weight_dtype == DTypes.KFLOAT16: + master_p = FP16Tensor(data, dtype=DTypes.KFLOAT16) + elif isinstance(self.master_weight_dtype, torch.dtype): + master_p = data.to(self.master_weight_dtype) if data.dtype != self.master_weight_dtype else data + else: + raise ValueError(f"accum_dtype={self.master_weight_dtype}") + return master_p + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + + # TODO(xrsrke): this is similar with master weight func, remove this + def _quantize_optim_state(self, tensor, dtype): + if dtype == DTypes.FP8E4M3 or dtype == DTypes.FP8E5M2: + tensor = FP8Tensor(tensor, dtype=dtype) + elif dtype == DTypes.KFLOAT16: + tensor = FP16Tensor(tensor, dtype=DTypes.KFLOAT16) + elif isinstance(dtype, torch.dtype): + tensor = tensor.to(dtype) + else: + raise ValueError(f"supported dtype={dtype}") + return tensor + + def _init_optim_states( + self, + state: Dict[str, Any], + p: nn.Parameter, + ) -> None: + # TODO(xrsrke): could we initialize these at a lower precision + # than the accumulation precision (eg: float32) because + # these are just zero tensors anyway? + exp_avg = torch.zeros_like(p.data, dtype=self.optim_accum_dtype) + exp_avg_sq = torch.zeros_like(p.data, dtype=self.optim_accum_dtype) + + exp_avg = self._quantize_optim_state(exp_avg, self.recipe.exp_avg_dtype) + exp_avg_sq = self._quantize_optim_state(exp_avg_sq, self.recipe.exp_avg_sq_dtype) + + # state["step"] = torch.tensor(0.0, dtype=torch.float32, device="cuda") + state["step"] = torch.tensor(0.0, dtype=self.optim_accum_dtype, device="cuda") + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + + def _calculate_mean_sqrt_ignoring_nans(self, numerator, denominator): + # Calculate the division, ignoring division by zero + division_result = torch.where(denominator != 0, numerator / denominator, torch.zeros_like(numerator)) + + # Calculate the mean, ignoring NaN values + valid_values = division_result[~torch.isnan(division_result)] + + if valid_values.numel() > 0: + mean_result = valid_values.mean() + return torch.sqrt(mean_result) + else: + raise ValueError("All values are NaN") + + # def _get_optim_logs(self): + # from nanotron.scaling.monitor import convert_logs_to_flat_logs + + # optim_loggings = {} + # for p in self.loggings: + # param_name = self.params_id_to_param_names[id(p)] + # optim_loggings[param_name] = self.loggings[p] + # return convert_logs_to_flat_logs(optim_loggings) + + def _dequantize_optim_state(self, state): + if state.__class__ == FP8Tensor: + fp32_state = convert_tensor_from_fp8(state, state.fp8_meta, self.optim_accum_dtype) + elif state.__class__ == FP16Tensor: + fp32_state = convert_tensor_from_fp16(state, self.optim_accum_dtype) + elif state.dtype == self.optim_accum_dtype: + fp32_state = state + elif isinstance(state.dtype, torch.dtype): + fp32_state = state.to(self.optim_accum_dtype) if state.dtype != self.optim_accum_dtype else state + + return fp32_state + + @torch.no_grad() + def step(self, closure=None): + # NOTE: sanity check the entire params has at least one grad + # TODO(xrsrke): remove this after debugging + from typing import cast + + from nanotron import constants + from nanotron.config.fp8_config import FP8Args + + cast(FP8Args, constants.CONFIG.fp8) + + for i, group in enumerate(self.param_groups): + for p in group["params"]: + + if not isinstance(p.data, FP8Tensor) and p.requires_grad is False: + continue + + assert p.grad is not None + + state = self.state[p] + if len(state) == 0: + self._init_optim_states(state, p) + + # NOTE: Case 1: With gradient accumulator => the grad is already in the correct dtype + # Case 2: Without gradient accumulator => + # 2.1 Non-FP8 parameter => cast the grad to the correct dtype + # 2.2 FP8 parameter => dequantize the grad to the correct dtype + + fp32_grad = p.grad + fp32_data = p.data + assert fp32_grad.dtype == self.optim_accum_dtype + assert p.data.dtype == torch.float32 + + if is_overflow_underflow_nan(fp32_grad): + self._is_overflow = True + + if constants.CONFIG.fp8.skip_param_update_if_nan is True: + log_rank( + f"[Optim] param_name, skipping update due to overflow/underflow/nan", # noqa + logger=logger, + level=logging.INFO, + ) + continue + else: + raise ValueError("Overflow, underflow, or NaN detected in the gradients") + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + fp32_exp_avg = self._dequantize_optim_state(exp_avg) + fp32_exp_avg_sq = self._dequantize_optim_state(exp_avg_sq) + + assert fp32_exp_avg.dtype == self.optim_accum_dtype + assert fp32_exp_avg_sq.dtype == self.optim_accum_dtype + + beta1, beta2 = group["betas"] + lr = group["lr"] + step = state["step"] + step += 1 + + fp32_exp_avg = beta1 * fp32_exp_avg + (1 - beta1) * fp32_grad + fp32_exp_avg_sq = beta2 * fp32_exp_avg_sq + (1 - beta2) * fp32_grad.pow(2) + + bias_correction1 = 1 / (1 - (beta1**step)) + bias_correction2 = 1 / (1 - (beta2**step)) + + unbiased_fp32_exp_avg = fp32_exp_avg * bias_correction1 + unbiased_fp32_exp_avg_sq = fp32_exp_avg_sq * bias_correction2 + + denom = unbiased_fp32_exp_avg_sq.sqrt() + group["eps"] + normalized_grad = unbiased_fp32_exp_avg / denom + + if constants.CONFIG.fp8.update_clipping is True: + rms = self._calculate_mean_sqrt_ignoring_nans( + fp32_grad.pow(2), + torch.max( + unbiased_fp32_exp_avg_sq, + torch.tensor(group["eps"], dtype=self.optim_accum_dtype, device="cuda").pow(2), + ), + ) + + if rms > 1: + # NOTE: only scale down the lr, not scale it up + update_lr = lr / torch.max(torch.tensor(1.0, dtype=self.optim_accum_dtype, device="cuda"), rms) + log_rank( + f"[Gradient clipping] param_name=, grad_norm: {fp32_grad.norm(p=2)}, RMS is {rms}, original lr is {lr}, new lr is {update_lr}", # noqa + logger=logger, + level=logging.INFO, + rank=0, + ) + else: + update_lr = lr + else: + update_lr = lr + + # NOTE: keep weight decay for biases + # TODO(xrsrke): we should explicitly set weight_decay_factor to 0 for biases + # in optimizer's param_groups + weight_decay_factor = group["weight_decay"] if p.data.ndim >= 2 else 0.0 + + if weight_decay_factor != 0: + fp32_new_changes_from_grad = update_lr * normalized_grad + fp32_weight_decay_grad = weight_decay_factor * fp32_data + + if constants.CONFIG.fp8.weight_decay_without_lr_decay is False: + fp32_new_changes_from_weight_decay = update_lr * fp32_weight_decay_grad + else: + fp32_new_changes_from_weight_decay = ( + constants.CONFIG.optimizer.learning_rate_scheduler.learning_rate * fp32_weight_decay_grad + ) + else: + fp32_new_changes_from_grad = update_lr * normalized_grad + fp32_new_changes_from_weight_decay = 0 + + fp32_new_changes_in_p = fp32_new_changes_from_grad + fp32_new_changes_from_weight_decay + new_fp32_data = fp32_data - fp32_new_changes_in_p + + p.data = new_fp32_data + + exp_avg = self._quantize_optim_state(fp32_exp_avg, self.recipe.exp_avg_dtype) + exp_avg_sq = self._quantize_optim_state(fp32_exp_avg_sq, self.recipe.exp_avg_sq_dtype) + + state["step"] = step + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + + assert state["step"] == step + assert state["exp_avg"] is exp_avg + assert state["exp_avg_sq"] is exp_avg_sq + + # NOTE: remove this shit + # if constants.is_ready_to_log is True: + # loggings[p]["step"] = {"value": step} + # loggings[p]["group:lr"] = {"value": lr} + # loggings[p]["group:eps"] = {"value": group["eps"]} + # loggings[p]["group:beta1"] = {"value": beta1} + # loggings[p]["group:beta2"] = {"value": beta2} + + # loggings[p]["bias_correction1"] = {"value": bias_correction1} + # loggings[p]["bias_correction2"] = {"value": bias_correction2} + # loggings[p]["fp32_exp_avg"] = compute_stas(fp32_exp_avg) + # loggings[p]["fp32_exp_avg_sq"] = compute_stas(fp32_exp_avg_sq) + + # loggings[p]["normalized_grad"] = compute_stas(normalized_grad) + + # if fp8_config.adam_atan2 is False: + # loggings[p]["denom"] = compute_stas(denom) + + # loggings[p]["update_lr"] = {"value": update_lr} + + # loggings[p]["fp32_p"] = compute_stas(fp32_data) + # loggings[p]["fp32_new_changes_in_p"] = { + # # "abs_total": fp32_new_changes_in_p.abs().sum(), + # # "abs_mean": fp32_new_changes_in_p.abs().mean(), + # "rms": fp32_new_changes_in_p.pow(2) + # .mean() + # .sqrt(), + # } + # loggings[p]["fp32_new_changes_from_grad"] = { + # "rms": fp32_new_changes_from_grad.pow(2).mean().sqrt(), + # } + + # p_norm = fp32_data.norm() + + # loggings[p]["fp32_grad"] = compute_stas(fp32_grad) + # loggings[p]["update_lr"] = {"value": update_lr} + # loggings[p]["weight_norm_and_normalized_grad_norm_ratio"] = { + # "value": p_norm / fp32_new_changes_from_grad.norm() + # } + # loggings[p]["weight_norm_and_weight_update_norm_ratio"] = { + # "value": p_norm / fp32_new_changes_in_p.norm() + # } + + # if weight_decay_factor != 0: + # loggings[p]["fp32_new_changes_from_weight_decay"] = { + # "rms": fp32_new_changes_from_weight_decay.pow(2).mean().sqrt(), + # } + # loggings[p]["weight_norm_and_weight_decay_grad_norm_ratio"] = { + # "value": p_norm / fp32_weight_decay_grad.norm() + # } + + # if constants.CONFIG.fp8.update_clipping is True: + # loggings[p]["grad_rms"] = {"value": rms} + + # if constants.is_ready_to_log is True: + # self.loggings = loggings + # self.loggings = self._get_optim_logs() + + def zero_grad(self): + for group in self.param_groups: + for p in group["params"]: + # NOTE: take the assumption that nanotron requires all parameters to have gradients + p.grad = None + + assert p.grad is None + assert p.data.grad is None diff --git a/src/nanotron/fp8/parallel.py b/src/nanotron/fp8/parallel.py new file mode 100644 index 00000000..2240da29 --- /dev/null +++ b/src/nanotron/fp8/parallel.py @@ -0,0 +1,36 @@ +from functools import partial + +import torch +from torch import nn + +from nanotron.parallel import ParallelContext + + +class DistributedDataParallel: + def __init__(self, module: nn.Module, parallel_context: ParallelContext): + self.parallel_context = parallel_context + + self._parallelize(module) + + @torch.no_grad() + def _parallelize(self, module) -> nn.Module: + if self.parallel_context.data_parallel_size > 1: + self._register_grad_avg_hook(module) + + return module + + def _register_grad_avg_hook(self, module: nn.Module): + for p in module.parameters(): + p.register_hook(partial(self._average_grad)) + + def _average_grad(self, grad: torch.Tensor, is_expert: bool): + # NOTE: (grad1 + grad2 + ... + gradn) / n = grad1/n + grad2/n + ... + gradn/n + assert 1 == 1 + # grad.div_(self.parallel_context.data_parallel_size) + + # all_reduce( + # grad, + # op=dist.ReduceOp.SUM, + # parallel_context=self.parallel_context, + # parallel_mode=ParallelMode.EXPERT_DATA if is_expert else ParallelMode.DATA, + # ) diff --git a/src/nanotron/fp8/parameter.py b/src/nanotron/fp8/parameter.py index e3499a62..3ed40f98 100644 --- a/src/nanotron/fp8/parameter.py +++ b/src/nanotron/fp8/parameter.py @@ -1,30 +1,87 @@ +from typing import Optional, Union + import torch from torch import nn -from nanotron.fp8.constants import FP8_DTYPES +from nanotron import constants +from nanotron.fp8.constants import FP8_DTYPES, FP8LM_RECIPE, INITIAL_AMAX, INITIAL_SCALING_FACTOR from nanotron.fp8.dtypes import DTypes from nanotron.fp8.meta import FP8Meta -from nanotron.fp8.tensor import FP8Tensor +from nanotron.fp8.tensor import FP8Tensor, update_scaling_factor class FP8Parameter(nn.Parameter): """ - A custom FP8 parameter class that allows gradients - to flow into FP8 tensors (which are integer tensors). + A custom FP8 parameter class that allows + fp8 gradients (which are integer tensors) + to flow into FP8 tensors. """ - def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True) -> nn.Parameter: + def __new__(cls, data: torch.Tensor, dtype: DTypes, requires_grad: bool = True, interval: int = 1) -> nn.Parameter: assert isinstance(data, torch.Tensor), "data must be a tensor" assert data.dtype not in FP8_DTYPES, "Currently only support turn a non-fp8 tensor to an fp8 parameter" assert data.device != torch.device("cpu"), "FP8Parameter only supports CUDA tensors" - # TODO(xrsrke): if the tensor is on cpu, then bypass quantization with torch.no_grad(): + from typing import cast + + from nanotron.config.fp8_config import FP8Args + + if constants.CONFIG is None: + sync_amax_in_weight = False + else: + fp8_config = cast(FP8Args, constants.CONFIG.fp8) + sync_amax_in_weight = fp8_config.sync_amax_in_weight + # TODO(xrsrke): support take an FP8 Tensor as data # currently we can't only quantize a tensor to FP8 after the parameter is created # because it raise "Only Tensors of floating point and complex dtype can require gradients" + # TODO(xrsrke): delete this fp32 tensor from memory after quantization self = torch.Tensor._make_subclass(cls, data, requires_grad) - self._data = FP8Tensor(data, dtype=dtype) + self._data = FP8Tensor(data, dtype=dtype, interval=interval, sync=sync_amax_in_weight) + # TODO(xrsrke): don't store fp32 raw data in memory after quantization + + if constants.ITERATION_STEP == 1: + self.orig_data = data.data + + # TODO(xrsrke): don't fixed these, take it from the FP8 recipe + fp8e4m3_scale = update_scaling_factor( + amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32), + scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR), + dtype=DTypes.FP8E4M3, + ) + fp8e5m2_scale = update_scaling_factor( + amax=torch.tensor(INITIAL_AMAX, dtype=torch.float32), + scaling_factor=torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), + dtype=DTypes.FP8E5M2, + ) + + # TODO(xrsrke): add type hints of fp8_grad_meta to FP8Parameter + self.fp8_grad_meta = FP8GradMeta( + input_grad=FP8Meta( + amax=INITIAL_AMAX, + dtype=DTypes.FP8E4M3, + scale=fp8e4m3_scale, + interval=FP8LM_RECIPE.linear.input_grad.interval, + ), + # TODO(xrsrke): change weight_grad to data_grad + # because this is the gradient of the parameter itself + weight_grad=FP8Meta( + amax=INITIAL_AMAX, + dtype=DTypes.FP8E4M3, + scale=fp8e4m3_scale, + interval=FP8LM_RECIPE.linear.weight_grad.interval, + ), + # kfloat8_e5m2 + output_grad=FP8Meta( + amax=INITIAL_AMAX, + dtype=DTypes.FP8E5M2, + scale=fp8e5m2_scale, + interval=FP8LM_RECIPE.linear.output_grad.interval, + ), + ) + self._grad = None + return self @property @@ -35,9 +92,24 @@ def data(self) -> FP8Tensor: def data(self, data: FP8Tensor): self._data = data + # # NOTE: because pytorch don't allow to assign an int grad to a tensor + # # so we bypass it by using a property + @property + def grad(self) -> Optional[Union[torch.Tensor, FP8Tensor]]: + return self.data._grad + # return self.data.grad + + @grad.setter + def grad(self, value: Optional[Union[torch.Tensor, FP8Tensor]]): + self.data._grad = value + + @property + def dtype(self) -> torch.dtype: + return self._data.dtype + @property def fp8_meta(self) -> FP8Meta: return self.data.fp8_meta def __repr__(self) -> str: - return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}" + return f"FP8Parameter({self.data}, fp8_meta={self.fp8_meta}, requires_grad={self.requires_grad}, fp8_grad_meta={self.fp8_grad_meta})" diff --git a/src/nanotron/fp8/recipe.py b/src/nanotron/fp8/recipe.py new file mode 100644 index 00000000..afe47635 --- /dev/null +++ b/src/nanotron/fp8/recipe.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +from typing import Union + +import torch + +from nanotron.fp8.dtypes import DTypes + + +# TODO(xrsrke): rename to LowPrecisionTensorRecipe or LPTensorRecipe +@dataclass +class FP8TensorRecipe: + dtype: DTypes + margin: int + interval: int + + +@dataclass +class FP8SplitAccumulator: + output: bool + input_grad: bool + weight_grad: bool + + +@dataclass +class FP8Accumulate: + output: bool + input_grad: bool + weight_grad: bool + + +@dataclass +class FP8LinearRecipe: + accum_dtype: torch.dtype + + input: Union[FP8TensorRecipe, torch.dtype] + weight: Union[FP8TensorRecipe, torch.dtype] + # TODO(xrsrke): remove bias recipe, because we don't quantize bias + bias: Union[FP8TensorRecipe, torch.dtype] + + # NOTE: for the gradients + input_grad: Union[FP8TensorRecipe, torch.dtype] + weight_grad: Union[FP8TensorRecipe, torch.dtype] + # TODO(xrsrke): we don't need this, because the output gradients of a layer + # is the input gradients of the other layer + output_grad: Union[FP8TensorRecipe, torch.dtype] + + # TODO(xrsrke): this is a low-level implementation details + # we should hide this from high-level apis later on + split_accumulator: FP8SplitAccumulator + accumulate: FP8Accumulate + smooth_quant: bool = False + + +@dataclass +class FP8OptimRecipe: + """ + master_weight_dtype, exp_avg_dtype, exp_avg_sq_dtype are just storage dtypes + accum_dtype is the dtype for calculations, we have to cast other dtypes to this dtype + """ + + # NOTE: these are just storage dtypes + # not FP8Tensor that need to dynamically change + # during training + master_weight_dtype: Union[DTypes, torch.dtype] + accum_dtype: torch.dtype + + exp_avg_dtype: Union[DTypes, torch.dtype] + exp_avg_sq_dtype: Union[DTypes, torch.dtype] + + +@dataclass +class FP8TrainingRecipe: + # TODO(xrsrke): add initialization dtype as a part of the recipe + # currently we use float32 for initialization, then quantize it + + # TODO(xrsrke): allow disable FP8 for some specific layers like lm_head, mlp, etc. + # maybe specify fp8 in the modeling code! + + # NOTE: precision dtype for non-fp8 modules + linear: FP8LinearRecipe + optim: FP8OptimRecipe diff --git a/src/nanotron/fp8/tensor.py b/src/nanotron/fp8/tensor.py index 3f97049c..db1b670e 100644 --- a/src/nanotron/fp8/tensor.py +++ b/src/nanotron/fp8/tensor.py @@ -1,39 +1,194 @@ +from __future__ import annotations + +from abc import abstractstaticmethod +from copy import deepcopy +from typing import Optional, Union, cast + import torch import transformer_engine as te # noqa -import transformer_engine_extensions as tex +import transformer_engine_torch as tex +from nanotron import logging from nanotron.fp8.constants import DTYPE_TO_FP8_MAX, FP8_DTYPES, INITIAL_SCALING_FACTOR from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.meta import FP8Meta +logger = logging.get_logger(__name__) -class FP8Tensor(torch.Tensor): - """FP8 Tensor.""" - def __new__(cls, tensor: torch.Tensor, dtype: DTypes) -> torch.Tensor: +class LowPrecisionTensor(torch.Tensor): + def __new__( + cls, + tensor: torch.Tensor, + dtype: Optional[DTypes] = None, + interval: Optional[int] = 1, + fp8_meta: Optional[FP8Meta] = None, + sync: bool = False, + ) -> torch.Tensor: assert isinstance(tensor, torch.Tensor), "tensor must be a tensor" - assert tensor.dtype not in FP8_DTYPES, "The tensor already quantized to FP8" - - # TODO(xrsrke): there is a circular import issue - # between tensor.py and meta.py fix this - from nanotron.fp8.meta import FP8Meta # TODO(xrsrke): if the tensor is on cpu, then bypass the quantization # because the current kernels only support gpu tensor assert tensor.device != torch.device("cpu"), "FP8Tensor only supports CUDA device" - assert isinstance(dtype, DTypes) - amax = tensor.abs().max().clone() - scale = update_scaling_factor(amax, torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), dtype) - fp8_meta = FP8Meta(amax, scale, dtype) - fp8_tensor = convert_tensor_to_fp8(tensor, fp8_meta) + if fp8_meta is None: + assert dtype in [DTypes.FP8E4M3, DTypes.FP8E5M2, DTypes.KFLOAT16] + + with torch.no_grad(): + fp8_meta = cls._get_metadata(tensor, dtype, interval, sync=sync) + + backup_fp8_meta = deepcopy(fp8_meta) + if tensor.dtype not in FP8_DTYPES: + fp8_tensor = cls._quantize(tensor, fp8_meta) + else: + fp8_tensor = tensor # TODO(xrsrke): move update inverse scaling to FP8Meta's initialization obj = torch.Tensor._make_subclass(cls, fp8_tensor) - obj.fp8_meta = fp8_meta + # TODO(xrsrke): use a different name, because FP16Tensor also has fp8_meta + obj.fp8_meta = backup_fp8_meta + return obj + def __init__( + self, + tensor: torch.Tensor, + dtype: Optional[DTypes] = None, + interval: Optional[int] = 1, + fp8_meta: Optional[FP8Meta] = None, + sync: bool = False, + ) -> None: + pass + + @staticmethod + # @torch.no_grad() + def _get_metadata(tensor: torch.Tensor, dtype: DTypes, interval: int, sync: bool) -> "FP8Meta": + # TODO(xrsrke): there is a circular import issue + # between tensor.py and meta.py fix this + from nanotron.fp8.meta import FP8Meta + + # NOTE: detach from original computational graph + amax = tensor.amax().clone() + + scale = update_scaling_factor(amax, torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), dtype) + scale = scale.clone().detach() + fp8_meta = FP8Meta(amax, scale, dtype, interval) + return fp8_meta + + @abstractstaticmethod + def _quantize(tensor: torch.Tensor, fp8_meta: "FP8Meta") -> torch.Tensor: + ... + + def mul_(self, other: torch.Tensor): + + assert isinstance(other, torch.Tensor) + assert ( + other.ndim == 0 or other.ndim == 1 + ), "FP8Tensor don't support directly do matrix multiplication in FP8. You should cast it to a higher precision format." + + other = other.squeeze() if other.ndim == 1 else other + self.fp8_meta = cast(FP8Meta, self.fp8_meta) + self.fp8_meta.scale = 1 / (self.fp8_meta.inverse_scale * other) + + def div_(self, other: torch.Tensor): + assert isinstance(other, torch.Tensor) + assert ( + other.ndim == 0 or other.ndim == 1 + ), "FP8Tensor don't support directly do matrix division in FP8. You should cast it to a higher precision format." + self.mul_(1 / other) + + def __add__(self, other: torch.Tensor): + raise ValueError( + "You can't directly add a FP8Tensor with another tensor. You should cast it to a higher precision format" + ) + + def __sub__(self, other: torch.Tensor): + raise ValueError( + "You can't directly subtract a FP8Tensor with another tensor. You should cast it to a higher precision format" + ) + + # TODO(xrsrke): need some more work to make it work with torch.equal + def __eq__(self, other: LowPrecisionTensor) -> bool: + assert isinstance( + other, self.__class__ + ), "Expected other tensor to be an instance of {self.__class__}, got {other.__class__}" + return True if self.fp8_meta == other.fp8_meta and torch.equal(self.data, other.data) else False + + # TODO(xrsrke): directly set a tensor data using tensor.data = new_data + def set_data(self, data: Union[torch.Tensor, LowPrecisionTensor, None], sync: bool = False): + assert isinstance(data, (self.__class__, torch.Tensor)), f"data must be a torch.Tensor or a {self.__class__}" + if data.__class__ in [FP8Tensor, FP16Tensor]: + assert data.dtype == self.data.dtype, "The data must have the same dtype as the tensor, got {data.dtype}" + quantized_data = data + else: + quantized_data = self.__class__( + data, dtype=self.fp8_meta.dtype, interval=self.fp8_meta.interval, sync=sync + ) + + self.data = quantized_data.data + self._orig_data_after_set_data = data + + self.fp8_meta.add_amax(quantized_data.fp8_meta.amax) + + @staticmethod + @torch.no_grad() + def from_metadata(data: torch.Tensor, metadata: "FP8Meta", sync: bool = False) -> Union[FP8Tensor, FP16Tensor]: + assert isinstance(data, (FP8Tensor, torch.Tensor)), "data must be a torch.Tensor or a FP8Tensor" + # NOTE: don't do deepcopy, because we reuse the same metadata + # for other iterations in fp8linear + amax = data.abs().max().clone() + metadata.add_amax(amax) + + quantized_data = FP8Tensor(data, metadata.dtype, metadata.interval, fp8_meta=metadata, sync=sync) + return quantized_data + + def transpose_fp8(self) -> FP8Tensor: + """Transpose the tensor.""" + transposed_t = tex.fp8_transpose(self, self.fp8_meta.te_dtype) + transposed_t.fp8_meta = self.fp8_meta + return self.__class__(transposed_t, fp8_meta=self.fp8_meta) + def __repr__(self) -> str: - return f"FP8Tensor({self}, fp8_meta={self.fp8_meta})" + if hasattr(self, "fp8_meta"): + if self.__class__ == FP16Tensor: + return f"FP16Tensor({repr(self.data)}, fp8_meta={self.fp8_meta})" + elif self.__class__ == FP8Tensor: + return f"FP8Tensor({repr(self.data)}, fp8_meta={self.fp8_meta})" + else: + raise ValueError(f"Unknown tensor class: {self.__class__}") + + return super().__repr__() + + def clone(self) -> FP8Tensor: + tensor = super().clone() + tensor.fp8_meta = deepcopy(self.fp8_meta) + return tensor + + +class FP8Tensor(LowPrecisionTensor): + """FP8 Tensor.""" + + @staticmethod + def _quantize(tensor: torch.Tensor, fp8_meta: "FP8Meta") -> torch.Tensor: + assert isinstance(tensor, torch.Tensor) + assert tensor.dtype not in FP8_DTYPES, "The tensor already quantized to FP8" + + tensor = tensor.contiguous() + return convert_tensor_to_fp8(tensor, fp8_meta) + + +class FP16Tensor(LowPrecisionTensor): + + # TODO(xrsrke): remove specifying the dtype KFLOAT16 + # in initialization + # TODO(xrsrke): change the name to lp_meta = low_precision_meta + @staticmethod + def _quantize(tensor: torch.Tensor, fp8_meta: "FP8Meta") -> torch.Tensor: + assert isinstance(tensor, torch.Tensor) + + tensor = tensor.contiguous() + # TODO(xrsrke): convert it to int8 format + return (tensor * fp8_meta.scale).to(torch.float16) def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType: @@ -44,10 +199,6 @@ def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType: torch.float32: "kFloat32", torch.float16: "kFloat16", torch.bfloat16: "kBFloat16", - # torch.fp8e5m2: "kFloat8E5M2", - # torch.fp8e4m3: "kFloat8E4M3", - # torch.int8: "kFloat8E5M2", - # torch.uint8: "kFloat8E4M3", DTypes.FP8E4M3: "kFloat8E4M3", DTypes.FP8E5M2: "kFloat8E5M2", DTypes.KFLOAT16: "kFloat16", @@ -74,6 +225,25 @@ def convert_tensor_from_fp8(tensor: torch.Tensor, meta, dtype: torch.dtype) -> t return tex.cast_from_fp8(tensor, meta.inverse_scale, tensor_dtype, output_dtype) +def convert_tensor_from_fp16(tensor: FP16Tensor, dtype: torch.dtype) -> torch.Tensor: + assert isinstance(dtype, torch.dtype) + # TODO(xrsrke): this is a hacky way to turn a fp16 tensor to a non-quantize tensor + inverse_scale = tensor.fp8_meta.inverse_scale + tensor = tensor.clone() + tensor = (tensor * inverse_scale).to(dtype) + return torch.tensor(tensor, dtype=dtype).squeeze(dim=0) + + +def _convert_tensor_from_fp16(tensor: FP16Tensor, fp8_meta, dtype: torch.dtype) -> torch.Tensor: + assert isinstance(dtype, torch.dtype) + + inverse_scale = fp8_meta.inverse_scale + tensor = tensor.clone() + tensor = (tensor * inverse_scale).to(dtype) + return torch.tensor(tensor, dtype=dtype).squeeze(dim=0) + + +# @torch.jit.script def update_scaling_factor( amax: torch.Tensor, scaling_factor: torch.Tensor, dtype: DTypes, margin: float = 0 ) -> torch.Tensor: @@ -81,16 +251,20 @@ def update_scaling_factor( Update the scaling factor to quantize a tensor to FP8. Credits: https://github.com/Azure/MS-AMP/blob/d562f0f0bcfc9b712fa0726b73428753ff1300ab/msamp/common/tensor/meta.py#L39 """ - assert amax.dtype == torch.float32 + # TODO(xrsrke): sometimes we store some params in fp16 + # make this configurable + assert amax.dtype in [torch.float32, torch.float16, torch.bfloat16], f"amax.dtype: {amax.dtype}" # TODO(xrsrke): can we use lower precision for scaling_factor? assert scaling_factor.dtype == torch.float32 # NOTE: Since fp8_max is a fixed number based on two FP8 data types, # we prefer not to take fp8_max in the input arguments. + # NOTE: create cuda tensor slows down by 7% fp8_max = torch.tensor(DTYPE_TO_FP8_MAX[dtype], dtype=torch.float32) # NOTE: torch.jit only take a concrete value rather than a DTYPE_TO_FP8_MAX[dtype], # so we create an inner function to bypass that + @torch.jit.script def _inner(amax: torch.Tensor, fp8_max: torch.Tensor, scaling_factor: torch.Tensor, margin: float): # NOTE: calculate the number of bits to shift the exponent diff --git a/src/nanotron/fp8/utils.py b/src/nanotron/fp8/utils.py index 1f0e23ea..903646a0 100644 --- a/src/nanotron/fp8/utils.py +++ b/src/nanotron/fp8/utils.py @@ -1,7 +1,18 @@ +from typing import Dict, List, Optional, Tuple + import torch import transformer_engine as te # noqa +from torch import nn + +from nanotron import logging +from nanotron.config.fp8_config import FP8Args, FP8LayerArgs +from nanotron.fp8.constants import FP8_GPU_NAMES, FP8LM_RECIPE, QTYPE_TO_DTYPE +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.linear import FP8Linear +from nanotron.fp8.meta import FP8Meta +from nanotron.models.base import NanotronModel -from nanotron.fp8.constants import FP8_GPU_NAMES +logger = logging.get_logger(__name__) def is_fp8_available() -> bool: @@ -11,3 +22,373 @@ def is_fp8_available() -> bool: return any(gpu_name in device_name for gpu_name in FP8_GPU_NAMES) else: return False + + +def get_tensor_fp8_metadata(tensor: torch.Tensor, dtype: DTypes) -> FP8Meta: + from nanotron.fp8.constants import INITIAL_SCALING_FACTOR + from nanotron.fp8.tensor import update_scaling_factor + + # NOTE: do .clone() somehow fixes nan grad, + # check `exp801_fp8_nan_debug` for more details + amax = tensor.abs().max().clone() + assert amax.dtype == torch.float32 + + scale = update_scaling_factor(amax, torch.tensor(INITIAL_SCALING_FACTOR, dtype=torch.float32), dtype) + assert scale.dtype == torch.float32 + + fp8_meta = FP8Meta(amax, scale, dtype) + return fp8_meta + + +# TODO(xrsrke): shorter name +def is_overflow_underflow_nan(tensor: torch.Tensor) -> bool: + assert isinstance(tensor, torch.Tensor) + + overflow = torch.isinf(tensor).any().item() + underflow = torch.isneginf(tensor).any().item() + nan = torch.isnan(tensor).any().item() + + return True if (overflow or underflow or nan) else False + + +def convert_linear_to_fp8(linear: nn.Linear, accum_qtype: DTypes = FP8LM_RECIPE.linear.accum_dtype) -> FP8Linear: + in_features = linear.in_features + out_features = linear.out_features + is_bias = linear.bias is not None + + fp8_linear = FP8Linear( + in_features, out_features, bias=is_bias, device=linear.weight.device, accum_qtype=accum_qtype + ) + # TODO(xrsrke): do we need clone? + fp8_linear._set_and_quantize_weights(linear.weight.data.clone()) + + if is_bias: + fp8_linear.bias.data = linear.bias.data.to(QTYPE_TO_DTYPE[accum_qtype]) + + return fp8_linear + + +def get_leaf_modules(module: nn.Module) -> List[Tuple[str, nn.Module]]: + """ + Return all the leaf modules (modules without any child modules) in a PyTorch module. + """ + leaf_modules = [] + for n, m in module.named_modules(): + if not list(m.children()): + leaf_modules.append((n, m)) + return leaf_modules + + +def convert_to_fp8_module(module: nn.Module, accum_qtype: DTypes = FP8LM_RECIPE.linear.accum_dtype) -> nn.Module: + def set_module(model, name, value): + parts = name.split(".") + module = model + for i, part in enumerate(parts): + if part.isdigit(): + if i == len(parts) - 1: + module[int(part)] = value + else: + module = module[int(part)] + else: + if i == len(parts) - 1: + setattr(module, part, value) + else: + module = getattr(module, part) + return model + + for name, child in get_leaf_modules(module): + if isinstance(child, nn.Linear): + fp8_linear = convert_linear_to_fp8(child, accum_qtype) + fp8_linear.name = name + set_module(module, name, fp8_linear) + + return module + + +def calculate_kurtosis(X): + # Calculate s + s = torch.sqrt(torch.mean(X**2, dim=0)) + + # Calculate m4 and m2 + m4 = torch.mean(s**4) + m2 = torch.mean(s**2) + + # Calculate kurtosis + kurtosis = m4 / (m2**2) + + if torch.isnan(kurtosis) and not torch.all(torch.eq(X, 0)).item(): + assert 1 == 1 + + return kurtosis + + +def compute_stas(tensor): + from nanotron.fp8.tensor import FP8Tensor, FP16Tensor + + def compute_snr(tensor): + mean = torch.mean(tensor) + std = torch.std(tensor) + snr = mean / std + return snr + + if isinstance(tensor, FP8Tensor) or isinstance(tensor, FP16Tensor): + return { + "amax": tensor.fp8_meta.amax, + "scale": tensor.fp8_meta.scale, + } + else: + return { + "mean": tensor.mean(), + "std": tensor.std(), + "var": tensor.var(), + "l1_norm": tensor.norm(p=1), + "l2_norm": tensor.norm(p=2), + "rms": tensor.pow(2).mean().sqrt(), + "min": tensor.min(), + "max": tensor.max(), + "amax": tensor.abs().max(), + "abs_mean": tensor.abs().mean(), + "kurtosis": calculate_kurtosis(tensor), + "snr": compute_snr(tensor), + } + + +def track_module_statistics(name: str, module: nn.Linear, logging: Dict[str, Dict[str, float]]): + if name not in logging: + logging[name] = {} + + def _save_output_stats(module: nn.Linear, input: torch.Tensor, output: torch.Tensor): + if hasattr(module, "weight") and module.weight is not None: + logging[name]["weight"] = compute_stas(module.weight.data) + # logging[name]["weight"] = _collect_stats(module.weight) + + if hasattr(module, "bias") and module.bias is not None: + logging[name]["bias"] = compute_stas(module.bias) + + inputs = input if isinstance(input, tuple) else (input,) + outputs = output if isinstance(output, tuple) else (output,) + + if len(inputs) > 1: + for i, inp in enumerate(inputs): + if inp.dtype == torch.long: + # NOTE: this is input ids in transformers + continue + logging[name][f"input:{i}"] = compute_stas(inp) + else: + logging[name]["input"] = compute_stas(inputs[0]) + + if len(outputs) > 1: + for i, out in enumerate(outputs): + logging[name][f"output:{i}"] = compute_stas(out) + else: + logging[name]["output"] = compute_stas(outputs[0]) + + def _save_grad_stats(module: nn.Linear, grad_input, grad_output: torch.Tensor): + if isinstance(grad_output, tuple): + for i, grad in enumerate(grad_output): + if grad is None: + continue + + logging[name][f"grad_output:{i}"] = compute_stas(grad) + else: + logging[name]["grad_output"] = compute_stas(grad_output) + + if isinstance(grad_input, tuple): + for i, grad in enumerate(grad_input): + if grad is not None: + logging[name][f"grad_input:{i}"] = compute_stas(grad) + else: + if grad_input is not None: + logging[name]["grad_input"] = compute_stas(grad_input) + + handles = [] + handles.append(module.register_forward_hook(_save_output_stats)) + handles.append(module.register_backward_hook(_save_grad_stats)) + return handles + + +def _log(model: nn.Module): + LOGGING = {} + leaf_modules = get_leaf_modules(model) + all_handles = [] + for name, module in leaf_modules: + all_handles.append(track_module_statistics(name, module, logging=LOGGING)) + + return LOGGING, all_handles + + +def convert_logs_to_flat_logs(logs, prefix): + flat_logs = {} + for module_name, components in logs.items(): + for component_name, stats in components.items(): + for stat_name, value in stats.items(): + flat_logs[f"{prefix}:{module_name}:{component_name}:{stat_name}"] = value + + return flat_logs + + +def find_fp8_config_by_module_name(target_module_name: str, config: FP8Args) -> Optional[FP8LayerArgs]: + # NOTE: either model or is_quant_all_except_first_and_last must be specified, not both + # assert config.fp8.model is not None or config.fp8.is_quant_all_except_first_and_last is not None + + # TODO(xrsrke): remove config.is_quant_all_except_first_and_last + from nanotron.fp8.constants import FP8LM_LINEAR_RECIPE + + if hasattr(config, "model") and config.model is not None: + for layer_args in config.model: + if layer_args.module_name == target_module_name.replace("pp_block.", "").replace("module.", ""): + return layer_args + # elif config.is_quant_all_except_first_and_last: + else: + + def match_layer_pattern(name, layer_idxs): + # patterns = [ + # "model.decoder.{}.pp_block.attn.qkv_proj", + # "model.decoder.{}.pp_block.attn.o_proj", + # "model.decoder.{}.pp_block.mlp.down_proj", + # "model.decoder.{}.pp_block.mlp.gate_up_proj", + # ] + patterns = [ + "model.decoder.{}.attn.qkv_proj", + "model.decoder.{}.attn.o_proj", + "model.decoder.{}.mlp.down_proj", + "model.decoder.{}.mlp.gate_up_proj", + ] + + for idx in layer_idxs: + for pattern in patterns: + if name == pattern.format(idx): + return True + + return False + + from nanotron import constants + + num_layers = constants.CONFIG.model.model_config.num_hidden_layers + assert num_layers > 2, "num_hidden_layers must be greater than 2" + # assert config.fp8_linear_config_temp is not None + + quant_layer_idxs = list(range(1, num_layers - 1)) + # NOTE: remove ".pp_block" from module name + if match_layer_pattern(target_module_name.replace(".pp_block", ""), quant_layer_idxs) is True: + from copy import deepcopy + + # config_temp = deepcopy(config.fp8_linear_config_temp) + config_temp = deepcopy(FP8LM_LINEAR_RECIPE) + config_temp.module_name = target_module_name + return config_temp + # else: + # from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8 + + # if any(module_name in target_module_name for module_name in MODULE_NAMES_THAT_NOT_FP8): + # return None + # else: + # # NOTE: return default recipe + # # NOTE: based on the global setting smooth_quant to decide whether to do smooth quantization + # # or not + # recipe = FP8LM_LINEAR_RECIPE + # recipe.smooth_quant = config.smooth_quant + # log_rank( + # f"target_module_name={target_module_name}, smooth_quant={recipe.smooth_quant}", + # logger=logger, + # level=logging.INFO, + # rank=0, + # ) + + # return recipe + # return None + + +def get_modules_not_in_fp16(): + from nanotron import constants + from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8 + + if constants.CONFIG is not None and hasattr(constants.CONFIG, "fp8"): + if constants.CONFIG.fp8.model is None: + # NOTE: convert all modules to fp8 axcept + name_of_modules_not_in_fp16 = MODULE_NAMES_THAT_NOT_FP8 + else: + name_of_modules_not_in_fp16 = [x.module_name for x in constants.CONFIG.fp8.model] + else: + name_of_modules_not_in_fp16 = [] + return name_of_modules_not_in_fp16 + + +def is_convert_to_fp16(module) -> bool: + from nanotron import constants + from nanotron.fp8.constant_recipe import MODULE_NAMES_THAT_NOT_FP8, MODULES_THAT_IN_FLOAT16 + + IS_CONVERT_TO_FLOAT16 = False + name_of_modules_not_in_fp16 = get_modules_not_in_fp16() + + # if hasattr(module, "name") and "lm_head" in module.name: + # assert 1 == 1 + + if constants.CONFIG is not None and constants.CONFIG.fp8.model is None: + if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16): + IS_CONVERT_TO_FLOAT16 = True + else: + if hasattr(module, "name") and any(n in module.name for n in MODULE_NAMES_THAT_NOT_FP8): + IS_CONVERT_TO_FLOAT16 = True + else: + if any(isinstance(module, m) for m in MODULES_THAT_IN_FLOAT16) or ( + hasattr(module, "name") and module.name not in name_of_modules_not_in_fp16 + ): + IS_CONVERT_TO_FLOAT16 = True + + return IS_CONVERT_TO_FLOAT16 + + +def convert_model_to_fp8(model: NanotronModel, config: FP8Args) -> NanotronModel: + from nanotron.fp8.utils import get_leaf_modules + + assert 1 == 1 + # NOTE: convert to FP8 + + # from nanotron import constants + from nanotron.fp8.utils import find_fp8_config_by_module_name + from nanotron.parallel.tensor_parallel.nn import ( + FP8TensorParallelColumnLinear, + FP8TensorParallelRowLinear, + TensorParallelColumnLinear, + TensorParallelRowLinear, + ) + + TP_LINEAR_CLS_TO_FP8_LINEAR_CLS = { + TensorParallelColumnLinear: FP8TensorParallelColumnLinear, + TensorParallelRowLinear: FP8TensorParallelRowLinear, + } + for name, module in get_leaf_modules(model): + if any(p.numel() > 0 for p in module.parameters()) is False: + continue + + recipe = find_fp8_config_by_module_name(name, config) + + # if isinstance(module, (TensorParallelColumnLinear, TensorParallelRowLinear)): + if recipe is not None: + print(f"Converting {name} to FP8") + module.__class__ = TP_LINEAR_CLS_TO_FP8_LINEAR_CLS[module.__class__] + # TODO(xrsrke): retrieve custom recipe + module._set_and_quantize_weights(module.weight.data) + + # assert isinstance(module.weight, NanotronParameter) + # assert module.weight.data.__class__ == FP8Tensor + # assert module.weight.data.dtype in [ + # torch.uint8, + # torch.int8, + # ], f"got {module.weight.data.dtype}, name: {name}" + else: + # NOTE: convert it to the residual stream's dtype + # for p in module.parameters(): + # p.data = p.data.to(self.config.model.dtype) + # for p in module.parameters(): + # p.data = p.data.to(dtype=config.resid_dtype) if p.data + # pass + # assert module.weight.data.__class__ == torch.Tensor + # module.to(dtype=config.resid_dtype) + # pass + # assert module.weight.data.__class__ == torch.Tensor + # NOTE: this causes param.data == NanotronParameter + assert config.resid_dtype == torch.float32, "not support datatype conversion, because of error 8" + + return model diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index 73ca3484..1135ae62 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -294,6 +294,7 @@ def merge_named_param_groups( def init_optimizer_and_grad_accumulator( parametrization_method: ParametrizationMethod, model: nn.Module, + master_weight_dtype: torch.dtype, optimizer_args: OptimizerArgs, parallel_context: ParallelContext, ) -> Tuple[BaseOptimizer, GradientAccumulator]: @@ -328,12 +329,40 @@ def basic_optimizer_builder(named_param_groups): if optimizer_args.optimizer_factory.name == "adamW": def optimizer(param_groups): + # if has_fp8_params(param_groups) is False: + # if constants.CONFIG.model.dtype != torch.int8: + # return torch.optim.AdamW( + # param_groups, + # lr=optimizer_args.learning_rate_scheduler.learning_rate, + # weight_decay=optimizer_args.weight_decay, + # eps=optimizer_args.optimizer_factory.adam_eps, + # betas=( + # optimizer_args.optimizer_factory.adam_beta1, + # optimizer_args.optimizer_factory.adam_beta2, + # ), + # fused=optimizer_args.optimizer_factory.torch_adam_is_fused, + # ) + # else: + # return FP8AdamW( + # param_groups, + # lr=optimizer_args.learning_rate_scheduler.learning_rate, + # weight_decay=optimizer_args.weight_decay, + # eps=optimizer_args.optimizer_factory.adam_eps, + # betas=( + # optimizer_args.optimizer_factory.adam_beta1, + # optimizer_args.optimizer_factory.adam_beta2, + # ), + # recipe=constants.CONFIG.fp8.optim, + # ) return torch.optim.AdamW( param_groups, lr=optimizer_args.learning_rate_scheduler.learning_rate, weight_decay=optimizer_args.weight_decay, eps=optimizer_args.optimizer_factory.adam_eps, - betas=(optimizer_args.optimizer_factory.adam_beta1, optimizer_args.optimizer_factory.adam_beta2), + betas=( + optimizer_args.optimizer_factory.adam_beta1, + optimizer_args.optimizer_factory.adam_beta2, + ), fused=optimizer_args.optimizer_factory.torch_adam_is_fused, ) @@ -365,6 +394,7 @@ def grad_optimizer_builder(named_param_groups): gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator( named_parameters=named_params, grad_buckets_named_params=named_parameters, + master_dtype=master_weight_dtype, ), named_params_or_groups=named_param_groups, optimizer_builder=basic_optimizer_builder, diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index 14ac6908..562f8627 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -71,7 +71,7 @@ def get_embeddings_lm_head_tied_names(self) -> list[str]: Example for GPT2 model: ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] """ return [] - + def get_named_params_without_weight_decay(self) -> List[str]: """Return a list of named parameters that should not have weight decay applied to them.""" return [] @@ -237,6 +237,10 @@ def build_model( return model +old_register_parameter = nn.Module.register_parameter +old_register_buffer = nn.Module.register_buffer + + # TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does. @contextmanager def init_on_device_and_dtype( diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 88fb6bcb..0d69fe62 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -24,11 +24,11 @@ from nanotron import logging from nanotron.config import Config, LlamaConfig, ParallelismArgs from nanotron.config.models_config import RandomInit, SpectralMupInit +from nanotron.fp8.utils import is_overflow_underflow_nan from nanotron.generation.generate_store import AttachableStore from nanotron.logging import log_rank from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN -from nanotron.nn.layer_norm import TritonRMSNorm from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer @@ -207,6 +207,7 @@ def __init__( config: LlamaConfig, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + layer_idx: int, ): super().__init__() @@ -220,6 +221,7 @@ def __init__( config.intermediate_size, # shape of gate_linear config.intermediate_size, # shape of up_linear ) + # self.gate_up_proj = TensorParallelColumnLinear( self.gate_up_proj = TensorParallelColumnLinear( config.hidden_size, 2 * config.intermediate_size, @@ -228,7 +230,8 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, - tp_recompute_allgather=parallel_config.tp_recompute_allgather, + name=f"model.decoder.{layer_idx}.mlp.gate_up_proj", + # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( config.intermediate_size, @@ -237,6 +240,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + name=f"model.decoder.{layer_idx}.mlp.down_proj", ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -389,7 +393,8 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, - tp_recompute_allgather=parallel_config.tp_recompute_allgather, + name=f"model.decoder.{layer_idx}.attention.qkv_proj", + # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. if config.rope_interleaved: @@ -418,6 +423,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, + name=f"model.decoder.{layer_idx}.attention.o_proj", ) self.attention = CoreAttention( @@ -425,6 +431,7 @@ def __init__( parallel_config=parallel_config, layer_idx=layer_idx, ) + self.layer_idx = layer_idx self.prefill_kv_len = ( config.max_position_embeddings @@ -446,6 +453,8 @@ def forward( ) # [seq_length, batch_size, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] q_length, batch_size, _ = qkv_states.shape + assert is_overflow_underflow_nan(qkv_states) is False, f"layer_idx: {self.layer_idx}" + if self.is_gqa: query_states, key_states, value_states = torch.split( qkv_states, @@ -655,6 +664,10 @@ def forward( key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) # [batch_size, seq_length, 2, num_heads, d_qk] key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous() + + assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(key_value_states) is False, f"layer_idx: {self.layer_idx}" + query_states, key_value_states = self.flash_rotary_embedding(query_states, kv=key_value_states) # [batch_size, seq_length, num_heads, d_qk] key_states, value_states = torch.split(key_value_states, 1, dim=2) @@ -676,6 +689,22 @@ def forward( batch_size * kv_length, self.n_local_kv_heads, self.d_v ) # [batch_size * kv_length, self.n_heads, d_v] + # NOTE: even though in some cases, we accumulate fp8 gemm in bfloat16, + # but since the layer norm are in float32, the resulting output will be in float32 + # and flash attention don't support float32 qkv, so we have to cast it back to bfloat16 + + assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(key_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(value_states) is False, f"layer_idx: {self.layer_idx}" + + query_states = query_states.to(torch.bfloat16) + key_states = key_states.to(torch.bfloat16) + value_states = value_states.to(torch.bfloat16) + + assert is_overflow_underflow_nan(query_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(key_states) is False, f"layer_idx: {self.layer_idx}" + assert is_overflow_underflow_nan(value_states) is False, f"layer_idx: {self.layer_idx}" + attention_output = self.attention( query_states=query_states, key_states=key_states, @@ -687,6 +716,14 @@ def forward( attention_output = ( attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) ) + from nanotron import constants + + if attention_output.dtype != constants.CONFIG.fp8.resid_dtype: + assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}" + attention_output = attention_output.to(constants.CONFIG.fp8.resid_dtype) + assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}" + + assert is_overflow_underflow_nan(attention_output) is False, f"layer_idx: {self.layer_idx}" output = self.o_proj(attention_output) return {"hidden_states": output, "sequence_mask": sequence_mask} @@ -701,7 +738,10 @@ def __init__( layer_idx: int, ): super().__init__() - self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # self.input_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # NOTE: i got an illegal memory access was encountered when using TritonRMSNorm + # even downgrad flash_attn to 2.4.2 + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = CausalSelfAttention( config=config, parallel_config=parallel_config, @@ -709,10 +749,12 @@ def __init__( layer_idx=layer_idx, ) - self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + # self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx) self.recompute_layer = parallel_config.recompute_layer + self.layer_idx = layer_idx def _core_forward( self, @@ -721,15 +763,22 @@ def _core_forward( ) -> List[Union[torch.Tensor, TensorPointer]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" hidden_states = hidden_states + residual + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" hidden_states = hidden_states + residual + assert is_overflow_underflow_nan(hidden_states) is False, f"layer_idx: {self.layer_idx}" return hidden_states, output["sequence_mask"] @@ -847,8 +896,10 @@ def __init__( self.final_layer_norm = PipelineBlock( p2p=self.p2p, - module_builder=TritonRMSNorm, - module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + # module_builder=TritonRMSNorm, + # module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_builder=nn.LayerNorm, + module_kwargs={"normalized_shape": config.hidden_size, "eps": config.rms_norm_eps}, module_input_keys={"input"}, module_output_keys={"hidden_states"}, ) # TODO @@ -865,7 +916,8 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, - "tp_recompute_allgather": parallel_config.tp_recompute_allgather, + "name": "model.lm_head", + # "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, module_output_keys={"logits"}, @@ -899,14 +951,20 @@ def forward_with_hidden_states( "hidden_states": output["input_embeds"], "sequence_mask": input_mask, } + assert is_overflow_underflow_nan(hidden_encoder_states["hidden_states"]) is False + for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + assert is_overflow_underflow_nan(hidden_encoder_states["hidden_states"]) is False hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + assert is_overflow_underflow_nan(hidden_states) is False sharded_logits = self.lm_head(x=hidden_states)["logits"] + assert is_overflow_underflow_nan(sharded_logits) is False fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + assert is_overflow_underflow_nan(fp32_sharded_logits) is False return fp32_sharded_logits, hidden_states @@ -1069,7 +1127,8 @@ def init_model_randomly(self, config: Config): continue module = model.get_submodule(module_name) - parametrizator.parametrize(param_name, module) + # parametrizator.parametrize(param_name, module) + parametrizator.parametrize(full_param_name, module) assert full_param_name not in initialized_parameters initialized_parameters.add(full_param_name) diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index ef3b4c50..7a8fcaad 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -39,9 +39,58 @@ def reset_parameters(self): def forward( self, input, residual=None, dropout_p=0.0, prenorm=False, residual_in_fp32=False, return_dropout_mask=False ): - from flash_attn.ops.triton.layer_norm import layer_norm_fn + # NOTE: fa=2.6.3 + # got the following errors: + # Traceback (most recent call last): + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl + # return self._call_impl(*args, **kwargs) + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl + # return forward_call(*args, **kwargs) + # File "/fsx/phuc/temp/fp8_for_nanotron/nanotron/src/nanotron/nn/layer_norm.py", line 44, in forward + # return layer_norm_fn( + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py", line 875, in layer_norm_fn + # return LayerNormFn.apply( + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply + # return super().apply(*args, **kwargs) # type: ignore[misc] + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py", line 748, in forward + # y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/flash_attn/ops/triton/layer_norm.py", line 335, in _layer_norm_fwd + # _layer_norm_fwd_1pass_kernel[(M,)]( + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in + # return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs) + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in run + # timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in + # timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 133, in _bench + # return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8)) + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench + # torch.cuda.synchronize() + # File "/fsx/phuc/temp/fp8_for_nanotron/env/lib/python3.10/site-packages/torch/cuda/__init__.py", line 783, in synchronize + # return torch._C._cuda_synchronize() + # RuntimeError: CUDA error: an illegal memory access was encountered + # CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. + # For debugging consider passing CUDA_LAUNCH_BLOCKING=1. + # Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. - return layer_norm_fn( + # from flash_attn.ops.triton.layer_norm import layer_norm_fn + # return layer_norm_fn( + # input, + # self.weight, + # None, + # residual=residual, + # eps=self.eps, + # dropout_p=dropout_p, + # prenorm=prenorm, + # residual_in_fp32=residual_in_fp32, + # is_rms_norm=True, + # return_dropout_mask=return_dropout_mask, + # ) + + # NOTE: fa=2.4.2 + from flash_attn.ops.triton.layernorm import rms_norm_fn + + return rms_norm_fn( input, self.weight, None, @@ -50,6 +99,6 @@ def forward( dropout_p=dropout_p, prenorm=prenorm, residual_in_fp32=residual_in_fp32, - is_rms_norm=True, + # is_rms_norm=True, # NOTE: fa=2.4.2 don't use this? wtf dao return_dropout_mask=return_dropout_mask, ) diff --git a/src/nanotron/optim/clip_grads.py b/src/nanotron/optim/clip_grads.py index d9fe211b..dd5698a2 100644 --- a/src/nanotron/optim/clip_grads.py +++ b/src/nanotron/optim/clip_grads.py @@ -4,6 +4,7 @@ import nanotron.distributed as dist from nanotron import logging +from nanotron.fp8.tensor import FP8Tensor from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.parameters import NanotronParameter @@ -33,9 +34,10 @@ def clip_grad_norm( named_parameters = list(named_parameters) world_rank = dist.get_rank() - # assert that all params require grad for _, p in named_parameters: - assert p.requires_grad, "clip_grad_norm_ only supports Tensors that require grad" + assert p.requires_grad or isinstance( + p.data, FP8Tensor + ), "clip_grad_norm_ only supports Tensors that require grad" if grad_accumulator is None: grads = [ diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 2e940744..ee532eb1 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -9,6 +9,8 @@ import nanotron.distributed as dist from nanotron import logging +from nanotron.fp8.tensor import FP8Tensor +from nanotron.fp8.utils import is_overflow_underflow_nan from nanotron.parallel.parameters import NanotronParameter from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage @@ -60,6 +62,7 @@ def __init__( self, named_parameters: Iterator[Tuple[str, NanotronParameter]], grad_buckets_named_params: Optional[Iterator[Tuple[str, NanotronParameter]]] = None, + master_dtype: torch.dtype = torch.float32, ): """Create a gradient accumulator that will accumulate gradients in fp32. @@ -82,8 +85,14 @@ def __init__( # Assign big buffer for weights + grad in fp32 segment_index = {} length = 0 + master_params = [] for name, param in named_parameters: - if not param.requires_grad: + # NOTE: FP8 Parameter by default has requires_grad=False, + # because we want to do the backward ourself, so here we only skip + # if the parameter isn't fp8, and doesn't require grad + + if self._is_not_required_master_weights(param): + master_params.append((name, param)) continue start = length @@ -92,33 +101,77 @@ def __init__( segment_index[name] = (start, end_weight, param) length = end_weight - big_flat_buffer = torch.empty(length, dtype=torch.float, device="cuda") + big_flat_buffer = torch.empty(length, dtype=master_dtype, device="cuda") self.parameters = { name: { - "fp32": big_flat_buffer[start_weight:end_weight].view_as(param), + "master": big_flat_buffer[start_weight:end_weight].view_as(param), "half": param, } for name, (start_weight, end_weight, param) in segment_index.items() } + self.parameters.update( + { + name: { + "master": param, + "half": param, + } + for name, param in master_params + } + ) + + # NOTE: and since we pass gradient accumulator around + # and other objects access parameter from .get_parameter_for_optimizer() + # so we will keep with torch.inference_mode(): for _, elt in self.parameters.items(): - fp32_param = elt["fp32"] + master_param = elt["master"] half_param = elt["half"] # Check that fp32 weights have the same memory representation as half precision weights - assert fp32_param.stride() == half_param.stride() + assert master_param.stride() == half_param.stride() # Copy weights from half precision to full precision - fp32_param.copy_(half_param) + if not isinstance(half_param.data, FP8Tensor): + master_param.copy_(half_param) + else: + from nanotron import constants + + def find_param_name(param, named_parameters): + for name, p in named_parameters: + if p is param: + return name + return None + + p_name = find_param_name(half_param, named_parameters) + assert p_name is not None + p_data = constants.CPU_WEIGHTS[p_name] + assert p_data.dtype == torch.float32, f"Expected {p_name} to be float32, but got {p_data.dtype}" + + master_param.copy_(constants.CPU_WEIGHTS[p_name]) + + del constants.CPU_WEIGHTS[p_name] + del p_name # Set requires_grad=True - fp32_param.requires_grad = True + master_param.requires_grad = True self._is_accumulation_sync_step = False # We need the last allreduce handle to make sure it finishes before the optimizer step self.fp32_grads_allreduce_handle: Optional[torch.futures.Future] = None + def _is_not_required_master_weights(self, param: NanotronParameter): + # NOTE: There are two scenarios that we don't create master weights + # Scenario 1: Scthe first is if a parameter don't require grad + # Scenario 2: In case of fp8 training, some non-fp8 parameters are in float32 + # so there are no needed master weights + if not isinstance(param.data, FP8Tensor) and not param.requires_grad: + return True + elif param.data.dtype is torch.float32: + return True + else: + return False + def assign_param_offsets(self, param_name_to_offsets: Dict[str, Dict[int, Tuple[int, int]]], dp_rank: int): """To use only when you use with ZeRODistributedOptimizer""" self.param_name_to_offsets = { @@ -154,6 +207,13 @@ def sync_gradients_across_dp(self, dp_pg: dist.ProcessGroup, reduce_op: dist.Red else: dist.all_reduce(self._contiguous_fp32_grad_buffer, op=reduce_op, group=dp_pg) + @classmethod + def _is_accumulate_param(cls, param: NanotronParameter) -> bool: + from nanotron.fp8.tensor import FP8Tensor + + # return param.requires_grad or param.data.__class__ == FP8Tensor + return param.requires_grad or isinstance(param.data, FP8Tensor) + @staticmethod def build_grad_buffers( named_parameters: Iterator[Tuple[str, NanotronParameter]], @@ -165,8 +225,11 @@ def build_grad_buffers( Note: In ZeRO-1, we need to accumulate grads for all parameters, because we need to allreduce all parameters' grads across DP at each sync step. + Unlike master weights, we create fp32 gradient buffers for both fp8 parameters and non-fp8 parameters. """ - named_parameters = [(name, param) for name, param in named_parameters if param.requires_grad] + named_parameters = [ + (name, param) for name, param in named_parameters if FP32GradientAccumulator._is_accumulate_param(param) + ] needed_buffer_size = sum(param.numel() for _, param in named_parameters) # important to have grads zeroed initially (see `self._accumulate_grad`) @@ -178,16 +241,19 @@ def build_grad_buffers( fp32_grad_buffers = OrderedDict() # keeps order of insertion offset = 0 for name, param in named_parameters: - if not param.requires_grad: + # NOTE: because fp8 parameter by default has `requires_grad=False`, + # but we still need to accumulate grads for it + # if not param.requires_grad: + if FP32GradientAccumulator._is_accumulate_param(param) is False: continue - assert param.dtype != torch.float, f"Expected {name} not to be float" + # assert param.dtype != torch.float, f"Expected {name} not to be float" assert param.is_contiguous(), f"Expected {name} to be contiguous" next_offset = offset + param.numel() * element_size fp32_grad_buffer = tensor_from_untyped_storage( - untyped_storage=untyped_storage[offset:next_offset], dtype=torch.float + untyped_storage=untyped_storage[offset:next_offset], dtype=torch.float32 ) fp32_grad_buffers[name] = { @@ -212,11 +278,20 @@ def backward(self, loss: torch.Tensor): def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: """Accumulate grad in fp32 and set the fp32 grad to the fp32 grad buffer, so that optimizer can update fp32 weights afterwards""" assert half_param.grad is not None, f"Expected param {name} to have gradient." + from nanotron.fp8.tensor import convert_tensor_from_fp8 + + if isinstance(half_param.data, FP8Tensor): + grad = convert_tensor_from_fp8(half_param.grad, half_param.grad.fp8_meta, torch.float32) + else: + grad = half_param.grad + + assert is_overflow_underflow_nan(grad) is False, f"Detected overflow/underflow/nan in {name} grad" + fp32_grad = self.get_grad_buffer(name=name) if self._is_accumulation_sync_step is False: # WARNING: We assume fp32_grad_bucket is already zeroed - fp32_grad.add_(half_param.grad) + fp32_grad.add_(grad) # In case _is_accumulation_sync_step = True: no need to add half gradients, because it's done in the allreduce hook # TODO @thomasw21: Is it better to set to zero instead? @@ -224,7 +299,7 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: # In the case an optimizer decides to set it to None, we need to re-assign previous buffer if name in self.parameters: - fp32_param = self.parameters[name]["fp32"] + master_param = self.parameters[name]["master"] if hasattr(self, "param_name_to_offsets"): if name not in self.param_name_to_offsets: # When `name` isn't in `param_name_to_offsets` it means the slice is empty. @@ -233,7 +308,8 @@ def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: grad = fp32_grad.view(-1)[start_offset:end_offset] else: grad = fp32_grad - fp32_param.grad = grad + master_param.grad = grad + assert is_overflow_underflow_nan(master_param.grad) is False @contextmanager def no_sync(self): @@ -256,11 +332,15 @@ def step(self): We need to update only the parameters that were updated by the optimizer. """ for name in self.parameters.keys(): - fp32_param = self.parameters[name]["fp32"] + master_param = self.parameters[name]["master"] half_param = self.parameters[name]["half"] + # TODO @nouamane: should we use a fused kernel to copy? # Copy weights from full precision to half precision - half_param.copy_(fp32_param) + if half_param.data.__class__ == FP8Tensor: + half_param.data.set_data(master_param, sync=False) + else: + half_param.copy_(master_param) def zero_grad(self): # Full precision gradients are reset to zero/none after the underlying `optimiser.step`, so no need to reset. @@ -276,22 +356,22 @@ def zero_grad(self): self._contiguous_fp32_grad_buffer.zero_() def get_parameter_for_optimizer(self, name: str) -> NanotronParameter: - return self.parameters[name]["fp32"] + return self.parameters[name]["master"] def get_grad_buffer(self, name: str) -> torch.Tensor: """Returns the gradient of the parameter from the appropriate grad bucket.""" return self.fp32_grad_buffers[name]["fp32_grad"] def state_dict(self) -> Dict[str, torch.Tensor]: - # We consider `fp32` parameters as a state of the gradient accumulator - return {name: elt["fp32"] for name, elt in self.parameters.items()} + # We consider master parameters as a state of the gradient accumulator + return {name: elt["master"] for name, elt in self.parameters.items()} def load_state_dict(self, state_dict: Dict[str, torch.Tensor]): assert set(state_dict.keys()) == set(self.parameters.keys()) with torch.inference_mode(): for name, elt in self.parameters.items(): - elt["fp32"].copy_(state_dict[name]) + elt["master"].copy_(state_dict[name]) @dataclasses.dataclass @@ -326,6 +406,9 @@ def get_fp32_accum_hook( # s = torch.cuda.Stream() def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]: + import pydevd + + pydevd.settrace(suspend=True, trace_only_current_thread=True) # nonlocal s # DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation. # See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details. diff --git a/src/nanotron/parallel/parameters.py b/src/nanotron/parallel/parameters.py index 4aaaf011..702a1e80 100644 --- a/src/nanotron/parallel/parameters.py +++ b/src/nanotron/parallel/parameters.py @@ -1,7 +1,9 @@ import dataclasses -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import torch +from functorch.dim import tree_map from torch import nn from nanotron import distributed as dist @@ -93,7 +95,7 @@ def is_dp_sharded(self, parallel_context): class NanotronParameter(nn.Parameter): - """Base class for all parameters in Nanotronmodels + """Base class for all parameters in Nanotron models A NanotronParameter can have specific properties: - sharded: the parameter is considered to be `sharded` across multiple devices @@ -107,12 +109,37 @@ class NanotronParameter(nn.Parameter): - Even if some weights don't need their grads to be reduced, it's still useful for them to be marked as tied. For example, current serialization format requires to mark them correctly. """ + # __torch_function__ = torch._C._disabled_torch_function_impl + NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME = "__nanotron_metadata__" NANOTRON_PARAMETER_METADATA_TIED_KEY = "tied" NANOTRON_PARAMETER_METADATA_SHARDED_KEY = "sharded" def __new__(cls, tensor: torch.Tensor, requires_grad: bool = True): - param = nn.Parameter.__new__(cls, data=tensor.data.detach(), requires_grad=requires_grad) + assert tensor.data.is_floating_point() or tensor.data.requires_grad is False + + if tensor.data.is_floating_point(): + if tensor.__class__ == nn.Parameter: + data = tensor.data + data.requires_grad = requires_grad + else: + data = tensor + else: + # NOTE: FP8 tensor has int dtype, you can't .detach() an integer tensor! + data = tensor.data + requires_grad = False + + # NOTE: this somehow makes the param has the methods of NanotronParameter + param = nn.Parameter._make_wrapper_subclass( + cls, + size=data.size(), + strides=data.stride(), + storage_offset=data.storage_offset(), + dtype=data.dtype, + layout=data.layout, + device=data.device, + requires_grad=requires_grad, + ) if isinstance(tensor, NanotronParameter): # Check that we don't inherit a weird class @@ -128,6 +155,13 @@ def __new__(cls, tensor: torch.Tensor, requires_grad: bool = True): return param + def __init__(self, tensor: Union[torch.Tensor, "FP8Tensor"]): + self._data = tensor + # NOTE: whether we will quantize this parameter + # because we need to know a parameter will be in fp8 or not + # so we create master weights of the fp32 parameters before quantizing + self._is_future_fp8 = False + def _set_metadata(self, key: str, value: Any): metadata = getattr(self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME) @@ -181,12 +215,88 @@ def get_sharded_info(self) -> ShardedInfo: self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY ] + @classmethod + def create_param_that_share_metadata(cls, tensor: torch.Tensor, param: Union[nn.Parameter, "NanotronParameter"]): + """Create a new parameter that shares the metadata and hash of the given parameter""" + # TODO(xrsrke): support deepcopy for tied parameter's metadata, because it includes an all-reduce + # which if we do deepcopy, it raises an error + metadata = deepcopy(getattr(param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME, {})) + + # Copy metadata to the new parameter + new_param = NanotronParameter(tensor) + setattr(new_param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME, metadata) + + # TODO(xrsrke): sync all the attributes in the param + # to the new parameter? in case, user sets some attributes + # then the new parameter is kinda lost it + return new_param + @property def is_sharded(self) -> bool: return self.NANOTRON_PARAMETER_METADATA_SHARDED_KEY in getattr( self, self.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME ) + def __repr__(self): + # return f"NanotronParameter({super().__repr__()})" + return "NanotronParameter()" + + @property + def data(self): + # from nanotron.fp8.parameter import FP8Parameter + return self._data + + @data.setter + def data(self, data): + self._data = data + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + from nanotron.fp8.tensor import FP8Tensor + + def unwrap(e): + return e._data if e.__class__ == NanotronParameter else e + + def wrap(e): + if not e.__class__ == NanotronParameter and e.__class__ in [torch.Tensor, FP8Tensor]: + return cls(e) + else: + return e + + unwrapped_args = tree_map(unwrap, args) + unwrapped_kwargs = tree_map(unwrap, kwargs) + + OPS_THAT_RETURN_ORIGINAL_TENSOR = [ + # NOTE: transpose operation + torch.ops.aten.t.default, + torch.ops.aten.view.default, + torch.ops.aten.detach.default, + # NOTE: F.embedding() + torch.ops.aten.embedding.default, + # NOTE: F.layer_norm() + torch.ops.aten.native_layer_norm.default, + torch.ops.aten.native_layer_norm_backward.default, + torch.ops.aten.native_layer_norm_backward.default, + # NOTE: nn.Linear + torch.ops.aten.addmm.default, + torch.ops.aten.linear.default, + # NOTE: x.to(device) + torch.ops.aten._to_copy.default, + ] + + if func == torch.ops.aten.detach.default and unwrapped_args[0].__class__ == FP8Tensor: + # NOTE: this is for parameter.data or parameter.detach() + # NOTE: because we already retrieved the data from unwrap, we don't need to do it again + data = unwrapped_args[0] + return data + else: + outputs = func(*unwrapped_args, **unwrapped_kwargs) + + if func in OPS_THAT_RETURN_ORIGINAL_TENSOR: + return outputs + else: + return tree_map(wrap, outputs) + def sanity_check(root_module: nn.Module): """Makes sure that the module is in Nanotronformat diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index e2ee3a29..7a46fe98 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -19,6 +19,8 @@ from torch.nn import functional as F import nanotron.distributed as dist +from nanotron.fp8.linear import FP8LinearMeta +from nanotron.fp8.recipe import FP8LinearRecipe from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( differentiable_all_reduce_sum, differentiable_identity, @@ -436,18 +438,30 @@ def column_linear( tp_mode: TensorParallelLinearMode, async_communication: bool, tp_recompute_allgather: bool = True, + metadatas: Optional[FP8LinearMeta] = None, + name: Optional[str] = None, + recipe: Optional[FP8LinearRecipe] = None, ): if async_communication: return _ColumnLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode, tp_recompute_allgather) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + import nanotron.fp8.functional as fp8_functional + from nanotron.fp8.tensor import FP8Tensor + input = differentiable_identity(input, group=group) - return F.linear(input, weight, bias) - if tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + + if isinstance(weight.data, FP8Tensor): + assert recipe is not None, "recipe must be provided for column_linear" + return fp8_functional.linear(input, weight, bias, metadatas=metadatas, recipe=recipe, name=name) + else: + return F.linear(input, weight, bias) + elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: return _ColumnLinearNoAsyncCommunicationReduceScatterMode.apply( input, weight, bias, group, tp_recompute_allgather ) - raise ValueError(f"Got unexpected mode: {tp_mode}.") + else: + raise ValueError(f"Got unexpected mode: {tp_mode}.") class _RowLinearAsyncCommunication(torch.autograd.Function): @@ -588,11 +602,21 @@ def row_linear( group: dist.ProcessGroup, tp_mode: TensorParallelLinearMode, async_communication: bool, + metadatas: Optional[FP8LinearMeta] = None, + recipe: Optional[FP8LinearRecipe] = None, + name: Optional[str] = None, ): if async_communication: return _RowLinearAsyncCommunication.apply(input, weight, bias, group, tp_mode) - out = F.linear(input, weight, bias) + import nanotron.fp8.functional as fp8_functional + from nanotron.fp8.tensor import FP8Tensor + + if isinstance(weight.data, FP8Tensor): + assert recipe is not None, "recipe must be provided for row_linear" + out = fp8_functional.linear(input, weight, bias, metadatas=metadatas, recipe=recipe, name=name) + else: + out = F.linear(input, weight, bias) if tp_mode is TensorParallelLinearMode.ALL_REDUCE: out = differentiable_all_reduce_sum(out, group=group) diff --git a/src/nanotron/parallel/tensor_parallel/nn.py b/src/nanotron/parallel/tensor_parallel/nn.py index 4c7325cd..debc8f06 100644 --- a/src/nanotron/parallel/tensor_parallel/nn.py +++ b/src/nanotron/parallel/tensor_parallel/nn.py @@ -19,6 +19,10 @@ from nanotron import distributed as dist from nanotron.distributed import get_global_rank +from nanotron.fp8.constants import FP8LM_LINEAR_RECIPE +from nanotron.fp8.linear import FP8Linear +from nanotron.fp8.recipe import FP8LinearRecipe +from nanotron.fp8.tensor import FP8Tensor from nanotron.parallel.parameters import NanotronParameter from nanotron.parallel.sharded_parameters import ( SplitConfig, @@ -39,19 +43,20 @@ from nanotron.parallel.tied_parameters import create_tied_parameter -class TensorParallelColumnLinear(nn.Linear): +class _BaseTensorParallelColumnLinear: def __init__( self, in_features, out_features, pg: dist.ProcessGroup, mode: TensorParallelLinearMode, - bias=True, - device=None, - dtype=None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: torch.dtype = None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, - tp_recompute_allgather: bool = True, + name: Optional[str] = None, + recipe: Optional[FP8LinearRecipe] = None, ): self.pg = pg self.world_size = pg.size() @@ -60,15 +65,19 @@ def __init__( self.in_features = in_features self.out_features = out_features // self.world_size - self.tp_recompute_allgather = tp_recompute_allgather + self.name = name - super().__init__( - in_features=self.in_features, - out_features=self.out_features, - bias=bias, - device=device, - dtype=dtype, - ) + init_args = { + "in_features": self.in_features, + "out_features": self.out_features, + "bias": bias, + "device": device, + "dtype": dtype, + } + if self.__class__ is FP8TensorParallelColumnLinear: + init_args["recipe"] = recipe + + super().__init__(**init_args) self.mode = mode self.async_communication = async_communication @@ -93,25 +102,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: group=self.pg, tp_mode=self.mode, async_communication=self.async_communication, - tp_recompute_allgather=self.tp_recompute_allgather, ) def extra_repr(self) -> str: return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}" -class TensorParallelRowLinear(nn.Linear): +class _BaseTensorParallelRowLinear: def __init__( self, in_features, out_features, pg: dist.ProcessGroup, mode: TensorParallelLinearMode, - bias=True, - device=None, - dtype=None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: torch.dtype = None, async_communication: bool = False, contiguous_chunks: Optional[Tuple[int, ...]] = None, + name: Optional[str] = None, + # TODO(xrsrke): remove this from base class + recipe: Optional[FP8LinearRecipe] = None, ): self.pg = pg self.world_size = pg.size() @@ -120,17 +131,22 @@ def __init__( self.in_features = in_features // self.world_size self.out_features = out_features + self.name = name # No need to shard the bias term, only rank 0 would have it bias = dist.get_rank(self.pg) == 0 and bias - super().__init__( - in_features=self.in_features, - out_features=self.out_features, - bias=bias, - device=device, - dtype=dtype, - ) + init_args = { + "in_features": self.in_features, + "out_features": self.out_features, + "bias": bias, + "device": device, + "dtype": dtype, + } + if self.__class__ is FP8TensorParallelRowLinear: + init_args["recipe"] = recipe + + super().__init__(**init_args) self.mode = mode self.async_communication = async_communication if self.mode is TensorParallelLinearMode.ALL_REDUCE and self.async_communication: @@ -172,6 +188,84 @@ def extra_repr(self) -> str: return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_in_features={self.in_features * self.world_size}" +class TensorParallelColumnLinear(_BaseTensorParallelColumnLinear, nn.Linear): + """Non-quantized tensor parallel column linear layer.""" + + pass + + +class TensorParallelRowLinear(_BaseTensorParallelRowLinear, nn.Linear): + """Non-quantized tensor parallel row linear layer.""" + + pass + + +class FP8TensorParallelColumnLinear(_BaseTensorParallelColumnLinear, FP8Linear): + def __init__(self, *args, recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE, **kwargs): + super().__init__(*args, **kwargs, recipe=recipe) + self.recipe = recipe + + def __post_init__(self): + assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}" + assert ( + self.metadatas is not None + ), "It seems like something went wrong in the initialization of FP8TensorParallelColumnLinear" + + def extra_repr(self) -> str: + extra = "" + + if isinstance(self.weight.data, FP8Tensor): + extra = f"{self.weight.data.fp8_meta}" + + return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_out_features={self.out_features * self.world_size}, {extra}" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return column_linear( + input=x, + weight=self.weight, + bias=self.bias, + group=self.pg, + tp_mode=self.mode, + async_communication=self.async_communication, + metadatas=self.metadatas, + name=self.name, + recipe=self.recipe, + ) + + +class FP8TensorParallelRowLinear(_BaseTensorParallelRowLinear, FP8Linear): + def __init__(self, *args, recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE, **kwargs): + super().__init__(*args, **kwargs, recipe=recipe) + self.recipe = recipe + + def __post_init__(self): + assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}" + assert ( + self.metadatas is not None + ), "It seems like something went wrong in the initialization of FP8TensorParallelColumnLinear" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return row_linear( + input=x, + weight=self.weight, + bias=self.bias, + group=self.pg, + tp_mode=self.mode, + async_communication=self.async_communication, + metadatas=self.metadatas, + name=self.name, + recipe=self.recipe, + ) + + def extra_repr(self) -> str: + extra = "" + + if isinstance(self.weight.data, FP8Tensor): + extra = f"{self.weight.data.fp8_meta}" + + return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_in_features={self.in_features * self.world_size}, {extra}" + + class TiedLinear(nn.Linear): def __init__( self, diff --git a/src/nanotron/sanity_checks.py b/src/nanotron/sanity_checks.py index 56ef1e2e..ed0f7daa 100644 --- a/src/nanotron/sanity_checks.py +++ b/src/nanotron/sanity_checks.py @@ -6,9 +6,12 @@ from nanotron import distributed as dist from nanotron import logging, optim from nanotron.config import Config +from nanotron.fp8.tensor import FP8Tensor +from nanotron.fp8.utils import is_overflow_underflow_nan from nanotron.logging import get_logger, log_rank from nanotron.models import NanotronModel from nanotron.optim.gradient_accumulator import GradientAccumulator +from nanotron.optim.optimizer_from_gradient_accumulator import OptimizerFromGradientAccumulator from nanotron.parallel import ParallelContext from nanotron.parallel.tied_parameters import get_tied_id_to_param @@ -60,6 +63,10 @@ def before_tbi_sanity_checks( grad_accumulator: GradientAccumulator, lr_scheduler: torch.optim.lr_scheduler.LRScheduler, ) -> None: + + # TODO(xrsrke): sanity check that _is_future_fp8 is consistent with the dtype of the parameter + # TODO(xrsrke): sanity check if the optimizer, gradient accumulator points to the correct model parameters + if not config.general.ignore_sanity_checks: # SANITY CHECK: Check that the model params are synchronized across dp for name, param in sorted(unwrapped_model.named_parameters(), key=lambda x: x[0]): @@ -132,7 +139,7 @@ def after_tbi_sanity_checks( # SANITY CHECK: Check that all parameters that required gradients, have actually a gradient # SANITY CHECK: Check for nan/inf for name, param in unwrapped_model.named_parameters(): - if not param.requires_grad: + if not param.requires_grad and not isinstance(param.data, FP8Tensor): continue if param.is_tied: @@ -163,16 +170,34 @@ def before_optim_step_sanity_checks( config: Config, parallel_context: ParallelContext, unwrapped_model: NanotronModel, + optim: OptimizerFromGradientAccumulator, grad_accumulator: GradientAccumulator, optimizer: optim.BaseOptimizer, ) -> None: + + # NOTE: sanity check that non-fp8 parameters's gradients have + # the same datatype of the residual stream's dtype + for pg in optim.param_groups: + for p in pg["params"]: + assert p.grad is not None + if isinstance(p.data, FP8Tensor): + assert p.grad.dtype in [torch.uint8, torch.int8], f"got {p.grad.dtype}" + else: + assert p.grad.dtype == config.fp8.resid_dtype + assert is_overflow_underflow_nan(p.grad) is False + + # TODO(xrsrke): we should sanity the gradients of parameters that optimizer points + # to because in the case of gradient accumulator, after accumulating gradients + # we set half_param.grad = None, and this also makes sense because + # optimizer's parameters' gradients are the one that affects updated weights + if not config.general.ignore_sanity_checks: # SANITY CHECK: Test tied weights gradients are synchronized for (name, group_ranks), param in sorted( get_tied_id_to_param(parameters=unwrapped_model.parameters(), root_module=unwrapped_model).items(), key=lambda x: x[0], ): - if not param.requires_grad: + if not param.requires_grad and not isinstance(param.data, FP8Tensor): continue if grad_accumulator is not None: @@ -249,7 +274,7 @@ def after_optim_step_sanity_checks( if not config.general.ignore_sanity_checks: # SANITY CHECK: Check that gradients is cleared for name, param in unwrapped_model.named_parameters(): - if not param.requires_grad: + if not param.requires_grad and not isinstance(param.data, FP8Tensor): continue if param.grad is not None: diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..819617f6 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -37,6 +37,7 @@ def __init__(self, config: ModelArgs): TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, + nn.LayerNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, } @@ -44,35 +45,34 @@ def __init__(self, config: ModelArgs): self.num_layers = config.model_config.num_hidden_layers def _parametrize_column_linear(self, param_name: str, module: nn.Module): - assert param_name in ["weight", "bias"] + assert any(x in param_name for x in ["weight", "bias"]) - if "weight" == param_name: + if "weight" in param_name: init.normal_(module.weight, mean=0.0, std=self.std) elif "bias" == param_name: module.bias.zero_() def _parametrize_row_linear(self, param_name: str, module: nn.Module): - assert param_name in ["weight", "bias"] + assert any(x in param_name for x in ["weight", "bias"]) - if "weight" == param_name: + if "weight" in param_name: std = self.std / math.sqrt(2 * self.num_layers) init.normal_(module.weight, mean=0.0, std=std) - elif "bias" == param_name: + elif "bias" in param_name: module.bias.zero_() def _parametrize_layer_norm(self, param_name: str, module: nn.Module): - assert param_name in ["weight", "bias"] - - if "weight" == param_name: + assert any(x in param_name for x in ["weight", "bias"]) + if "weight" in param_name: # TODO @thomasw21: Sometimes we actually want 0 module.weight.fill_(1) - elif "bias" == param_name: + elif "bias" in param_name: module.bias.zero_() def _parametrize_embedding(self, param_name: str, module: nn.Module): - assert param_name in ["weight"] + assert "weight" in param_name - if "weight" == param_name: + if "weight" in param_name: init.normal_(module.weight, mean=0.0, std=self.std) diff --git a/src/nanotron/testing/llama.py b/src/nanotron/testing/llama.py new file mode 100644 index 00000000..4085db00 --- /dev/null +++ b/src/nanotron/testing/llama.py @@ -0,0 +1,142 @@ +# NOTE: I moved it from helpers/llama.py to here to here for +# tests/{folder}/test_x.py to be able to import it + +import torch +from nanotron.config import ( + AllForwardAllBackwardPipelineEngine, + CheckpointsArgs, + Config, + DataArgs, + DatasetStageArgs, + GeneralArgs, + LlamaConfig, + LoggingArgs, + LRSchedulerArgs, + ModelArgs, + OptimizerArgs, + ParallelismArgs, + TensorParallelLinearMode, + TokenizerArgs, + TokensArgs, +) +from nanotron.config.config import AdamWOptimizerArgs, PretrainDatasetsArgs +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel.context import ParallelContext +from nanotron.trainer import mark_tied_parameters + +TINY_LLAMA_CONFIG = LlamaConfig( + **{ + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 16, + "initializer_range": 0.02, + "intermediate_size": 32, + "is_llama_config": True, + "max_position_embeddings": 128, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "num_key_value_heads": 4, + "pad_token_id": None, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 4096, + } +) + + +def get_llama_training_config(model_config: ModelArgs): + return Config( + model=model_config, + general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), + checkpoints=CheckpointsArgs( + checkpoints_path="./checkpoints", + checkpoint_interval=10, + ), + parallelism=ParallelismArgs( + dp=1, + pp=1, + tp=2, + expert_parallel_size=2, + pp_engine="1f1b", + tp_mode="ALL_REDUCE", + tp_linear_async_communication=False, + ), + tokenizer=TokenizerArgs("gpt2"), + optimizer=OptimizerArgs( + optimizer_factory=AdamWOptimizerArgs( + adam_eps=1e-08, adam_beta1=0.9, adam_beta2=0.95, torch_adam_is_fused=True, name="adamW" + ), + zero_stage=0, + weight_decay=0.01, + clip_grad=1.0, + accumulate_grad_in_fp32=False, + learning_rate_scheduler=LRSchedulerArgs( + learning_rate=3e-4, + lr_warmup_steps=100, + lr_warmup_style="linear", + lr_decay_style="cosine", + min_decay_lr=1e-5, + ), + ), + logging=LoggingArgs(), + tokens=TokensArgs(sequence_length=16, train_steps=10, micro_batch_size=16, batch_accumulation_per_replica=1), + data_stages=[ + DatasetStageArgs( + name="train", + start_training_step=1, + data=DataArgs( + seed=42, + num_loading_workers=1, + dataset=PretrainDatasetsArgs( + hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small", + hf_dataset_splits="train", + text_column_name="completion", + dataset_processing_num_proc_per_process=12, + ), + ), + ) + ], + ) + + +def create_llama_from_config( + model_config: LlamaConfig, + device: torch.device, + parallel_context: ParallelContext, + dtype: torch.dtype = torch.bfloat16, +) -> LlamaForTraining: + + """ + Creates and returns a nanotron model. + If `model_config` is None, then `checkpoint_path` must be set, in which case + the configuration will be loaded from such path. + If `checkpoint_path` is None, then `model_config` must be set, in which case + the model created will have random weights. + """ + + parallel_config = ParallelismArgs( + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + tp=parallel_context.tensor_parallel_size, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + model = build_model( + model_builder=lambda: LlamaForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model, parallel_context=parallel_context) + return model diff --git a/src/nanotron/testing/parallel.py b/src/nanotron/testing/parallel.py new file mode 100644 index 00000000..1dbe8a3c --- /dev/null +++ b/src/nanotron/testing/parallel.py @@ -0,0 +1,159 @@ +import os +import re +from inspect import signature +from typing import Callable + +import torch.cuda +import torch.multiprocessing as mp +from nanotron.parallel import ParallelContext +from packaging import version + + +def global_wrapper(rank, func, tp, pp, dp, port, kwargs): + def setup_dist_env(rank, world_size, port): + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(rank) + # NOTE: since we do unit tests in a + # single node => this is fine! + os.environ["LOCAL_RANK"] = str(rank) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + world_size = tp * pp * dp + setup_dist_env(rank, world_size, port) + parallel_context = ParallelContext(data_parallel_size=dp, pipeline_parallel_size=pp, tensor_parallel_size=tp) + func(parallel_context, **kwargs) + + +def init_distributed(tp: int, dp: int, pp: int): + def _init_distributed(func): + def wrapper(**kwargs): + from nanotron.utils import find_free_port + + world_size = tp * pp * dp + port = find_free_port() + + # Note that kwargs needs to be passed as part of args in a way that can be unpacked + args = (func, tp, pp, dp, port, kwargs) + mp.spawn(global_wrapper, args=args, nprocs=world_size) + + return wrapper + + return _init_distributed + + +def rerun_if_address_is_in_use(max_try: int = 500): + """ + This function reruns a wrapped function if "address already in use" occurs + in testing spawned with torch.multiprocessing + + Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L157 + + Usage:: + + @rerun_if_address_is_in_use() + def test_something(): + ... + + """ + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major >= 1 + + # only torch >= 1.8 has ProcessRaisedException + if torch_version >= version.parse("1.8.0"): + exception = torch.multiprocessing.ProcessRaisedException + else: + exception = Exception + + func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=max_try) + return func_wrapper + + +def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 10) -> Callable: + """ + A decorator on a function to re-run when an exception occurs. + + Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L71 + + Usage:: + + # rerun for all kinds of exception + @rerun_on_exception() + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for RuntimeError only + @rerun_on_exception(exception_type=RuntimeError) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for maximum 10 times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=10) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=None) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun only the exception message is matched with pattern + # for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + Args: + exception_type (Exception, Optional): The type of exception to detect for rerun + pattern (str, Optional): The pattern to match the exception message. + If the pattern is not None and matches the exception message, + the exception will be detected for rerun + max_try (int, Optional): Maximum reruns for this function. The default value is 5. + If max_try is None, it will rerun forever if exception keeps occurring + """ + + def _match_lines(lines, pattern): + for line in lines: + if re.match(pattern, line): + return True + return False + + def _wrapper(func): + def _run_until_success(*args, **kwargs): + try_count = 0 + assert max_try is None or isinstance( + max_try, int + ), f"Expected max_try to be None or int, but got {type(max_try)}" + + while max_try is None or try_count < max_try: + try: + try_count += 1 + ret = func(*args, **kwargs) + return ret + except exception_type as e: + error_lines = str(e).split("\n") + if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): + + print("Exception is caught, retrying...") + # when pattern is not specified, we always skip the exception + # when pattern is specified, we only skip when pattern is matched + continue + else: + print("Maximum number of attempts is reached or pattern is not matched, no more retrying...") + raise e + + # Override signature + # otherwise pytest.mark.parameterize will raise the following error: + # function does not use argument xxx + sig = signature(func) + _run_until_success.__signature__ = sig + + return _run_until_success + + return _wrapper diff --git a/src/nanotron/testing/utils.py b/src/nanotron/testing/utils.py new file mode 100644 index 00000000..52009dbb --- /dev/null +++ b/src/nanotron/testing/utils.py @@ -0,0 +1,299 @@ +import shutil +import uuid +from functools import lru_cache +from pathlib import Path + + +class TestContext: + def __init__(self): + self._random_string = str(uuid.uuid1()) + self._root_dir = Path(__file__).parent.parent / ".test_cache" + self._root_dir.mkdir(parents=True, exist_ok=True) + + @lru_cache(maxsize=1) + def get_auto_remove_tmp_dir(self): + path = self._root_dir / self._random_string + path.mkdir(parents=True, exist_ok=True) + return path + + def __del__(self): + path = self.get_auto_remove_tmp_dir() + shutil.rmtree(path) + + +import contextlib +import os +import re +from inspect import signature +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch.cuda +import torch.multiprocessing as mp +from nanotron.parallel import ParallelContext +from packaging import version + + +def available_gpus(): + if not torch.cuda.is_available(): + return 0 + + device_properties = [torch.cuda.get_device_properties(i) for i in range(torch.cuda.device_count())] + + # We filter out + blacklisted_gpu_names = {"NVIDIA DGX Display"} + device_properties = [property_ for property_ in device_properties if property_.name not in blacklisted_gpu_names] + + # TODO @thomasw21: Can we do this cross node + return len(device_properties) + + +# from https://stackoverflow.com/a/34333710/9201239 +@contextlib.contextmanager +def mock_os_environ(remove_keys: List[str] = None, update_key_values: Dict[str, Any] = None): + """ + Temporarily updates the ``os.environ`` dictionary in-place. + The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations. + Args: + remove_keys: Environment variables to remove. + update_key_values: Dictionary of environment variables and values to add/update. + """ + env = os.environ + update_key_values = update_key_values or {} + remove_keys = remove_keys or [] + + update_keys = set(update_key_values.keys()) + remove_keys = set(remove_keys) + assert remove_keys.isdisjoint(update_keys) + + stomped = (update_keys | remove_keys) & set(env.keys()) + reverse_change = { + # Environment variables and values to restore on exit. + **{k: env[k] for k in update_keys & stomped}, + # Environment variables and values to remove on exit. + **{k: env[k] for k in remove_keys & stomped}, + } + + try: + env.update(update_key_values) + for k in remove_keys: + env.pop(k, None) + yield + finally: + env.update(reverse_change) + + +def is_dict_equal(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None) -> Tuple[bool, Optional[str]]: + """Returns True or False if the dictionaries match, and an additional message when it's False""" + if sub_paths is None: + sub_paths = [] + + first_keys = set(first.keys()) + second_keys = set(second.keys()) + if first_keys != second_keys: + return False, f"Keys don't match in {'.'.join(sub_paths)}.\nCur: {first_keys}\nRef: {second_keys}" + for key in first_keys: + first_elt = first[key] + second_elt = second[key] + + if isinstance(first_elt, dict): + if not isinstance(second_elt, dict): + return ( + False, + f"Object types don't match in {'.'.join(sub_paths + [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}", + ) + match, msg = is_dict_equal(first_elt, second_elt, sub_paths=sub_paths + [str(key)]) + if match is False: + return False, msg + elif isinstance(first_elt, torch.Tensor): + if not isinstance(second_elt, torch.Tensor): + return ( + False, + f"Object types don't match in {'.'.join(sub_paths + [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}", + ) + try: + torch.testing.assert_close( + first_elt, + second_elt, + atol=0.0, + rtol=0.0, + msg=lambda msg: f"Tensor at {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}\n{msg}", + ) + except AssertionError as error: + return False, error.args[0] + else: + if first_elt != second_elt: + return ( + False, + f"Objects at key {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}", + ) + + return True, None + + +def get_all_3d_configurations(gpus: int) -> List[Tuple[int, int, int]]: + """Given a number of gpus, we want all 3d configurations possible such that pp * dp * tp = gpus""" + result = [] + for tp in range(1, gpus + 1): + if gpus % tp != 0: + continue + gpus_left_after_tp = gpus // tp + for dp in range(1, gpus_left_after_tp + 1): + if gpus_left_after_tp % dp != 0: + continue + gpus_left_after_dp = gpus_left_after_tp // dp + for pp in range(1, gpus_left_after_dp + 1): + if gpus_left_after_dp % pp != 0: + continue + if tp * dp * pp == gpus: + result.append((pp, dp, tp)) + return result + + +def rerun_if_address_is_in_use(max_try: int = 500): + """ + This function reruns a wrapped function if "address already in use" occurs + in testing spawned with torch.multiprocessing + + Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L157 + + Usage:: + + @rerun_if_address_is_in_use() + def test_something(): + ... + + """ + # check version + torch_version = version.parse(torch.__version__) + assert torch_version.major >= 1 + + # only torch >= 1.8 has ProcessRaisedException + if torch_version >= version.parse("1.8.0"): + exception = torch.multiprocessing.ProcessRaisedException + else: + exception = Exception + + func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=max_try) + return func_wrapper + + +def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 10) -> Callable: + """ + A decorator on a function to re-run when an exception occurs. + + Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L71 + + Usage:: + + # rerun for all kinds of exception + @rerun_on_exception() + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for RuntimeError only + @rerun_on_exception(exception_type=RuntimeError) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for maximum 10 times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=10) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, max_try=None) + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + # rerun only the exception message is matched with pattern + # for infinite times if Runtime error occurs + @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") + def test_method(): + print('hey') + raise RuntimeError('Address already in use') + + Args: + exception_type (Exception, Optional): The type of exception to detect for rerun + pattern (str, Optional): The pattern to match the exception message. + If the pattern is not None and matches the exception message, + the exception will be detected for rerun + max_try (int, Optional): Maximum reruns for this function. The default value is 5. + If max_try is None, it will rerun forever if exception keeps occurring + """ + + def _match_lines(lines, pattern): + for line in lines: + if re.match(pattern, line): + return True + return False + + def _wrapper(func): + def _run_until_success(*args, **kwargs): + try_count = 0 + assert max_try is None or isinstance( + max_try, int + ), f"Expected max_try to be None or int, but got {type(max_try)}" + + while max_try is None or try_count < max_try: + try: + try_count += 1 + ret = func(*args, **kwargs) + return ret + except exception_type as e: + error_lines = str(e).split("\n") + if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): + + print("Exception is caught, retrying...") + # when pattern is not specified, we always skip the exception + # when pattern is specified, we only skip when pattern is matched + continue + else: + print("Maximum number of attempts is reached or pattern is not matched, no more retrying...") + raise e + + # Override signature + # otherwise pytest.mark.parameterize will raise the following error: + # function does not use argument xxx + sig = signature(func) + _run_until_success.__signature__ = sig + + return _run_until_success + + return _wrapper + + +def global_wrapper(rank, func, tp, pp, dp, port, kwargs): + def setup_dist_env(rank, world_size, port): + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(rank) + # NOTE: since we do unit tests in a + # single node => this is fine! + os.environ["LOCAL_RANK"] = str(rank) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + + world_size = tp * pp * dp + setup_dist_env(rank, world_size, port) + parallel_context = ParallelContext(data_parallel_size=dp, pipeline_parallel_size=pp, tensor_parallel_size=tp) + func(parallel_context, **kwargs) + + +def init_distributed(tp: int, dp: int, pp: int): + def _init_distributed(func): + def wrapper(**kwargs): + from nanotron.utils import find_free_port + + world_size = tp * pp * dp + port = find_free_port() + + # Note that kwargs needs to be passed as part of args in a way that can be unpacked + args = (func, tp, pp, dp, port, kwargs) + mp.spawn(global_wrapper, args=args, nprocs=world_size) + + return wrapper + + return _init_distributed diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 94b03c6e..38444623 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -37,6 +37,8 @@ ) from nanotron.constants import MODEL_CONFIG_FILE_NAME from nanotron.dataloader import sanity_check_dataloader +from nanotron.fp8.tensor import FP8Tensor +from nanotron.fp8.utils import convert_model_to_fp8 from nanotron.helpers import ( _vocab_size_with_padding, compute_remain_train_steps_of_a_data_stage_from_ckp, @@ -113,6 +115,15 @@ wandb = None +def print_sanity_params(model): + for n, p in model.named_parameters(): + print( + n, + p.__class__.__name__, + f"p.requires_grad: {p.requires_grad}, p.dtype: {p.dtype}, p.data.dtype: {p.data.dtype}", + ) + + class DistributedTrainer: def __init__( self, @@ -135,6 +146,12 @@ def __init__( self.config = get_config_from_file( config_or_config_file, config_class=config_class, model_config_class=model_config_class ) + assert ( + self.config.model.dtype == torch.int8 and self.config.optimizer.accumulate_grad_in_fp32 is True + ), "FP8 training must enable gradient accumulation" + from nanotron import constants + + constants.CONFIG = self.config self.model_config = self.config.model.model_config if model_class is not None: CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class @@ -152,6 +169,7 @@ def __init__( ) self.pre_init() + self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context) # Set log levels set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging) @@ -183,13 +201,56 @@ def __init__( else ParametrizationMethod.STANDARD ) + from nanotron.fp8.utils import find_fp8_config_by_module_name, get_leaf_modules + + for module_name, module in get_leaf_modules(self.model): + if any(p.numel() > 0 for p in module.parameters()) is False: + continue + + recipe = find_fp8_config_by_module_name(module_name, constants.CONFIG.fp8) + if recipe is not None: + module.weight._is_future_fp8 = True + + # NOTE: make a copy of FP8 parameter on CPU + # from nanotron import constants + for n, p in self.model.named_parameters(): + if hasattr(p, "_is_future_fp8") and p._is_future_fp8 is True: + constants.CPU_WEIGHTS[n.replace("module.", "")] = p.data.cpu().clone() + + # NOTE: sanity check all hash are different + # param_hash = [] + # for p in self.model.parameters(): + # assert hash(p) not in param_hash + # param_hash.append(hash(p)) + + # NOTE: if we cast model to FP8 before wrapping it with NanotronParameter, + # then we can create a NanotronParameter that has dtype=[torch.int8, torch.uint8] + # which then it allows us to assign [torch.int8, torch.uint8] gradients to the parameter + # otherwise, it would raise: + # "attempting to assign a gradient with dtype + # 'unsigned char' to a tensor with dtype 'float'. + # Please ensure that the gradient and the tensor have the same dtype" + # NOTE: the reason that we cast after initializing the optimizer is that + # we want to create some master weights for fp8 parameters, before quantizing them + self.model = convert_model_to_fp8(self.model, config=constants.CONFIG.fp8) + + # NOTE: convert non-fp8 parameters to the residual stream's dtype + # Init optimizer self.optimizer, self.grad_accumulator = init_optimizer_and_grad_accumulator( parametrization_method=parametrization_method, model=self.model, + master_weight_dtype=self.config.optimizer.master_weight_dtype, optimizer_args=self.config.optimizer, parallel_context=self.parallel_context, ) + # NOTE: quantize optimizer states + # add hook to dequantize optimizer states before .step() + # add hook step to recompute lr + # add post_step hook to quantize optimizer states + + assert 1 == 1 + if self.init_checkpoint_path is not None and self.config.checkpoints.load_optimizer: load_optimizer( optimizer=self.optimizer, @@ -260,7 +321,7 @@ def __init__( self.post_init() def pre_init(self): - self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context) + pass def post_init(self): # S3 Mover and save initial state @@ -528,10 +589,11 @@ def training_step( # Clip gradients if self.config.optimizer.clip_grad is not None: # Unwrap DDP + # NOTE: FP8 parameter's requires_grad is set to False by default named_parameters = [ (name, param) for name, param in self.unwrapped_model.get_named_params_with_correct_tied() - if param.requires_grad + if param.requires_grad or isinstance(param.data, FP8Tensor) ] self.grad_norm_unclipped = clip_grad_norm( mp_pg=self.parallel_context.mp_pg, @@ -540,6 +602,10 @@ def training_step( max_norm=self.config.optimizer.clip_grad, ) + before_optim_step_sanity_checks( + self.config, self.parallel_context, self.unwrapped_model, self.optimizer, self.grad_accumulator + ) + # Compute DP average loss and overlap with optimizer step if isinstance(outputs[0]["loss"], torch.Tensor): # This is an average on only one data rank. @@ -552,6 +618,7 @@ def training_step( loss_avg = None handle = None + # NOTE: sanity check that parameters has gradient # Move optimizer states back to GPU before optimizer step if ( self.init_checkpoint_path is not None @@ -706,6 +773,7 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: model = self._init_model_instance() model = self._load_model_checkpoint(model) + return model def _init_model_instance(self) -> NanotronModel: @@ -787,10 +855,11 @@ def _init_model( config.optimizer.accumulate_grad_in_fp32 and config.optimizer.zero_stage > 0 ) + model_init_dtype = config.fp8.resid_dtype if config.model.dtype == torch.int8 else config.model.dtype # Build model and set pp ranks model = build_model( parallel_context=parallel_context, - dtype=config.model.dtype, + dtype=model_init_dtype, target_pp_ranks=target_pp_ranks, model_builder=model_builder, ) @@ -844,6 +913,16 @@ def _init_model( # Check that the model has at least one grad. Necessary for DDP check_model_has_grad(model=model, parallel_context=parallel_context) # TODO @thomasw21: DDP doesn't support broadcasting complex buffers (and we don't really need that broadcasting anyway) + # if self.config.model.dtype == torch.int8: + # raise NotImplementedError + # model = FP8DistributedDataParallel(model, self.parallel_context) + # else: + # model = DistributedDataParallel( + # model, + # process_group=parallel_context.dp_pg, + # broadcast_buffers=False, + # bucket_cap_mb=config.model.ddp_bucket_cap_mb, + # ) model = DistributedDataParallel( model, process_group=parallel_context.dp_pg, diff --git a/tests/fp8/_test_fp8_parameter.py b/tests/fp8/_test_fp8_parameter.py new file mode 100644 index 00000000..1d464615 --- /dev/null +++ b/tests/fp8/_test_fp8_parameter.py @@ -0,0 +1,151 @@ +import pytest +import torch +from nanotron.constants import CHECKPOINT_VERSION +from nanotron.fp8.constants import FP8_DTYPES +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.meta import FP8Meta +from nanotron.fp8.parameter import FP8Parameter +from nanotron.fp8.tensor import FP8Tensor +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config +from nanotron.serialize.metadata import TensorMetadata +from nanotron.testing.parallel import init_distributed, rerun_if_address_is_in_use +from torch import nn + + +def create_sharded_fp8_parameter(param: nn.Parameter, parallel_context: ParallelContext): + split_config = SplitConfig( + split_dim=0, + contiguous_chunks=(8, 8), + ) + param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config) + return param + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_create_fp8_parameter(dtype): + tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32) + + fp8_parameter = FP8Parameter(tensor, dtype) + + assert isinstance(fp8_parameter.data, FP8Tensor) + assert fp8_parameter.requires_grad is True + assert fp8_parameter.grad is None + assert fp8_parameter.dtype in FP8_DTYPES + + assert isinstance(fp8_parameter.fp8_meta, FP8Meta) + assert isinstance(fp8_parameter.data.fp8_meta, FP8Meta) + assert fp8_parameter.data.fp8_meta is fp8_parameter.fp8_meta + + +def test_fp8_parameter_grad_metadata(): + GRAD_META = ["input_grad", "weight_grad", "output_grad"] + tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32) + fp8_parameter = FP8Parameter(tensor, DTypes.FP8E4M3) + + assert all(hasattr(fp8_parameter.fp8_grad_meta, attr) for attr in GRAD_META) + assert all(isinstance(getattr(fp8_parameter.fp8_grad_meta, attr), FP8Meta) for attr in GRAD_META) + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +@pytest.mark.parametrize("grad_dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_setting_fp8_gradient_to_fp8_parameter(dtype, grad_dtype): + fp8_parameter = FP8Parameter(torch.randn(16, 16, device="cuda"), dtype) + fp8_grad = FP8Tensor(torch.randn(16, 16, device="cuda"), dtype=grad_dtype) + + fp8_parameter.grad = fp8_grad + + assert torch.equal(fp8_parameter.grad, fp8_parameter.data.grad) + assert id(fp8_parameter.grad) == id(fp8_parameter.data.grad) + assert fp8_parameter.grad.data_ptr() == fp8_parameter.data.grad.data_ptr() + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_fp8_parameter_storage_memory(dtype): + data = torch.randn(16, 16, device="cuda", dtype=torch.float32) + fp8_parameter = FP8Parameter(data, dtype) + + assert id(fp8_parameter.data) != id(data) + assert fp8_parameter.data_ptr() == data.data_ptr() + assert fp8_parameter.data.data_ptr() != data.data_ptr() + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_set_data_in_fp8_parameter(dtype): + data = torch.randn(16, 16, device="cuda", dtype=torch.float32) + fp8_parameter = FP8Parameter(data, dtype) + + new_data = torch.randn(16, 16, device="cuda", dtype=torch.float32) + new_fp8_data = FP8Tensor(new_data, dtype=dtype) + + fp8_parameter.data = new_fp8_data + + assert fp8_parameter.data is new_fp8_data + assert torch.equal(fp8_parameter.data, new_fp8_data) + assert fp8_parameter.data.data_ptr() == new_fp8_data.data_ptr() + + assert fp8_parameter.fp8_meta is new_fp8_data.fp8_meta + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_set_gradient_in_fp8_parameter(dtype): + data = torch.randn(16, 16, device="cuda", dtype=torch.float32) + fp8_parameter = FP8Parameter(data, dtype) + + grad = torch.randn(16, 16, device="cuda", dtype=torch.float32) + fp8_grad = FP8Tensor(grad, dtype=dtype) + + fp8_parameter.grad = fp8_grad + + assert fp8_parameter.grad is fp8_grad + assert torch.equal(fp8_parameter.grad, fp8_grad) + assert fp8_parameter.grad.data_ptr() == fp8_grad.data_ptr() + + assert fp8_parameter.data.grad is fp8_parameter.grad + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +@rerun_if_address_is_in_use() +def test_create_sharded_fp8_parameter(dtype): + init_distributed(tp=2, dp=1, pp=1)(_test_create_sharded_fp8_parameter)(dtype=dtype) + + +def _test_create_sharded_fp8_parameter(parallel_context: ParallelContext, dtype: DTypes): + data = torch.randn(16, 64, device="cuda") + param = FP8Parameter(data, dtype) + + param = create_sharded_fp8_parameter(param, parallel_context) + sharded_info = param.get_sharded_info() + + assert isinstance(param, NanotronParameter) + assert isinstance(param.data, FP8Tensor) + assert isinstance(param.data.fp8_meta, FP8Meta) + + metadata = TensorMetadata( + version=CHECKPOINT_VERSION, + local_global_slices_pairs=sharded_info.local_global_slices_pairs, + unsharded_shape=sharded_info.unsharded_shape, + ) + metadata_str_dict = metadata.to_str_dict() + # Assert metadata_str_dict is Dict[str, str] + assert isinstance(metadata_str_dict, dict) + assert all(isinstance(key, str) for key in metadata_str_dict.keys()) + assert all(isinstance(value, str) for value in metadata_str_dict.values()) + + metadata_from_str_dict = TensorMetadata.from_str_dict(metadata_str_dict) + assert metadata == metadata_from_str_dict + + parallel_context.destroy() + + +# TODO(xrsrke): add test for preventing torch autograd do the backward pass +# on a FP8Parameter + +# TODO(xrsrke): test CPU parameter + + +# TODO(xrsrke): test convert model to FP8 +# include the FP8's NanotronParameter's dtype and requires_grad + +# TODO(xrsrke): test set FP8 gradients to FP8 NanotronParameter diff --git a/tests/fp8/_test_linear.py b/tests/fp8/_test_linear.py new file mode 100644 index 00000000..4c53e00d --- /dev/null +++ b/tests/fp8/_test_linear.py @@ -0,0 +1,260 @@ +from copy import deepcopy +from functools import partial, reduce + +import pytest +import torch +from nanotron.fp8.constants import FP8_DTYPES, QTYPE_TO_DTYPE +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.linear import FP8Linear, FP8LinearMeta +from nanotron.fp8.parameter import FP8Parameter +from nanotron.fp8.recipe import FP8LinearRecipe +from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8 +from nanotron.fp8.utils import convert_linear_to_fp8, convert_to_fp8_module, is_overflow_underflow_nan + +# from timm.models.layers import trunc_normal_ +from torch import nn + + +@pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +@pytest.mark.parametrize("bias", [True, False]) +def test_create_an_fp8_linear_parameters(bias, accum_qtype): + fp8_linear = FP8Linear(64, 64, bias=bias, device="cuda", accum_qtype=accum_qtype) + + assert isinstance(fp8_linear.weight, FP8Parameter) + assert isinstance(fp8_linear.bias, torch.Tensor) if bias else True + assert isinstance(fp8_linear.recipe, FP8LinearRecipe) + assert isinstance(fp8_linear.metadatas, FP8LinearMeta) + + +def test_fp8_linear_parameters(): + ref_linear = nn.Linear(16, 16, device="cuda") + fp8_linear = convert_linear_to_fp8(deepcopy(ref_linear), accum_qtype=DTypes.KFLOAT32) + + assert len(list(ref_linear.parameters())) == len(list(fp8_linear.parameters())) + assert all(p is not None for p in fp8_linear.parameters()) + assert isinstance(fp8_linear.weight, FP8Parameter) + assert isinstance(fp8_linear.bias, torch.Tensor) + assert all(p.requires_grad for p in fp8_linear.parameters()) is True + + +# @pytest.mark.skip +@pytest.mark.parametrize("n_layers", [1, 2]) +@pytest.mark.parametrize( + "input", + [ + torch.randn(64, 64, device="cuda", dtype=torch.float32), # [B, H] + torch.randn(16, 64, device="cuda", dtype=torch.float32), # [B, H] + torch.randn(16, 32, 64, device="cuda", dtype=torch.float32), # [B, N, H] + torch.randn(64, 64, 64, device="cuda", dtype=torch.float32), # [B, N, H] + ], +) +# @pytest.mark.parametrize("is_bias", [True, False]) +# @pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16, torch.bfloat16]) +@pytest.mark.parametrize("is_bias", [False]) +@pytest.mark.parametrize("accum_qtype", [torch.bfloat16]) +def test_fp8_linear_forward_pass(n_layers, input, is_bias, accum_qtype): + HIDDEN_SIZE = 64 + + ref_input = input.detach().clone() + ref_linear = nn.Sequential( + *[ + layer + for _ in range(n_layers) + for layer in (nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, bias=is_bias, device="cuda"), nn.ReLU()) + ] + ) + + fp8_linear = convert_to_fp8_module(deepcopy(ref_linear), accum_qtype) + + ref_output = ref_linear(ref_input) + output = fp8_linear(input) + + assert isinstance(output, torch.Tensor) + assert output.dtype == QTYPE_TO_DTYPE[accum_qtype] + + # NOTE: this threshold is from fp8-lm, the paper shows that this is fine + torch.testing.assert_allclose(output, ref_output, rtol=0, atol=0.1) + + +# TODO(xrsrke): add cases where the input requires and don't require grad +# @pytest.mark.skip("we already test this in the test_tensor_parallel") +@pytest.mark.parametrize("n_layers", [1, 2]) +@pytest.mark.parametrize( + "input", + [ + torch.randn(64, 64, device="cuda", dtype=torch.float32), # [B, H] + # torch.randn(16, 64, device="cuda", dtype=torch.float32), # [B, H] + # torch.randn(16, 32, 64, device="cuda", dtype=torch.float32), # [B, N, H] + # torch.randn(64, 64, 64, device="cuda", dtype=torch.float32), # [B, N, H] + ], +) +# @pytest.mark.parametrize( +# "init_method", +# [ +# lambda weight: trunc_normal_(weight, std=0.02), +# lambda weight: trunc_normal_(weight, std=math.sqrt(1 / 64)), +# lambda weight: trunc_normal_(weight, std=math.sqrt(1 / 64 * 4)), +# lambda weight: trunc_normal_(weight, std=1), +# ], +# ) +# @pytest.mark.parametrize("is_bias", [True, False]) +# @pytest.mark.skip +# @pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +@pytest.mark.parametrize("accum_qtype", [torch.bfloat16]) +def test_fp8_linear_backward_pass(n_layers, input, accum_qtype): + is_bias = False + + HIDDEN_SIZE = 64 + + ref_input = input.detach().clone().requires_grad_(True) + # ref_linear = nn.Linear(HIDDEN_SIZE, INTERDIM_SIZE, device="cuda", dtype=torch.float32) + ref_linear = nn.Sequential( + *[ + layer + for _ in range(n_layers) + for layer in (nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, bias=is_bias, device="cuda"), nn.ReLU()) + ] + ) + + # trunc_normal_(ref_linear.weight, std=0.02) + # trunc_normal_(ref_linear.weight, std=math.sqrt(1 / (HIDDEN_SIZE))) + + # fp8_linear = convert_linear_to_fp8(deepcopy(ref_linear), accum_qtype) + fp8_linear = convert_to_fp8_module(deepcopy(ref_linear), accum_qtype) + + ref_linear(ref_input).sum().backward() + + fp8_linear(input).sum().backward() + + for ref_p, p in zip(ref_linear.parameters(), fp8_linear.parameters()): + if p.requires_grad is False: + assert p.grad is None + continue + + if isinstance(p, FP8Parameter): + assert isinstance(p.grad, FP8Tensor) + assert p.grad.dtype in FP8_DTYPES + grad = convert_tensor_from_fp8(p.grad, p.grad.fp8_meta, torch.float32) + else: + assert isinstance(p.grad, torch.Tensor) + assert p.grad.dtype == QTYPE_TO_DTYPE[accum_qtype] + + assert is_overflow_underflow_nan(grad) is False + if p.ndim > 1: + # NOTE: these weight threshold is tuned from the FP8-LM implementation + # TODO(xrsrke): tune what is the minimum threshold for this to correctly converge + torch.testing.assert_allclose(grad, ref_p.grad, rtol=0.06, atol=0.1) + else: + torch.testing.assert_allclose(grad, ref_p.grad) + + # assert isinstance(fp8_linear.weight.grad, FP8Tensor) + # assert fp8_linear.weight.grad.dtype in FP8_DTYPES + + # assert isinstance(fp8_linear.bias.grad, torch.Tensor) + # assert fp8_linear.bias.grad.dtype == QTYPE_TO_DTYPE[accum_qtype] + + # # TODO(xrsrke): investigate why input.grad is so high tolerance + # # assert torch.allclose(input.grad, ref_input.grad, 0.2, 0.2) if input_requires_grad else True + + # # NOTE: these weight threshold is tuned from the FP8-LM implementation + # # TODO(xrsrke): tune what is the minimum threshold for this to correctly converge + # weight_grad = convert_tensor_from_fp8(fp8_linear.weight.grad, fp8_linear.weight.grad.fp8_meta, torch.float32) + # torch.testing.assert_allclose(weight_grad, ref_linear.weight.grad, rtol=0.06, atol=0.1) + # torch.testing.assert_allclose(fp8_linear.bias.grad, ref_linear.bias.grad) + + +@pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +def test_fp8_modules_trigger_the_entire_computational_graph(accum_qtype): + HIDDEN_SIZE = 16 + TIMELINE = [] + + def backward_hook(module, grad_input, grad_output, idx): + TIMELINE.append(f"{module.__class__.__name__}.{idx}.backward") + + class Logger(nn.Module): + def __init__(self, idx: int, module: nn.Linear): + super().__init__() + module.register_backward_hook(partial(backward_hook, idx=idx)) + self.module = module + self.idx = idx + + def forward(self, input): + TIMELINE.append(f"{self.module.__class__.__name__}.{self.idx}.forward") + return self.module(input) + + input = torch.randn(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda", dtype=torch.float32) + fp8_linear = nn.Sequential( + nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda", dtype=torch.float32), + nn.ReLU(), + nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda", dtype=torch.float32), + nn.ReLU(), + ) + fp8_linear = convert_to_fp8_module(fp8_linear, accum_qtype) + fp8_linear = nn.ModuleList([Logger(idx, module) for idx, module in enumerate(fp8_linear)]) + + output = reduce(lambda x, module: module(x), fp8_linear, input) + scalar = torch.randn(1, device="cuda", dtype=output.dtype) + (output.sum() * scalar).backward() + + assert TIMELINE == [ + "FP8Linear.0.forward", + "ReLU.1.forward", + "FP8Linear.2.forward", + "ReLU.3.forward", + "ReLU.3.backward", + "FP8Linear.2.backward", + "ReLU.1.backward", + "FP8Linear.0.backward", + ] + + for p in fp8_linear.parameters(): + if p.requires_grad is True: + assert is_overflow_underflow_nan(p.grad) is False + + +# NOTE: it seems that dynamic quantization should be in test_tensor +# but we only do this if we are in training => test it in a linear +@pytest.mark.parametrize("interval", [1, 5, 10]) +def test_deplay_quantization(interval): + # NOTE: test delay quantization (window size) + # NOTE: test overflow, underflow, zeros + # NOTE: test reduce/increase exponent bits + + HIDDEN_SIZE = 16 + N_STEPS = 4 + + input = torch.randn(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda", dtype=torch.float32) + fp8_linear = FP8Linear(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda") + + for _ in range(N_STEPS): + output = fp8_linear(input) + output.sum().backward() + + +@pytest.mark.skip +@pytest.mark.parametrize("input_shape", [(16, 15), (15, 16), (15, 15)]) +@pytest.mark.parametrize("is_bias", [True, False]) +@pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +def test_fp8_linear_padding(input_shape, is_bias, accum_qtype): + input = torch.randn(**input_shape) + ref_input = input.detach().clone() + ref_linear = nn.Linear(16, 16, bias=is_bias, device="cuda") + fp8_linear = convert_linear_to_fp8(deepcopy(ref_linear), accum_qtype) + + ref_output = ref_linear(ref_input) + output = fp8_linear(input) + + assert isinstance(output, torch.Tensor) + assert output.dtype == QTYPE_TO_DTYPE[accum_qtype] + + # NOTE: this threshold is from fp8-lm, the paper shows that this is fine + torch.testing.assert_allclose(output, ref_output, rtol=0, atol=0.1) + + +# TODO(xrsrke): test if FP8Linear has all the methods of a torch.nn.Linear + + +# TODO(xrsrke): test only calculating the gradients of the weight, bias, or input based +# on the requires_grad of the input, weight, or bias + +# TODO(xrsrke): test automatic padding if a input/weight shape isn't divisible by 16 diff --git a/tests/fp8/_test_tensor.py b/tests/fp8/_test_tensor.py new file mode 100644 index 00000000..237e89d5 --- /dev/null +++ b/tests/fp8/_test_tensor.py @@ -0,0 +1,468 @@ +from copy import deepcopy +from typing import cast + +import numpy as np +import pytest +import torch +import transformer_engine as te # noqa +import transformer_engine_torch as tex +from nanotron.fp8.constants import ( + FP8_ATOL_THRESHOLD, + FP8_RTOL_THRESHOLD, + FP16_ATOL_THRESHOLD, + FP16_RTOL_THRESHOLD, + QTYPE_TO_DTYPE, +) +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.meta import FP8Meta +from nanotron.fp8.tensor import FP8Tensor, FP16Tensor, convert_tensor_from_fp8, convert_tensor_from_fp16 +from nanotron.testing.utils import TestContext +from utils import fail_if_expect_to_fail + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +@pytest.mark.parametrize("interval", [1, 5]) +def test_fp8_and_fp16_metadata(tensor_cls, dtype, interval): + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = tensor_cls(tensor, dtype=dtype, interval=interval) + fp8_meta = cast(FP8Meta, fp8_tensor.fp8_meta) + + # TODO(xrsrke): remove the fixed 1 factor + # it couples with the current implementation of FP8Meta + # because we initialize scale with 1 + assert fp8_meta.amax == ref_tensor.abs().max() + assert isinstance(fp8_meta.inverse_scale, torch.Tensor) + assert fp8_meta.scale != 0.1 and fp8_meta.scale != 1.0 + assert isinstance(fp8_meta.te_dtype, tex.DType) + assert fp8_meta.interval == interval + + +@pytest.mark.parametrize("size", [4, 8, 16, 64]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_quantize_and_dequantize_tensor_in_fp8(size, dtype): + tensor = torch.randn((size, size), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + fp8_tensor = FP8Tensor(tensor, dtype=dtype) + + assert not np.array_equal(fp8_tensor.cpu().numpy(), ref_tensor.cpu().numpy()) + + tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) + # NOTE: sometimes type(tensor) is FP8Tensor, but it still passes, so we directly check the class name + # to make sure it's a torch.Tensor + assert tensor.__class__.__name__ == torch.Tensor.__name__ + assert tensor.dtype == ref_tensor.dtype + + torch.testing.assert_close(tensor, ref_tensor, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD) + + +# @pytest.mark.parametrize("interval", [1, 5]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_create_fp8_tensor_from_metadata(dtype): + INTERVAL = 5 + TOTAL_STEPS, REMAINING_STEPS = 20, 16 + tensor = torch.randn((16, 16), dtype=torch.float32, device="cuda") + fp8_tensor = FP8Tensor(tensor, dtype=dtype, interval=INTERVAL) + + new_values = [torch.randn((16, 16), dtype=torch.float32, device="cuda") for i in range(TOTAL_STEPS)] + + for i in range(TOTAL_STEPS): + if TOTAL_STEPS - REMAINING_STEPS == i: + current_tensor = fp8_tensor.orig_data + fp8_meta = deepcopy(fp8_tensor.fp8_meta) + + fp8_tensor.set_data(new_values[i]) + + resumed_fp8_tensor = FP8Tensor.from_metadata(current_tensor, fp8_meta) + for i in range(TOTAL_STEPS - REMAINING_STEPS, TOTAL_STEPS): + resumed_fp8_tensor.set_data(new_values[i]) + + # NOTE: we expect a resume tensor to have the state trajectory of the original tensor + assert resumed_fp8_tensor == fp8_tensor + + +@pytest.mark.parametrize("size", [4, 8, 16, 64]) +def test_quantize_and_dequantize_tensor_in_fp16(size): + tensor = torch.randn((size, size), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp16_tensor = FP16Tensor(tensor, dtype=DTypes.KFLOAT16) + + assert not np.array_equal(fp16_tensor.cpu().numpy(), ref_tensor.cpu().numpy()) + + tensor = convert_tensor_from_fp16(fp16_tensor, torch.float32) + # NOTE: sometimes type(tensor) is FP16Tensor, but it still passes + assert tensor.__class__.__name__ == torch.Tensor.__name__ + assert tensor.dtype == ref_tensor.dtype + + # NOTE: this tolerance is from FP8-LM's implementation + # reference: https://github.com/Azure/MS-AMP/blob/9ac98df5371f3d4174d8f103a1932b3a41a4b8a3/tests/common/tensor/test_cast.py#L35 + torch.testing.assert_close(tensor, ref_tensor, rtol=FP16_RTOL_THRESHOLD, atol=FP16_ATOL_THRESHOLD) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_fp8_and_fp16_tensor_repr(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(tensor, dtype) + + # NOTE: in some cases, it causes an infinite loop + # in repr(tensor), so just check if it doesn't loop + assert isinstance(repr(fp8_tensor), str) + + +@pytest.mark.parametrize( + "tensor_cls, dtype, expected_dtype", + [ + (FP8Tensor, DTypes.FP8E4M3, torch.uint8), + (FP8Tensor, DTypes.FP8E5M2, torch.uint8), + (FP16Tensor, DTypes.KFLOAT16, torch.float16), + ], +) +def test_fp8_and_fp16_tensor_attrs(tensor_cls, dtype, expected_dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + ref_tensor = tensor.detach().clone() + + fp8_tensor = tensor_cls(tensor, dtype) + + assert isinstance(fp8_tensor, tensor_cls) + assert isinstance(fp8_tensor.fp8_meta, FP8Meta) + assert fp8_tensor.dtype == expected_dtype + assert fp8_tensor.device == ref_tensor.device + assert fp8_tensor.shape == ref_tensor.shape + assert fp8_tensor.numel() == ref_tensor.numel() + assert fp8_tensor.device == ref_tensor.device + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +@pytest.mark.parametrize( + "scale", + [ + torch.ones(1, device="cuda:0").squeeze() * 2, # an random scalar + torch.ones(1, device="cuda:0") * 2, + torch.ones(4, 4, device="cuda:0") * 2, + ], +) +def test_multiple_fp8_tensor(tensor_cls, dtype, scale): + RTOL, ATOL = ( + (FP8_RTOL_THRESHOLD, FP8_ATOL_THRESHOLD) + if tensor_cls == FP8Tensor + else (FP16_RTOL_THRESHOLD, FP16_ATOL_THRESHOLD) + ) + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda:0") + ref_tensor = tensor.detach().clone() + + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + ref_fp8_tensor = fp8_tensor.clone() + + with fail_if_expect_to_fail(expect_to_fail=scale.ndim > 1): + fp8_tensor.mul_(scale) + + assert torch.equal(fp8_tensor, ref_fp8_tensor) + assert fp8_tensor.fp8_meta.scale != ref_fp8_tensor.fp8_meta.scale + + if isinstance(fp8_tensor, FP8Tensor): + # NOTE: with the current implementation, we only scale the metadata + # not the tensor itself, so we expect the tensor to be the same + tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) + else: + tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) + + torch.testing.assert_allclose(tensor, ref_tensor * scale, rtol=RTOL, atol=ATOL) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +@pytest.mark.parametrize( + "scale", + [ + torch.ones(1, device="cuda:0").squeeze() * 2, # an random scalar + torch.ones(1, device="cuda:0") * 2, + torch.ones(4, 4, device="cuda:0") * 2, + ], +) +def test_divide_fp8_tensor(tensor_cls, dtype, scale): + # NOTE: the reason that we use 2 as the scale is because + # the purpose of this test is to test whether we scale the magnitude + # of the tensor, so if use other values from normal distribution, + # some values could lead to quantization error, and for this test we don't + # test the quantization error + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + ref_tensor = tensor.detach().clone() + + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + ref_fp8_tensor = fp8_tensor.clone() + + with fail_if_expect_to_fail(expect_to_fail=scale.ndim > 1): + fp8_tensor.div_(scale) + + assert torch.equal(fp8_tensor, ref_fp8_tensor) + assert fp8_tensor.fp8_meta.scale != ref_fp8_tensor.fp8_meta.scale + + if isinstance(fp8_tensor, FP8Tensor): + tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) + # NOTE: use the same tolerance as test_quantize_and_dequantize_tensor_in_fp8 + torch.testing.assert_allclose(tensor, ref_tensor / scale, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD) + else: + tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) + torch.testing.assert_close(tensor, ref_tensor / scale, rtol=FP16_RTOL_THRESHOLD, atol=FP16_ATOL_THRESHOLD) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_add_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + with pytest.raises(ValueError): + fp8_tensor + 1 + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_subtract_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + with pytest.raises(ValueError): + fp8_tensor - 1 + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_clone_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + cloned_fp8_tensor = fp8_tensor.clone() + + assert isinstance(cloned_fp8_tensor, tensor_cls) + assert id(cloned_fp8_tensor) != id(fp8_tensor) + assert cloned_fp8_tensor.device == fp8_tensor.device + + assert torch.equal(cloned_fp8_tensor, fp8_tensor) + assert cloned_fp8_tensor.data_ptr() != fp8_tensor.data_ptr() + assert cloned_fp8_tensor.data.data_ptr() != fp8_tensor.data.data_ptr() + + assert cloned_fp8_tensor.fp8_meta == fp8_tensor.fp8_meta + assert id(cloned_fp8_tensor.fp8_meta) != id(fp8_tensor.fp8_meta) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + # (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_transpose_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((16, 16), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + ref_fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + transposed_fp8_tensor = fp8_tensor.transpose_fp8() + # transposed_fp8_tensor = fp8_tensor.t() + # transposed_fp8_tensor = torch.transpose(fp8_tensor, 0, 1) + + # NOTE: we expect the original tensor to be the same + assert fp8_tensor == ref_fp8_tensor + + if tensor_cls == FP8Tensor: + assert isinstance(transposed_fp8_tensor, FP8Tensor) + ref_transposed = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32).T + dequant_transposed_fp8_tensor = convert_tensor_from_fp8( + transposed_fp8_tensor, transposed_fp8_tensor.fp8_meta, torch.float32 + ) + else: + assert isinstance(transposed_fp8_tensor, FP16Tensor) + dequant_transposed_fp8_tensor = convert_tensor_from_fp16(transposed_fp8_tensor, torch.float32) + ref_transposed = convert_tensor_from_fp16(fp8_tensor, torch.float32).T + + torch.testing.assert_close(dequant_transposed_fp8_tensor, ref_transposed) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_determistic_quantization(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + ref_fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + assert torch.equal(fp8_tensor, ref_fp8_tensor) + assert fp8_tensor.fp8_meta == ref_fp8_tensor.fp8_meta + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_fp8_and_fp16_tensor_storage_memory(tensor_cls, dtype): + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = tensor_cls(tensor, dtype=dtype) + + assert id(fp8_tensor) != id(ref_tensor) + + assert isinstance(fp8_tensor.data, torch.Tensor) + assert id(fp8_tensor.data) != id(ref_tensor.data) + assert fp8_tensor.data_ptr() == fp8_tensor.data.data_ptr() + assert fp8_tensor.data.data_ptr() != ref_tensor.data_ptr() + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +@pytest.mark.parametrize("is_quantized", [True, False]) +def test_setting_new_data_for_fp8_and_fp16_tensor(tensor_cls, dtype, is_quantized): + RTOL, ATOL = ( + (FP8_RTOL_THRESHOLD, FP8_ATOL_THRESHOLD) + if tensor_cls == FP8Tensor + else (FP16_RTOL_THRESHOLD, FP16_ATOL_THRESHOLD) + ) + + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + fp8_tensor = tensor_cls(tensor, dtype=dtype) + + new_data = torch.randn(fp8_tensor.shape, dtype=torch.float32, device="cuda") * 2 + ref_new_data = deepcopy(new_data) + expected_quantized_tensor = tensor_cls(ref_new_data, dtype=dtype) + + new_data = tensor_cls(new_data, dtype=dtype) if is_quantized else new_data + fp8_tensor.set_data(new_data) + + assert fp8_tensor.data.dtype == QTYPE_TO_DTYPE[dtype] + assert torch.equal(fp8_tensor, expected_quantized_tensor) + + if is_quantized: + if tensor_cls == FP8Tensor: + dequantized_tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) + else: + dequantized_tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) + + assert torch.allclose(dequantized_tensor, ref_new_data, rtol=RTOL, atol=ATOL) + assert fp8_tensor.data.data_ptr() == new_data.data.data_ptr() + + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +# ) +# @pytest.mark.parametrize("is_quantized", [True, False]) +# def test_setting_None_data_for_fp8_and_fp16_tensor(tensor_cls, dtype, is_quantized): +# tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") +# fp8_tensor = tensor_cls(tensor, dtype=dtype) + +# fp8_tensor.set_data(None) + +# assert fp8_tensor is None +# assert fp8_tensor.data is None + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +# ) +# def test_quantize_overflow_fp8_and_fp16_tensor(tensor_cls, dtype): +# tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") +# tensor[0, 0] = torch.tensor(float("inf")) +# fp8_tensor = tensor_cls(tensor, dtype) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_zero_out_data_of_fp8_and_fp16_tensor(tensor_cls, dtype): + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + fp8_tensor = tensor_cls(tensor, dtype=dtype) + + fp8_tensor.zero_() + + assert torch.equal(fp8_tensor, torch.zeros_like(fp8_tensor)) + + if tensor_cls == FP8Tensor: + dequantized_tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) + else: + dequantized_tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) + assert torch.equal(dequantized_tensor, torch.zeros_like(tensor)) + + +# NOTE: add testing based on tensor metadata +@pytest.mark.parametrize("is_meta_the_same", [True, False]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_fp8_and_fp16_tensor_equality_based_on_tensor_value(is_meta_the_same, dtype): + # TODO(xrsrke): support torch.equal for FP8Tensor + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = FP8Tensor(tensor, dtype=dtype) + ref_fp8_tensor = FP8Tensor(ref_tensor, dtype=dtype) + + if not is_meta_the_same: + fp8_tensor.fp8_meta.scale = ref_fp8_tensor.fp8_meta.scale * 2 + + assert (fp8_tensor == ref_fp8_tensor) is is_meta_the_same + + new_data = torch.randn(tensor.shape, dtype=torch.float32, device="cuda") + ref_fp8_tensor.set_data(new_data) + + assert not fp8_tensor == ref_fp8_tensor + + +# TODO(xrsrke): test it has all the methods of torch.Tensor + +# TODO(xrsrke): test it has all the attributes of its input tensor + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_serialize_fp8_tensor(tensor_cls, dtype): + test_context = TestContext() + store_folder = test_context.get_auto_remove_tmp_dir() + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + + fp8_tensor = tensor_cls(tensor, dtype=dtype) + + torch.save(fp8_tensor, f"{store_folder}/fp8_tensor.pt") + torch.load(f"{store_folder}/fp8_tensor.pt") + + assert 1 == 1 diff --git a/tests/fp8/test_fp8_meta.py b/tests/fp8/test_fp8_meta.py new file mode 100644 index 00000000..13a3db77 --- /dev/null +++ b/tests/fp8/test_fp8_meta.py @@ -0,0 +1,66 @@ +import pytest +import torch +import transformer_engine as te # noqa +import transformer_engine_torch as tex +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.meta import FP8Meta + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2, DTypes.KFLOAT16]) +def test_fp8_meta(dtype): + AMAX = torch.randn(1, dtype=torch.float32) * 3 + SCALE = torch.randn(1, dtype=torch.float32) + INTERVAL = 5 + + fp8_meta = FP8Meta(amax=AMAX, scale=SCALE, dtype=dtype, interval=INTERVAL) + + assert torch.equal(fp8_meta.amax, AMAX) + assert torch.equal(fp8_meta.scale, SCALE) + assert torch.equal(fp8_meta.inverse_scale, 1 / fp8_meta.scale) + assert isinstance(fp8_meta.fp8_max, float) + assert isinstance(fp8_meta.te_dtype, tex.DType) + + +@pytest.mark.parametrize("interval", [1, 5]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2, DTypes.KFLOAT16]) +def test_fp8_meta_for_delayed_scaling(interval, dtype): + AMAX = torch.randn(1, dtype=torch.float32) * 3 + SCALE = torch.randn(1, dtype=torch.float32) + + fp8_meta = FP8Meta(amax=AMAX, scale=SCALE, dtype=dtype, interval=interval) + + assert fp8_meta.is_delayed_scaling is (interval > 1) + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2, DTypes.KFLOAT16]) +@pytest.mark.parametrize( + "modifications", + [ + # NOTE: no modifications + {}, + {"amax": torch.randn(1, dtype=torch.float32)}, + {"scale": torch.randn(1, dtype=torch.float32)}, + {"dtype": True}, + ], +) +def test_fp8_meta_equality(dtype, modifications): + def modify_fp8_meta(fp8_meta, modifications): + if "dtype" in modifications: + modifications["dtype"] = next(d for d in [DTypes.FP8E5M2, DTypes.FP8E4M3, DTypes.KFLOAT16] if d != dtype) + + for attr, new_value in modifications.items(): + setattr(fp8_meta, attr, new_value) + + AMAX = torch.randn(1, dtype=torch.float32) * 3 + SCALE = torch.randn(1, dtype=torch.float32) + INTERVAL = 5 + + fp8_meta = FP8Meta(amax=AMAX, scale=SCALE, dtype=dtype, interval=INTERVAL) + ref_fp8_meta = FP8Meta(amax=AMAX, scale=SCALE, dtype=dtype, interval=INTERVAL) + + modify_fp8_meta(ref_fp8_meta, modifications) + + if not modifications: + assert fp8_meta == ref_fp8_meta + else: + assert fp8_meta != ref_fp8_meta diff --git a/tests/fp8/test_fp8_model.py b/tests/fp8/test_fp8_model.py new file mode 100644 index 00000000..41cff327 --- /dev/null +++ b/tests/fp8/test_fp8_model.py @@ -0,0 +1,64 @@ +import pytest +import torch +from nanotron.config import ModelArgs, RandomInit +from nanotron.config.fp8_config import FP8Args +from nanotron.fp8.tensor import FP8Tensor +from nanotron.fp8.utils import convert_model_to_fp8 +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.testing.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from nanotron.testing.utils import init_distributed, rerun_if_address_is_in_use +from torch import nn + + +# NOTE: fp8 quantization should be parametrization-method-agnotic +@pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 1, 1), (1, 1, 2), (2, 1, 2)]) +@rerun_if_address_is_in_use() +def test_initialize_fp8_model(tp: int, dp: int, pp: int): + fp8_config = FP8Args() + init_distributed(tp=tp, dp=dp, pp=pp)(_test_initialize_fp8_model)(fp8_config=fp8_config) + + +def _test_initialize_fp8_model(parallel_context: ParallelContext, fp8_config: FP8Args): + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + config = get_llama_training_config(model_args) + llama = create_llama_from_config( + model_config=TINY_LLAMA_CONFIG, + device=torch.device("cuda"), + parallel_context=parallel_context, + dtype=torch.float32, + ) + llama.init_model_randomly(config=config) + + llama = convert_model_to_fp8(llama, config=fp8_config) + + assert 1 == 1 + # NOTE: test the default recipe in fp8's nanotron + from nanotron.fp8.utils import find_fp8_config_by_module_name, get_leaf_modules + + for name, module in get_leaf_modules(llama): + recipe = find_fp8_config_by_module_name(name, fp8_config) + + assert all(p.__class__ == NanotronParameter for p in module.parameters()) + if recipe is None: + assert all( + p.dtype == fp8_config.resid_dtype for p in module.parameters() + ), f"name: {name}, __class__: {module.weight.data.__class__}" + # try: + # assert all( + # p.data.__class__ == nn.Parameter for p in module.parameters() + # ), f"name: {name}, __class__: {module.weight.data.__class__}" + # except: + # assert 1 == 1 + assert all( + p.data.__class__ == nn.Parameter for p in module.parameters() + ), f"name: {name}, __class__: {module.weight.data.__class__}" + else: + assert all( + isinstance(p.data.__class__, FP8Tensor) for p in module.parameters() + ), f"name: {name}, __class__: {module.weight.data.__class__}" + assert all( + p.dtype in [torch.int8, torch.uint8] for p in module.parameters() + ), f"name: {name}, __class__: {module.weight.data.__class__}" + # NOTE: check the expected parameters have fp8 dtype + # NOTE: check the dtype of non-fp8 parameters diff --git a/tests/fp8/test_fp8_nanotron_parameter.py b/tests/fp8/test_fp8_nanotron_parameter.py new file mode 100644 index 00000000..4925e6ad --- /dev/null +++ b/tests/fp8/test_fp8_nanotron_parameter.py @@ -0,0 +1,151 @@ +import pytest +import torch +from nanotron.constants import CHECKPOINT_VERSION +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.tensor import FP8Tensor +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config +from nanotron.serialize.metadata import TensorMetadata +from nanotron.testing.parallel import init_distributed, rerun_if_address_is_in_use +from torch import nn + + +def create_sharded_fp8_parameter(param: nn.Parameter, parallel_context: ParallelContext): + split_config = SplitConfig( + split_dim=0, + contiguous_chunks=(8, 8), + ) + param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config) + return param + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_create_fp8_parameter(dtype): + tensor = FP8Tensor(torch.randn(16, 16, device="cuda", dtype=torch.float32), dtype) + parameter = NanotronParameter(tensor) + + assert parameter.requires_grad is False + assert parameter.dtype == parameter.data.dtype + assert parameter.data is tensor + + # assert isinstance(parameter, FP8Meta) + # assert isinstance(fp8_parameter.data.fp8_meta, FP8Meta) + # assert fp8_parameter.data.fp8_meta is fp8_parameter.fp8_meta + + +# def test_fp8_parameter_grad_metadata(): +# GRAD_META = ["input_grad", "weight_grad", "output_grad"] +# tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32) +# fp8_parameter = FP8Parameter(tensor, DTypes.FP8E4M3) + +# assert all(hasattr(fp8_parameter.fp8_grad_meta, attr) for attr in GRAD_META) +# assert all(isinstance(getattr(fp8_parameter.fp8_grad_meta, attr), FP8Meta) for attr in GRAD_META) + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +@pytest.mark.parametrize("grad_dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_setting_fp8_gradient_to_fp8_parameter(dtype, grad_dtype): + tensor = FP8Tensor(torch.randn(16, 16, device="cuda", dtype=torch.float32), dtype) + parameter = NanotronParameter(tensor) + + fp8_grad = FP8Tensor(torch.randn(16, 16, device="cuda"), dtype=grad_dtype) + + parameter.grad = fp8_grad + + assert parameter.grad is fp8_grad + # assert torch.equal(fp8_parameter.grad, fp8_parameter.data.grad) + # assert id(fp8_parameter.grad) == id(fp8_parameter.data.grad) + # assert fp8_parameter.grad.data_ptr() == fp8_parameter.data.grad.data_ptr() + + +# @pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +# def test_fp8_parameter_storage_memory(dtype): +# data = torch.randn(16, 16, device="cuda", dtype=torch.float32) +# fp8_parameter = FP8Parameter(data, dtype) + +# assert id(fp8_parameter.data) != id(data) +# assert fp8_parameter.data_ptr() == data.data_ptr() +# assert fp8_parameter.data.data_ptr() != data.data_ptr() + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_set_data_in_fp8_parameter(dtype): + # data = torch.randn(16, 16, device="cuda", dtype=torch.float32) + # fp8_parameter = FP8Parameter(data, dtype) + tensor = FP8Tensor(torch.randn(16, 16, device="cuda", dtype=torch.float32), dtype) + parameter = NanotronParameter(tensor) + + new_data = FP8Tensor(torch.randn(16, 16, device="cuda", dtype=torch.float32), dtype=dtype) + + parameter.data = new_data + + assert parameter.data is new_data + + # TODO(xrsrke): expect changes in metadata + + +# @pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +# def test_set_gradient_in_fp8_parameter(dtype): +# data = torch.randn(16, 16, device="cuda", dtype=torch.float32) +# fp8_parameter = FP8Parameter(data, dtype) + +# grad = torch.randn(16, 16, device="cuda", dtype=torch.float32) +# fp8_grad = FP8Tensor(grad, dtype=dtype) + +# fp8_parameter.grad = fp8_grad + +# assert fp8_parameter.grad is fp8_grad +# assert torch.equal(fp8_parameter.grad, fp8_grad) +# assert fp8_parameter.grad.data_ptr() == fp8_grad.data_ptr() + +# assert fp8_parameter.data.grad is fp8_parameter.grad + + +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +@rerun_if_address_is_in_use() +def test_create_sharded_fp8_parameter(dtype): + init_distributed(tp=2, dp=1, pp=1)(_test_create_sharded_fp8_parameter)(dtype=dtype) + + +def _test_create_sharded_fp8_parameter(parallel_context: ParallelContext, dtype: DTypes): + tensor = FP8Tensor(torch.randn(16, 16, device="cuda", dtype=torch.float32), dtype) + parameter = NanotronParameter(tensor) + + # data = torch.randn(16, 64, device="cuda") + # param = FP8Parameter(data, dtype) + + param = create_sharded_fp8_parameter(parameter, parallel_context) + sharded_info = param.get_sharded_info() + + assert isinstance(param, NanotronParameter) + assert isinstance(param.data, FP8Tensor) + # assert isinstance(param.data.fp8_meta, FP8Meta) + + metadata = TensorMetadata( + version=CHECKPOINT_VERSION, + local_global_slices_pairs=sharded_info.local_global_slices_pairs, + unsharded_shape=sharded_info.unsharded_shape, + ) + metadata_str_dict = metadata.to_str_dict() + # Assert metadata_str_dict is Dict[str, str] + assert isinstance(metadata_str_dict, dict) + assert all(isinstance(key, str) for key in metadata_str_dict.keys()) + assert all(isinstance(value, str) for value in metadata_str_dict.values()) + + metadata_from_str_dict = TensorMetadata.from_str_dict(metadata_str_dict) + assert metadata == metadata_from_str_dict + + parallel_context.destroy() + + +# TODO(xrsrke): add test for preventing torch autograd do the backward pass +# on a FP8Parameter + +# TODO(xrsrke): test CPU parameter + + +# TODO(xrsrke): test convert model to FP8 +# include the FP8's NanotronParameter's dtype and requires_grad + +# TODO(xrsrke): test set FP8 gradients to FP8 NanotronParameter diff --git a/tests/fp8/test_fp8_optimizer.py b/tests/fp8/test_fp8_optimizer.py new file mode 100644 index 00000000..9c1a0405 --- /dev/null +++ b/tests/fp8/test_fp8_optimizer.py @@ -0,0 +1,56 @@ +import pytest +import torch +from nanotron.optim.gradient_accumulator import FP32GradientAccumulator +from nanotron.optim.named_optimizer import NamedOptimizer +from nanotron.optim.optimizer_from_gradient_accumulator import ( + OptimizerFromGradientAccumulator, +) +from nanotron.parallel.parameters import NanotronParameter, sanity_check +from torch import nn + + +@pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16]) +def test_optimizer_can_step_gradient_in_fp32(half_precision: torch.dtype): + model = nn.Linear(3, 2, bias=False, dtype=half_precision, device="cuda") + original_weight = model.weight.detach().clone() + + # Create Nanotron Parameter + model.weight = NanotronParameter(model.weight) + + # Add optimizer + optimizer = OptimizerFromGradientAccumulator( + gradient_accumulator_builder=lambda named_params: FP32GradientAccumulator(named_parameters=named_params), + named_params_or_groups=model.named_parameters(), + optimizer_builder=lambda named_param_groups: NamedOptimizer( + named_params_or_groups=named_param_groups, + optimizer_builder=lambda param_groups: torch.optim.AdamW(param_groups), + ), + ) + accumulator = optimizer.gradient_accumulator + + # Check that our model is a valid model + sanity_check(model) + + # Compute backward + input = torch.randn(5, 3, dtype=half_precision, device="cuda") + accumulator.backward(model(input).sum()) + + # Check that we have an high precision gradient and that the low precision one is cleared + assert accumulator.parameters["weight"]["fp32"].grad.dtype == torch.float + if model.weight.grad is not None: + # We check that it's zero + torch.testing.assert_close(model.weight.grad, torch.zeros_like(model.weight.grad), atol=1e-6, rtol=1e-7) + + optimizer.step() + optimizer.zero_grad() + + # Check that we don't have gradients anymore and that it's set to `None` + assert accumulator.parameters["weight"]["fp32"].grad is None + assert model.weight.grad is None + + # Check that gradients have been set to zero + fp32_grad = accumulator.get_grad_buffer(name="weight") + torch.testing.assert_close(fp32_grad, torch.zeros_like(fp32_grad), atol=1e-6, rtol=1e-7) + + # weights has been updates + assert not torch.allclose(original_weight, model.weight) diff --git a/tests/fp8/test_fp8_parameter.py b/tests/fp8/test_fp8_parameter.py deleted file mode 100644 index 43343248..00000000 --- a/tests/fp8/test_fp8_parameter.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -from nanotron.fp8 import DTypes, FP8Parameter, FP8Tensor -from nanotron.fp8.meta import FP8Meta - - -def test_create_fp8_parameter(): - # TODO(xrsrke): test FP8E5M2 format - # TODO(xrsrke): test take a cpu tensor - tensor = torch.randn(16, 16, device="cuda", dtype=torch.float32) - - fp8_parameter = FP8Parameter(tensor, DTypes.FP8E4M3) - - assert isinstance(fp8_parameter.data, FP8Tensor) - assert fp8_parameter.requires_grad is True - assert fp8_parameter.grad is None - assert isinstance(fp8_parameter.fp8_meta, FP8Meta) - assert isinstance(fp8_parameter.data.fp8_meta, FP8Meta) - - -# TODO(xrsrke): add test for preventing torch autograd do the backward pass -# on a FP8Parameter diff --git a/tests/fp8/test_fp8_tensor_parallel.py b/tests/fp8/test_fp8_tensor_parallel.py new file mode 100644 index 00000000..0e25eda7 --- /dev/null +++ b/tests/fp8/test_fp8_tensor_parallel.py @@ -0,0 +1,459 @@ +import os + +# from nanotron import distributed as dist +import nanotron.fp8.distributed as dist + +# import torch.distributed as dist +import pytest +import torch +from nanotron.distributed import get_global_rank +from nanotron.fp8.linear import FP8LinearMeta +from nanotron.fp8.recipe import FP8LinearRecipe +from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8 +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import ( + FP8TensorParallelColumnLinear, + FP8TensorParallelRowLinear, +) +from nanotron.sanity_checks import assert_tensor_synced_across_pg +from nanotron.testing.parallel import init_distributed, rerun_if_address_is_in_use +from torch import nn + +# TODO(xrsrke): add test where we test the apis of fp8 parallel linear + + +@pytest.mark.parametrize("tp,dp,pp", [[1, 1, 1], [2, 1, 1]]) +@rerun_if_address_is_in_use() +def test_fp8_column_linear_metadata( + tp: int, + dp: int, + pp: int, +): + init_distributed(tp=tp, dp=dp, pp=pp)(_test_fp8_column_linear_metadata)() + + +def _test_fp8_column_linear_metadata( + parallel_context: ParallelContext, +): + # NOTE: divisible by 16 for TP + in_features = 32 + out_features_per_tp_rank = 16 + + out_features = parallel_context.tp_pg.size() * out_features_per_tp_rank + + column_linear = FP8TensorParallelColumnLinear( + in_features=in_features, + out_features=out_features, + pg=parallel_context.tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + device="cuda", + async_communication=False, + bias=False, + ) + + assert isinstance(column_linear.weight, NanotronParameter) + # assert isinstance(column_linear.weight.data, FP8Tensor) + assert isinstance(column_linear.weight.data, FP8Tensor) + assert isinstance(column_linear.recipe, FP8LinearRecipe) + assert isinstance(column_linear.metadatas, FP8LinearMeta) + + parallel_context.destroy() + + +# TODO(xrsrke): support gradient flow to bias +@pytest.mark.parametrize("tp,dp,pp", [[1, 1, 1], [2, 1, 1]]) +@pytest.mark.parametrize("tp_mode", [TensorParallelLinearMode.ALL_REDUCE]) +@pytest.mark.parametrize("async_communication", [False]) +@pytest.mark.parametrize("with_bias", [False]) +@rerun_if_address_is_in_use() +def test_column_linear( + tp: int, + dp: int, + pp: int, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + with_bias: bool, +): + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: + pytest.skip("ALL_REDUCE mode does not support async communication") + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_column_linear)( + tp_mode=tp_mode, + async_communication=async_communication, + with_bias=with_bias, + ) + + +def _test_column_linear( + parallel_context: ParallelContext, + tp_mode: TensorParallelLinearMode, + async_communication: bool, + with_bias: bool, +): + if async_communication: + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + # NOTE: divisible by 16 for TP + in_features = 32 + out_features_per_tp_rank = 16 + + out_features = parallel_context.tp_pg.size() * out_features_per_tp_rank + + # Sharded + column_linear = FP8TensorParallelColumnLinear( + in_features=in_features, + out_features=out_features, + pg=parallel_context.tp_pg, + mode=tp_mode, + device="cuda", + async_communication=async_communication, + bias=with_bias, + ) + + # Un-sharded + reference_linear = nn.Linear(in_features=in_features, out_features=out_features, bias=with_bias, device="cuda") + + # Copy weights/bias from sharded to un-sharded + with torch.inference_mode(): + # weight = column_linear.weight.data + # weight = convert_tensor_from_fp8(weight, weight.fp8_meta, torch.bfloat16), + dist.all_gather( + tensor_list=list(reference_linear.weight.split(out_features_per_tp_rank, dim=0)), + # tensor=column_linear.weight.data, + tensor=column_linear.weight.data, + group=parallel_context.tp_pg, + ) + + if with_bias is True: + # TODO(xrsrke): support if bias is in FP8 + # bias = column_linear.bias.data + bias = column_linear.bias.data + bias = bias.to(reference_linear.bias.dtype) if bias.dtype != reference_linear.bias.dtype else bias + dist.all_gather( + tensor_list=list(reference_linear.bias.split(out_features_per_tp_rank, dim=0)), + tensor=bias, + group=parallel_context.tp_pg, + ) + + # TODO(xrsrke) + if with_bias is True: + assert column_linear.bias.requires_grad is (with_bias is True) + # assert column_linear.bias.data.__class__ == torch.Tensor + # assert get_data_from_param(column_linear.bias).__class__ == nn.Parameter + assert isinstance(column_linear.bias, nn.Parameter) + # assert column_linear.bias.data.requires_grad is (with_bias is True) + + # Generate random input + random_input: torch.Tensor + sharded_random_input: torch.Tensor + if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + # batch_size = 5 + batch_size = 16 + random_input = torch.randn(batch_size, in_features, device="cuda") + # synchronize random_input across tp + dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) + sharded_random_input = random_input + elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + # sharded_batch_size = 5 + sharded_batch_size = 16 + sharded_random_input = torch.randn(sharded_batch_size, in_features, device="cuda") + if parallel_context.tp_pg.size() > 1: + random_input = torch.empty( + sharded_batch_size * parallel_context.tp_pg.size(), + *(sharded_random_input.shape[1:]), + device=sharded_random_input.device, + dtype=sharded_random_input.dtype, + ) + dist.all_gather_into_tensor(random_input, sharded_random_input, group=parallel_context.tp_pg) + else: + random_input = sharded_random_input + else: + ValueError(f"Unsupported mode: {tp_mode}") + + dist.barrier() + assert_tensor_synced_across_pg(random_input, pg=parallel_context.tp_pg) + + # It's important that `random_input` and `sharded_random_input` are two separate tensors with separate storage + sharded_random_input = sharded_random_input.clone() + sharded_random_input = sharded_random_input.contiguous() + random_input.requires_grad = True + sharded_random_input.requires_grad = True + + # Test that we get the same output after forward pass + sharded_output = column_linear(sharded_random_input) + reference_output = reference_linear(random_input) + + hidden_dim_slice = slice( + dist.get_rank(parallel_context.tp_pg) * out_features_per_tp_rank, + (dist.get_rank(parallel_context.tp_pg) + 1) * out_features_per_tp_rank, + ) + + torch.testing.assert_close( + # convert_tensor_from_fp8(column_linear.weight.data, column_linear.weight.data.fp8_meta, torch.bfloat16), + convert_tensor_from_fp8( + # get_data_from_param(column_linear.weight), + # get_data_from_param(column_linear.weight).fp8_meta, + column_linear.weight.data, + column_linear.weight.data.fp8_meta, + torch.bfloat16, + ), + reference_linear.weight[hidden_dim_slice].to(torch.bfloat16), + rtol=0.1, + atol=0.1, + ) + + # reference_output = ReferenceLinear.apply(random_input, reference_linear.weight, reference_linear.bias) + + # TODO @thomasw21: Tune tolerance + try: + torch.testing.assert_close( + sharded_output, + # TODO(xrsrke): retrieve accumulation precision from recipe + # NOTE: before the reference_output.to(torch.bfloat16) + # reference_output[ + # :, + # dist.get_rank(parallel_context.tp_pg) + # * out_features_per_tp_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + # * out_features_per_tp_rank, + # ].to(torch.bfloat16), + reference_output[:, hidden_dim_slice].to(torch.bfloat16), + rtol=0, + atol=0.1, + ) + except BaseException as e: + print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: FAIL.") + dist.barrier() + raise e + + print(f"Rank {dist.get_rank(parallel_context.tp_pg)}: SUCCESS.") + dist.barrier() + + # Test that we get the same gradient after backward pass + sharded_output.sum().backward() + reference_output.sum().backward() + + # TODO(xrsrke): retrieve accumulation precision from recipe + assert sharded_output.dtype == torch.bfloat16 + # NOTE(xrsrke): we expect the output is a raw torch.Tensor, not FP8Paramter, or NanotronParameter + # assert isinstance(sharded_output, torch.Tensor) + assert sharded_output.__class__ == torch.Tensor + assert sharded_output.requires_grad is True + + if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + torch.testing.assert_close(sharded_random_input.grad, random_input.grad, rtol=0.1, atol=0.1) + elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + batch_dim_slice = slice( + dist.get_rank(parallel_context.tp_pg) * sharded_batch_size, + (dist.get_rank(parallel_context.tp_pg) + 1) * sharded_batch_size, + ) + torch.testing.assert_close( + sharded_random_input.grad, + random_input.grad[batch_dim_slice], + ) + else: + ValueError(f"Unsupported mode: {tp_mode}") + + if with_bias is True: + torch.testing.assert_close( + column_linear.bias.grad, + reference_linear.bias.grad[hidden_dim_slice], + ) + + # if isinstance(column_linear.weight.data, FP8Tensor): + # # grad = column_linear.weight.data._temp_grad + # # grad = convert_tensor_from_fp8(grad, column_linear.weight.data._temp_grad.fp8_meta, torch.bfloat16) + # grad = convert_tensor_from_fp8(column_linear.weight.grad, column_linear.weight.grad.fp8_meta, torch.bfloat16) + # else: + # # grad = column_linear.weight.grad + # grad = column_linear.weight.grad + grad = convert_tensor_from_fp8(column_linear.weight.grad, column_linear.weight.grad.fp8_meta, torch.bfloat16) + + torch.testing.assert_close( + grad, + reference_linear.weight.grad[hidden_dim_slice].to(torch.bfloat16), + # rtol=0.1, atol=0.1 + rtol=0.2, + atol=0.2, + ) + + parallel_context.destroy() + + +# TODO(xrsrke): support gradient flow to bias + + +@pytest.mark.parametrize("tp,dp,pp", [[1, 1, 1], [2, 1, 1]]) +@pytest.mark.parametrize("tp_mode", [TensorParallelLinearMode.ALL_REDUCE]) +@pytest.mark.parametrize("async_communication", [False]) +@pytest.mark.parametrize("with_bias", [False]) +@rerun_if_address_is_in_use() +def test_row_linear( + tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool, with_bias: bool +): + if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication: + pytest.skip("ALL_REDUCE mode does not support async communication") + + init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)( + tp_mode=tp_mode, async_communication=async_communication, with_bias=with_bias + ) + + +def _test_row_linear( + parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool, with_bias: bool +): + if async_communication: + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + + out_features = 16 + in_features_per_rank = 32 + + in_features = parallel_context.tp_pg.size() * in_features_per_rank + dist.get_rank(parallel_context.tp_pg) + + # Sharded + row_linear = FP8TensorParallelRowLinear( + in_features=in_features, + out_features=out_features, + pg=parallel_context.tp_pg, + mode=tp_mode, + device="cuda", + async_communication=async_communication, + bias=with_bias, + ) + + # Un-sharded + reference_linear = nn.Linear(in_features=in_features, out_features=out_features, bias=with_bias, device="cuda") + + # Copy weights/bias from sharded to un-sharded + # NOTE(xrsrke): dont' use torch.inference_mode because got "Cannot set version_counter for inference tensor" + # https://github.com/pytorch/pytorch/issues/112024 + dist.all_reduce(tensor=reference_linear.weight, op=dist.ReduceOp.SUM, group=parallel_context.tp_pg) + + sharded_weight = reference_linear.weight[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] + # row_linear.weight.data.set_data(sharded_weight) + # get_data_from_param(row_linear.weight).set_data(sharded_weight) + row_linear._set_and_quantize_weights(sharded_weight) + + if with_bias is True: + # broadcast bias from rank 0, and the other don't have bias + if dist.get_rank(parallel_context.tp_pg) == 0: + # row_linear.bias.data.copy_(reference_linear.bias) + # get_data_from_param(row_linear.bias).copy_(reference_linear.bias) + row_linear.bias.data.copy_(reference_linear.bias.data) + + dist.broadcast( + tensor=reference_linear.bias, + src=get_global_rank(group=parallel_context.tp_pg, group_rank=0), + group=parallel_context.tp_pg, + ) + + # Generate random input + if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + batch_size = 16 + elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + batch_size = 16 * parallel_context.tp_pg.size() + else: + raise ValueError() + + random_input = torch.randn(batch_size, in_features, device="cuda") + # synchronize random_input across tp + dist.all_reduce(random_input, op=dist.ReduceOp.AVG, group=parallel_context.tp_pg) + + assert_tensor_synced_across_pg(random_input, pg=parallel_context.tp_pg) + + # Row linear receives as input sharded input + random_sharded_input = random_input[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ] + + start_idx = dist.get_rank(parallel_context.tp_pg) * in_features_per_rank + end_idx = (dist.get_rank(parallel_context.tp_pg) + 1) * in_features_per_rank + sharded_portion = (slice(None), slice(start_idx, end_idx)) + torch.testing.assert_close( + # convert_tensor_from_fp8(row_linear.weight.data, row_linear.weight.data.fp8_meta, torch.bfloat16), + convert_tensor_from_fp8( + # get_data_from_param(row_linear.weight), get_data_from_param(row_linear.weight).fp8_meta, torch.bfloat16 + row_linear.weight.data, + row_linear.weight.data.fp8_meta, + torch.bfloat16, + ), + reference_linear.weight.to(torch.bfloat16)[sharded_portion], + rtol=0.1, + atol=0.1, + ) + + # Test that we get the same output after forward pass + # TODO @kunhao: We may want to have our custom error type + reference_output = reference_linear(random_input) + # reference_output = ReferenceLinear.apply(random_input, reference_linear.weight, reference_linear.bias) + sharded_output = row_linear(random_sharded_input) + + assert sharded_output.dtype == torch.bfloat16 + # NOTE(xrsrke): we expect the output is a raw torch.Tensor, not FP8Paramter, or NanotronParameter + assert sharded_output.__class__ == torch.Tensor + assert sharded_output.requires_grad is True + + if tp_mode is TensorParallelLinearMode.ALL_REDUCE: + sharded_reference_output = reference_output + elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + assert batch_size % parallel_context.tp_pg.size() == 0 + sharded_batch_size = batch_size // parallel_context.tp_pg.size() + sharded_reference_output = reference_output[ + dist.get_rank(parallel_context.tp_pg) + * sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1) + * sharded_batch_size + ] + else: + raise ValueError(f"Unsupported mode: {tp_mode}") + + # TODO @thomasw21: Tune tolerance + torch.testing.assert_close(sharded_output, sharded_reference_output.to(torch.bfloat16), rtol=0.2, atol=0.2) + + # Test that we get the same gradient after backward pass + sharded_output.sum().backward() + reference_output.sum().backward() + + if with_bias is True: + if dist.get_rank(parallel_context.tp_pg) == 0: + torch.testing.assert_close( + row_linear.bias.grad, + reference_linear.bias.grad, + ) + else: + assert row_linear.bias is None + + # if isinstance(row_linear.weight.data, FP8Tensor): + # if isinstance(get_data_from_param(row_linear.weight), FP8Tensor): + # # grad = row_linear.weight.data._temp_grad + # # grad = convert_tensor_from_fp8(grad, row_linear.weight.data._temp_grad.fp8_meta, torch.bfloat16) + # # grad = get_grad_from_parameter(row_linear.weight) + # # grad = row_linear.weight.grad + # grad = convert_tensor_from_fp8(row_linear.weight.grad, row_linear.weight.grad.fp8_meta, torch.bfloat16) + # else: + # # grad = row_linear.weight.grad + # grad = get_grad_from_parameter(row_linear.weight) + grad = convert_tensor_from_fp8(row_linear.weight.grad, row_linear.weight.grad.fp8_meta, torch.bfloat16) + + torch.testing.assert_close( + grad, + reference_linear.weight.grad[ + :, + dist.get_rank(parallel_context.tp_pg) + * in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1) + * in_features_per_rank, + ].to(torch.bfloat16), + rtol=0.2, + atol=0.2, + ) + + parallel_context.destroy() diff --git a/tests/fp8/test_fp8_utils.py b/tests/fp8/test_fp8_utils.py new file mode 100644 index 00000000..209c8c84 --- /dev/null +++ b/tests/fp8/test_fp8_utils.py @@ -0,0 +1,125 @@ +import pytest +import torch +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.linear import FP8Linear +from nanotron.fp8.parameter import FP8Parameter +from nanotron.fp8.utils import ( + _log, + convert_linear_to_fp8, + convert_to_fp8_module, + get_leaf_modules, + is_overflow_underflow_nan, +) +from torch import nn + + +@pytest.mark.parametrize("accum_dtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +def test_convert_linear_to_fp8(accum_dtype): + linear = nn.Linear(16, 16, device="cuda") + fp8_linear = convert_linear_to_fp8(linear, accum_dtype) + + assert isinstance(fp8_linear, FP8Linear) + assert isinstance(fp8_linear.weight, FP8Parameter) + assert isinstance(fp8_linear.bias, nn.Parameter) + + +@pytest.mark.parametrize("accum_dtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +def test_convert_module_to_fp8(accum_dtype): + HIDDEN_SIZE = 16 + input = torch.randn(1, HIDDEN_SIZE, device="cuda") + + model = nn.Sequential( + nn.Linear(16, 16, device="cuda"), + nn.ReLU(), + nn.Linear(16, 16, device="cuda"), + nn.ReLU(), + ) + + fp8_model = convert_to_fp8_module(model, accum_dtype) + + assert fp8_model(input).shape == model(input).shape + + ref_modules = get_leaf_modules(model) + fp8_modules = get_leaf_modules(fp8_model) + + for (ref_name, ref_module), (fp8_name, fp8_module) in zip(ref_modules, fp8_modules): + assert ref_name == fp8_name + + if not isinstance(ref_module, nn.Linear): + assert isinstance(fp8_module, type(ref_module)) + else: + assert isinstance(fp8_module, FP8Linear) + + assert ref_module.weight.shape == fp8_module.weight.shape + assert ref_module.weight.numel() == fp8_module.weight.numel() + assert ref_module.weight.requires_grad == fp8_module.weight.requires_grad + assert ref_module.weight.device == fp8_module.weight.device + + assert ref_module.bias.shape == fp8_module.bias.shape + assert ref_module.bias.numel() == fp8_module.bias.numel() + assert ref_module.bias.requires_grad == fp8_module.bias.requires_grad + assert ref_module.bias.device == fp8_module.bias.device + + +@pytest.mark.parametrize( + "tensor, expected_output", + [ + [torch.tensor(0), False], + [torch.tensor(1.0), False], + [torch.tensor(float("inf")), True], + [torch.tensor(float("-inf")), True], + [torch.tensor(float("nan")), True], + ], +) +def test_detect_overflow_underflow_nan(tensor, expected_output): + output = is_overflow_underflow_nan(tensor) + assert output == expected_output + + +def test_track_module_statistics(): + class FP8Model(nn.Module): + def __init__(self): + super(FP8Model, self).__init__() + self.fin = FP8Linear(32, 32, device="cuda") + self.relu = nn.ReLU() + self.fout = FP8Linear(32, 32, device="cuda") + + def forward(self, x): + return self.fout(self.relu(self.fin(x))) + + input = torch.randn(32, 32, device="cuda") + model = FP8Model() + + logs, _ = _log(model) + + for _ in range(1): + model(input).sum().backward() + + # NOTE: now merge module_name:x:statistic into a flat dictionary + assert logs.keys() == {"fin", "relu", "fout"} + assert logs["fin"].keys() == { + "weight", + "bias", + "input", + "output", + "grad_output:0", + "grad_input:0", + "grad_input:1", + } + + +# @pytest.mark.skip +# def test_track_fp8_optimizer(): +# input = torch.randn(16, 16, device="cuda") +# linear = nn.Linear(16, 16, device="cuda") +# fp8_linear = convert_linear_to_fp8(linear, accum_qtype=DTypes.KFLOAT16) +# fp8_optim = FP8Adam(fp8_linear.parameters()) + +# logs = track_optimizer(fp8_optim) + +# for _ in range(1): +# fp8_linear(input).sum().backward() +# fp8_optim.step() +# fp8_optim.zero_grad() + +# assert logs.keys() == ["master_weights", "optimizer_states"] diff --git a/tests/fp8/test_linear.py b/tests/fp8/test_linear.py index f88ee558..e72b38bb 100644 --- a/tests/fp8/test_linear.py +++ b/tests/fp8/test_linear.py @@ -1,83 +1,288 @@ +from copy import deepcopy +from functools import partial, reduce + import pytest import torch -from nanotron.fp8 import DTypes, FP8Linear, FP8Parameter, FP8Tensor +from nanotron.fp8.constants import FP8_DTYPES, QTYPE_TO_DTYPE +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.linear import FP8Linear, FP8LinearMeta +from nanotron.fp8.recipe import FP8LinearRecipe +from nanotron.fp8.tensor import FP8Tensor, convert_tensor_from_fp8 +from nanotron.fp8.utils import convert_linear_to_fp8, convert_to_fp8_module, is_overflow_underflow_nan +from nanotron.parallel.parameters import NanotronParameter + +# from timm.models.layers import trunc_normal_ from torch import nn -from torch.optim import Adam -@pytest.mark.parametrize("is_bias", [True, False]) -def test_fp8_linear_forward_pass(is_bias): - input = torch.randn(16, 16, device="cuda", dtype=torch.float32) - ref_input = input.detach().clone() - ref_linear = nn.Linear(16, 16, bias=is_bias, device="cuda", dtype=torch.float32) +@pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +@pytest.mark.parametrize("bias", [True, False]) +def test_create_an_fp8_linear_parameters(bias, accum_qtype): + fp8_linear = FP8Linear(64, 64, bias=bias, device="cuda", accum_qtype=accum_qtype) + + assert isinstance(fp8_linear.weight, NanotronParameter) + if bias: + assert isinstance(fp8_linear.bias, torch.Tensor) + assert fp8_linear.bias.requires_grad is True + + assert isinstance(fp8_linear.recipe, FP8LinearRecipe) + assert isinstance(fp8_linear.metadatas, FP8LinearMeta) + + for p in fp8_linear.parameters(): + if p.ndim == 1: + # NOTE: not quantize biases + assert isinstance(p.data, torch.Tensor) + else: + assert isinstance(p.data, FP8Tensor) - fp8_linear = FP8Linear(16, 16, bias=is_bias, device="cuda:0") - fp8_linear.weight = FP8Parameter(ref_linear.weight.detach().clone(), DTypes.FP8E4M3) - if is_bias: - fp8_linear.bias.data = ref_linear.bias.detach().clone() +# def test_fp8_linear_parameters(): +# ref_linear = nn.Linear(16, 16, device="cuda") +# fp8_linear = convert_linear_to_fp8(deepcopy(ref_linear), accum_qtype=DTypes.KFLOAT32) + +# assert len(list(ref_linear.parameters())) == len(list(fp8_linear.parameters())) +# assert all(p is not None for p in fp8_linear.parameters()) +# # assert isinstance(fp8_linear.weight, FP8Parameter) +# # assert isinstance(fp8_linear.bias, torch.Tensor) +# assert all(p.requires_grad for p in fp8_linear.parameters()) is True + + +# NOTE: sometimes the assert output just fails, if you rerun it, +# then it will pass, i also observed this in the FP8-LM implementation +# @pytest.mark.skip +@pytest.mark.parametrize("n_layers", [1, 2]) +@pytest.mark.parametrize( + "input", + [ + torch.randn(64, 64, device="cuda", dtype=torch.float32), # [B, H] + torch.randn(16, 64, device="cuda", dtype=torch.float32), # [B, H] + torch.randn(16, 32, 64, device="cuda", dtype=torch.float32), # [B, N, H] + torch.randn(64, 64, 64, device="cuda", dtype=torch.float32), # [B, N, H] + ], +) +# @pytest.mark.parametrize("is_bias", [True, False]) +# @pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16, torch.bfloat16]) +@pytest.mark.parametrize("is_bias", [False]) +@pytest.mark.parametrize("accum_qtype", [torch.bfloat16]) +def test_fp8_linear_forward_pass(n_layers, input, is_bias, accum_qtype): + HIDDEN_SIZE = 64 + + ref_input = input.detach().clone() + ref_linear = nn.Sequential( + *[ + layer + for _ in range(n_layers) + for layer in (nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, bias=is_bias, device="cuda"), nn.ReLU()) + ] + ) + + fp8_linear = convert_to_fp8_module(deepcopy(ref_linear), accum_qtype) ref_output = ref_linear(ref_input) output = fp8_linear(input) assert isinstance(output, torch.Tensor) - assert output.dtype == torch.float32 - assert torch.allclose(output, ref_output, rtol=0, atol=0.1) + assert output.dtype == QTYPE_TO_DTYPE[accum_qtype] + + # NOTE: this threshold is from fp8-lm, the paper shows that this is fine + torch.testing.assert_allclose(output, ref_output, rtol=0, atol=0.1) # TODO(xrsrke): add cases where the input requires and don't require grad -@pytest.mark.parametrize("input_requires_grad", [True, False]) -@pytest.mark.parametrize("device", ["cuda:0", "cpu"]) -def test_fp8_linear_backward_pass(input_requires_grad, device): - input = torch.randn(16, 16, device=device, dtype=torch.float32, requires_grad=input_requires_grad) +# @pytest.mark.skip("we already test this in the test_tensor_parallel") +@pytest.mark.parametrize("n_layers", [1, 2]) +@pytest.mark.parametrize( + "input", + [ + torch.randn(64, 64, device="cuda", dtype=torch.float32), # [B, H] + # torch.randn(16, 64, device="cuda", dtype=torch.float32), # [B, H] + # torch.randn(16, 32, 64, device="cuda", dtype=torch.float32), # [B, N, H] + # torch.randn(64, 64, 64, device="cuda", dtype=torch.float32), # [B, N, H] + ], +) +# @pytest.mark.parametrize( +# "init_method", +# [ +# lambda weight: trunc_normal_(weight, std=0.02), +# lambda weight: trunc_normal_(weight, std=math.sqrt(1 / 64)), +# lambda weight: trunc_normal_(weight, std=math.sqrt(1 / 64 * 4)), +# lambda weight: trunc_normal_(weight, std=1), +# ], +# ) +# @pytest.mark.parametrize("is_bias", [True, False]) +# @pytest.mark.skip +# @pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +@pytest.mark.parametrize("accum_qtype", [torch.bfloat16]) +def test_fp8_linear_backward_pass(n_layers, input, accum_qtype): + is_bias = False + + HIDDEN_SIZE = 64 + ref_input = input.detach().clone().requires_grad_(True) - ref_linear = nn.Linear(16, 16, device=device, dtype=torch.float32) - fp8_linear = FP8Linear(16, 16, device=device) + # ref_linear = nn.Linear(HIDDEN_SIZE, INTERDIM_SIZE, device="cuda", dtype=torch.float32) + ref_linear = nn.Sequential( + *[ + layer + for _ in range(n_layers) + for layer in (nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, bias=is_bias, device="cuda"), nn.ReLU()) + ] + ) + + # trunc_normal_(ref_linear.weight, std=0.02) + # trunc_normal_(ref_linear.weight, std=math.sqrt(1 / (HIDDEN_SIZE))) - if device == "cpu": - fp8_linear.weight.data = ref_linear.weight.detach().clone() - else: - fp8_linear.weight.data = FP8Tensor(ref_linear.weight.detach().clone(), dtype=DTypes.FP8E4M3) - fp8_linear.bias.data = ref_linear.bias.detach().clone() + # fp8_linear = convert_linear_to_fp8(deepcopy(ref_linear), accum_qtype) + fp8_linear = convert_to_fp8_module(deepcopy(ref_linear), accum_qtype) ref_linear(ref_input).sum().backward() fp8_linear(input).sum().backward() - # TODO(xrsrke): investigate why input.grad is so high tolerance - # assert torch.allclose(input.grad, ref_input.grad, 0.2, 0.2) if input_requires_grad else True - assert torch.allclose(fp8_linear.weight.grad, ref_linear.weight.grad, 0.1, 0.1) - assert torch.allclose(fp8_linear.bias.grad, ref_linear.bias.grad, 0, 0.1) + for ref_p, p in zip(ref_linear.parameters(), fp8_linear.parameters()): + if isinstance(p.data, FP8Tensor): + # NOTE: bypass torch autograd for FP8Tensor + # so we can compute the gradients ourself + assert p.requires_grad is False + assert isinstance(p.grad, FP8Tensor) + assert p.grad.dtype in FP8_DTYPES + else: + assert p.requires_grad is True + assert isinstance(p.grad, torch.Tensor) + assert p.grad.dtype == QTYPE_TO_DTYPE[accum_qtype] -# TODO(xrsrke): test if FP8Linear has all the methods of a torch.nn.Linear + # if p.requires_grad is False: + # if isinstance(p.data, FP8Tensor): + # assert isinstance(p.grad, FP8Tensor) + # else: + # assert p.grad is None + # if isinstance(p.data, FP8Tensor): + # # assert isinstance(p.grad, FP8Tensor) + # grad = convert_tensor_from_fp8(p.grad, p.grad.fp8_meta, torch.float32) + # else: + # # assert isinstance(p.grad, torch.Tensor) + grad = ( + convert_tensor_from_fp8(p.grad, p.grad.fp8_meta, torch.float32) + if isinstance(p.data, FP8Tensor) + else p.grad + ) -def test_fp8_linear_attrs(): - fp8_linear = FP8Linear(16, 16, device="cuda:0") + assert is_overflow_underflow_nan(grad) is False + if p.ndim > 1: + # NOTE: these weight threshold is tuned from the FP8-LM implementation + # TODO(xrsrke): tune what is the minimum threshold for this to correctly converge + torch.testing.assert_allclose(grad, ref_p.grad, rtol=0.06, atol=0.1) + else: + torch.testing.assert_allclose(grad, ref_p.grad) - assert next(fp8_linear.parameters()) is not None - assert all(p.requires_grad for p in fp8_linear.parameters()) is True + # assert isinstance(fp8_linear.weight.grad, FP8Tensor) + # assert fp8_linear.weight.grad.dtype in FP8_DTYPES + # assert isinstance(fp8_linear.bias.grad, torch.Tensor) + # assert fp8_linear.bias.grad.dtype == QTYPE_TO_DTYPE[accum_qtype] -# TODO(xrsrke): test only calculating the gradients of the weight, bias, or input based -# on the requires_grad of the input, weight, or bias + # # TODO(xrsrke): investigate why input.grad is so high tolerance + # # assert torch.allclose(input.grad, ref_input.grad, 0.2, 0.2) if input_requires_grad else True + + # # NOTE: these weight threshold is tuned from the FP8-LM implementation + # # TODO(xrsrke): tune what is the minimum threshold for this to correctly converge + # weight_grad = convert_tensor_from_fp8(fp8_linear.weight.grad, fp8_linear.weight.grad.fp8_meta, torch.float32) + # torch.testing.assert_allclose(weight_grad, ref_linear.weight.grad, rtol=0.06, atol=0.1) + # torch.testing.assert_allclose(fp8_linear.bias.grad, ref_linear.bias.grad) + + +@pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +def test_fp8_modules_trigger_the_entire_computational_graph(accum_qtype): + HIDDEN_SIZE = 16 + TIMELINE = [] + def backward_hook(module, grad_input, grad_output, idx): + TIMELINE.append(f"{module.__class__.__name__}.{idx}.backward") -def test_fp8_model_bwd(): - HIDEEN_SIZE = 128 - N_LAYERS = 5 - N_EPOCHS = 3 + class Logger(nn.Module): + def __init__(self, idx: int, module: nn.Linear): + super().__init__() + module.register_backward_hook(partial(backward_hook, idx=idx)) + self.module = module + self.idx = idx - input = torch.randn(HIDEEN_SIZE, HIDEEN_SIZE, device="cuda", requires_grad=True) + def forward(self, input): + TIMELINE.append(f"{self.module.__class__.__name__}.{self.idx}.forward") + return self.module(input) - model = nn.Sequential( - *[nn.Sequential(FP8Linear(HIDEEN_SIZE, HIDEEN_SIZE, device="cuda"), nn.ReLU()) for _ in range(N_LAYERS)] + input = torch.randn(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda", dtype=torch.float32) + fp8_linear = nn.Sequential( + nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda", dtype=torch.float32), + nn.ReLU(), + nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda", dtype=torch.float32), + nn.ReLU(), ) - optim = Adam(model.parameters(), lr=1e-3) + fp8_linear = convert_to_fp8_module(fp8_linear, accum_qtype) + fp8_linear = nn.ModuleList([Logger(idx, module) for idx, module in enumerate(fp8_linear)]) - for _ in range(N_EPOCHS): - optim.zero_grad() - model(input).sum().backward() - optim.step() + output = reduce(lambda x, module: module(x), fp8_linear, input) + scalar = torch.randn(1, device="cuda", dtype=output.dtype) + (output.sum() * scalar).backward() + + assert TIMELINE == [ + "FP8Linear.0.forward", + "ReLU.1.forward", + "FP8Linear.2.forward", + "ReLU.3.forward", + "ReLU.3.backward", + "FP8Linear.2.backward", + "ReLU.1.backward", + "FP8Linear.0.backward", + ] + + for p in fp8_linear.parameters(): + if p.requires_grad is True: + assert is_overflow_underflow_nan(p.grad) is False + + +# NOTE: it seems that dynamic quantization should be in test_tensor +# but we only do this if we are in training => test it in a linear +@pytest.mark.parametrize("interval", [1, 5, 10]) +def test_deplay_quantization(interval): + # NOTE: test delay quantization (window size) + # NOTE: test overflow, underflow, zeros + # NOTE: test reduce/increase exponent bits + + HIDDEN_SIZE = 16 + N_STEPS = 4 + + input = torch.randn(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda", dtype=torch.float32) + fp8_linear = FP8Linear(HIDDEN_SIZE, HIDDEN_SIZE, device="cuda") + + for _ in range(N_STEPS): + output = fp8_linear(input) + output.sum().backward() + + +@pytest.mark.skip +@pytest.mark.parametrize("input_shape", [(16, 15), (15, 16), (15, 15)]) +@pytest.mark.parametrize("is_bias", [True, False]) +@pytest.mark.parametrize("accum_qtype", [DTypes.KFLOAT32, DTypes.KFLOAT16]) +def test_fp8_linear_padding(input_shape, is_bias, accum_qtype): + input = torch.randn(**input_shape) + ref_input = input.detach().clone() + ref_linear = nn.Linear(16, 16, bias=is_bias, device="cuda") + fp8_linear = convert_linear_to_fp8(deepcopy(ref_linear), accum_qtype) + + ref_output = ref_linear(ref_input) + output = fp8_linear(input) + + assert isinstance(output, torch.Tensor) + assert output.dtype == QTYPE_TO_DTYPE[accum_qtype] + + # NOTE: this threshold is from fp8-lm, the paper shows that this is fine + torch.testing.assert_allclose(output, ref_output, rtol=0, atol=0.1) + + +# TODO(xrsrke): test if FP8Linear has all the methods of a torch.nn.Linear + + +# TODO(xrsrke): test only calculating the gradients of the weight, bias, or input based +# on the requires_grad of the input, weight, or bias - assert all(p.grad is not None for p in model.parameters()) +# TODO(xrsrke): test automatic padding if a input/weight shape isn't divisible by 16 diff --git a/tests/fp8/test_new_tensor.py b/tests/fp8/test_new_tensor.py new file mode 100644 index 00000000..6fe5bcbc --- /dev/null +++ b/tests/fp8/test_new_tensor.py @@ -0,0 +1,463 @@ +from copy import deepcopy +from typing import cast + +import numpy as np +import pytest +import torch +from nanotron.fp8.constants import ( + FP8_ATOL_THRESHOLD, + FP8_RTOL_THRESHOLD, + FP16_ATOL_THRESHOLD, + FP16_RTOL_THRESHOLD, + QTYPE_TO_DTYPE, +) +from nanotron.fp8.dtypes import DTypes +from nanotron.fp8.meta import FP8Meta +from nanotron.fp8.tensor import FP8Tensor, FP16Tensor, convert_tensor_from_fp8, convert_tensor_from_fp16 +from nanotron.testing.utils import TestContext + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +@pytest.mark.parametrize("interval", [1, 5]) +def test_fp8_and_fp16_metadata(tensor_cls, dtype, interval): + import transformer_engine as te # noqa + import transformer_engine_torch as tex + + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = tensor_cls(tensor, dtype=dtype, interval=interval) + fp8_meta = cast(FP8Meta, fp8_tensor.fp8_meta) + + # TODO(xrsrke): remove the fixed 1 factor + # it couples with the current implementation of FP8Meta + # because we initialize scale with 1 + # assert fp8_meta.amax == ref_tensor.abs().max() + assert torch.equal(fp8_meta.amax, ref_tensor.amax()) + assert isinstance(fp8_meta.inverse_scale, torch.Tensor) + assert fp8_meta.scale != 0.1 and fp8_meta.scale != 1.0 + assert isinstance(fp8_meta.te_dtype, tex.DType) + assert fp8_meta.interval == interval + + +@pytest.mark.parametrize("size", [4, 8, 16, 64]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_quantize_and_dequantize_tensor_in_fp8(size, dtype): + tensor = torch.randn((size, size), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + fp8_tensor = FP8Tensor(tensor, dtype=dtype) + + assert not np.array_equal(fp8_tensor.cpu().numpy(), ref_tensor.cpu().numpy()) + + tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) + # tensor = fp8_tensor.to(torch.float32) + # NOTE: sometimes type(tensor) is FP8Tensor, but it still passes, so we directly check the class name + # to make sure it's a torch.Tensor + assert tensor.__class__ == torch.Tensor + assert tensor.dtype == ref_tensor.dtype + + torch.testing.assert_close(tensor, ref_tensor, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD) + + +# @pytest.mark.parametrize("interval", [1, 5]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_create_fp8_tensor_from_metadata(dtype): + INTERVAL = 5 + TOTAL_STEPS, REMAINING_STEPS = 20, 16 + tensor = torch.randn((16, 16), dtype=torch.float32, device="cuda") + fp8_tensor = FP8Tensor(tensor, dtype=dtype, interval=INTERVAL) + + new_values = [torch.randn((16, 16), dtype=torch.float32, device="cuda") for i in range(TOTAL_STEPS)] + + for i in range(TOTAL_STEPS): + if TOTAL_STEPS - REMAINING_STEPS == i: + current_tensor = new_values[i - 1] + fp8_meta = deepcopy(fp8_tensor.fp8_meta) + + fp8_tensor.data = new_values[i] + + resumed_fp8_tensor = FP8Tensor.from_metadata(current_tensor, fp8_meta) + for i in range(TOTAL_STEPS - REMAINING_STEPS, TOTAL_STEPS): + resumed_fp8_tensor.data = new_values[i] + + # NOTE: we expect a resume tensor to have the state trajectory of the original tensor + assert resumed_fp8_tensor.fp8_meta == fp8_tensor.fp8_meta + assert resumed_fp8_tensor == fp8_tensor + + +@pytest.mark.parametrize("size", [4, 8, 16, 64]) +def test_quantize_and_dequantize_tensor_in_fp16(size): + tensor = torch.randn((size, size), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp16_tensor = FP16Tensor(tensor, dtype=DTypes.KFLOAT16) + + assert not np.array_equal(fp16_tensor.cpu().numpy(), ref_tensor.cpu().numpy()) + + tensor = convert_tensor_from_fp16(fp16_tensor, torch.float32) + # tensor = fp16_tensor.to(torch.float32) + # NOTE: sometimes type(tensor) is FP16Tensor, but it still passes + assert tensor.__class__ == torch.Tensor + assert tensor.dtype == ref_tensor.dtype + + # NOTE: this tolerance is from FP8-LM's implementation + # reference: https://github.com/Azure/MS-AMP/blob/9ac98df5371f3d4174d8f103a1932b3a41a4b8a3/tests/common/tensor/test_cast.py#L35 + torch.testing.assert_close(tensor, ref_tensor, rtol=FP16_RTOL_THRESHOLD, atol=FP16_ATOL_THRESHOLD) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_fp8_and_fp16_tensor_repr(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(tensor, dtype) + + # NOTE: in some cases, it causes an infinite loop + # in repr(tensor), so just check if it doesn't loop + assert isinstance(repr(fp8_tensor), str) + + +@pytest.mark.parametrize( + "tensor_cls, dtype, expected_dtype", + [ + (FP8Tensor, DTypes.FP8E4M3, torch.uint8), + (FP8Tensor, DTypes.FP8E5M2, torch.uint8), + (FP16Tensor, DTypes.KFLOAT16, torch.float16), + ], +) +def test_fp8_and_fp16_tensor_attrs(tensor_cls, dtype, expected_dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + ref_tensor = tensor.detach().clone() + + fp8_tensor = tensor_cls(tensor, dtype) + + assert isinstance(fp8_tensor, tensor_cls) + assert isinstance(fp8_tensor.fp8_meta, FP8Meta) + assert fp8_tensor.dtype == expected_dtype + assert fp8_tensor.device == ref_tensor.device + assert fp8_tensor.shape == ref_tensor.shape + assert fp8_tensor.numel() == ref_tensor.numel() + assert fp8_tensor.device == ref_tensor.device + + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", +# [ +# (FP8Tensor, DTypes.FP8E4M3), +# (FP8Tensor, DTypes.FP8E5M2), +# (FP16Tensor, DTypes.KFLOAT16), +# ], +# ) +# @pytest.mark.parametrize( +# "scale", +# [ +# torch.ones(1, device="cuda:0").squeeze() * 2, # an random scalar +# torch.ones(1, device="cuda:0") * 2, +# torch.ones(4, 4, device="cuda:0") * 2, +# ], +# ) +# def test_multiple_fp8_tensor(tensor_cls, dtype, scale): +# RTOL, ATOL = ( +# (FP8_RTOL_THRESHOLD, FP8_ATOL_THRESHOLD) +# if tensor_cls == FP8Tensor +# else (FP16_RTOL_THRESHOLD, FP16_ATOL_THRESHOLD) +# ) +# tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda:0") +# ref_tensor = tensor.detach().clone() + +# fp8_tensor = tensor_cls(deepcopy(tensor), dtype) +# ref_fp8_tensor = fp8_tensor.clone() + +# with fail_if_expect_to_fail(expect_to_fail=scale.ndim > 1): +# fp8_tensor.mul_(scale) + +# assert torch.equal(fp8_tensor, ref_fp8_tensor) +# assert fp8_tensor.fp8_meta.scale != ref_fp8_tensor.fp8_meta.scale + +# if isinstance(fp8_tensor, FP8Tensor): +# # NOTE: with the current implementation, we only scale the metadata +# # not the tensor itself, so we expect the tensor to be the same +# tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) +# else: +# tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) + +# torch.testing.assert_allclose(tensor, ref_tensor * scale, rtol=RTOL, atol=ATOL) + + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", +# [ +# (FP8Tensor, DTypes.FP8E4M3), +# (FP8Tensor, DTypes.FP8E5M2), +# (FP16Tensor, DTypes.KFLOAT16), +# ], +# ) +# @pytest.mark.parametrize( +# "scale", +# [ +# torch.ones(1, device="cuda:0").squeeze() * 2, # an random scalar +# torch.ones(1, device="cuda:0") * 2, +# torch.ones(4, 4, device="cuda:0") * 2, +# ], +# ) +# def test_divide_fp8_tensor(tensor_cls, dtype, scale): +# # NOTE: the reason that we use 2 as the scale is because +# # the purpose of this test is to test whether we scale the magnitude +# # of the tensor, so if use other values from normal distribution, +# # some values could lead to quantization error, and for this test we don't +# # test the quantization error +# tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") +# ref_tensor = deepcopy(tensor) + +# fp8_tensor = tensor_cls(deepcopy(tensor), dtype) +# ref_fp8_tensor = fp8_tensor.clone() + +# with fail_if_expect_to_fail(expect_to_fail=scale.ndim > 1): +# fp8_tensor.div_(scale) + +# assert torch.equal(fp8_tensor, ref_fp8_tensor) +# assert fp8_tensor.fp8_meta.scale != ref_fp8_tensor.fp8_meta.scale + +# if isinstance(fp8_tensor, FP8Tensor): +# tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) +# # NOTE: use the same tolerance as test_quantize_and_dequantize_tensor_in_fp8 +# torch.testing.assert_allclose(tensor, ref_tensor / scale, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD) +# else: +# tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) +# torch.testing.assert_close(tensor, ref_tensor / scale, rtol=FP16_RTOL_THRESHOLD, atol=FP16_ATOL_THRESHOLD) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_add_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + with pytest.raises(ValueError): + fp8_tensor + 1 + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_subtract_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + with pytest.raises(ValueError): + fp8_tensor - 1 + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_clone_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + cloned_fp8_tensor = fp8_tensor.clone() + + assert isinstance(cloned_fp8_tensor, tensor_cls) + assert id(cloned_fp8_tensor) != id(fp8_tensor) + assert cloned_fp8_tensor.device == fp8_tensor.device + + assert torch.equal(cloned_fp8_tensor, fp8_tensor) + assert cloned_fp8_tensor.data_ptr() != fp8_tensor.data_ptr() + assert cloned_fp8_tensor.data.data_ptr() != fp8_tensor.data.data_ptr() + + assert cloned_fp8_tensor.fp8_meta == fp8_tensor.fp8_meta + assert id(cloned_fp8_tensor.fp8_meta) != id(fp8_tensor.fp8_meta) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + # (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_transpose_fp8_tensor(tensor_cls, dtype): + tensor = torch.randn((16, 16), dtype=torch.float32, device="cuda:0") + ref_transposed_tensor = deepcopy(tensor).T + fp8_tensor = tensor_cls(tensor, dtype) + + transposed_fp8_tensor = fp8_tensor.transpose_fp8() + + assert isinstance(transposed_fp8_tensor, FP8Tensor) + + dequant_transposed_fp8_tensor = convert_tensor_from_fp8( + transposed_fp8_tensor, transposed_fp8_tensor.fp8_meta, torch.float32 + ) + torch.testing.assert_close( + dequant_transposed_fp8_tensor, ref_transposed_tensor, rtol=FP8_RTOL_THRESHOLD, atol=FP8_ATOL_THRESHOLD + ) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", + [ + (FP8Tensor, DTypes.FP8E4M3), + (FP8Tensor, DTypes.FP8E5M2), + (FP16Tensor, DTypes.KFLOAT16), + ], +) +def test_determistic_quantization(tensor_cls, dtype): + tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") + fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + ref_fp8_tensor = tensor_cls(deepcopy(tensor), dtype) + + assert torch.equal(fp8_tensor, ref_fp8_tensor) + assert fp8_tensor.fp8_meta == ref_fp8_tensor.fp8_meta + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_fp8_and_fp16_tensor_storage_memory(tensor_cls, dtype): + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = tensor_cls(tensor, dtype=dtype) + + assert id(fp8_tensor) != id(ref_tensor) + + assert isinstance(fp8_tensor.data, torch.Tensor) + assert id(fp8_tensor.data) != id(ref_tensor.data) + assert fp8_tensor.data_ptr() == fp8_tensor.data.data_ptr() + assert fp8_tensor.data.data_ptr() != ref_tensor.data_ptr() + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +@pytest.mark.parametrize("is_quantized", [True, False]) +def test_setting_new_data_for_fp8_and_fp16_tensor(tensor_cls, dtype, is_quantized): + RTOL, ATOL = ( + (FP8_RTOL_THRESHOLD, FP8_ATOL_THRESHOLD) + if tensor_cls == FP8Tensor + else (FP16_RTOL_THRESHOLD, FP16_ATOL_THRESHOLD) + ) + + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + quant_tensor = tensor_cls(tensor, dtype=dtype) + + new_data = torch.randn(quant_tensor.shape, dtype=torch.float32, device="cuda") * 2 + ref_new_data = deepcopy(new_data) + expected_quantized_tensor = tensor_cls(ref_new_data, dtype=dtype) + + new_data = tensor_cls(new_data, dtype=dtype) if is_quantized else new_data + quant_tensor.data = new_data + assert dequant_tensor.data.data_ptr() == new_data.data.data_ptr() + + assert quant_tensor.data.dtype == QTYPE_TO_DTYPE[dtype] + assert torch.equal(quant_tensor, expected_quantized_tensor) + + if is_quantized: + # if tensor_cls == FP8Tensor: + # dequant_tensor = convert_tensor_from_fp8(quant_tensor, fp8_tensor.fp8_meta, torch.float32) + # else: + # dequantized_tensor = convert_tensor_from_fp16(fp8_tensor, torch.float32) + + dequant_tensor = quant_tensor.to(torch.float32) + assert torch.allclose(dequant_tensor, ref_new_data, rtol=RTOL, atol=ATOL) + + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +# ) +# @pytest.mark.parametrize("is_quantized", [True, False]) +# def test_setting_None_data_for_fp8_and_fp16_tensor(tensor_cls, dtype, is_quantized): +# tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") +# fp8_tensor = tensor_cls(tensor, dtype=dtype) + +# fp8_tensor.set_data(None) + +# assert fp8_tensor is None +# assert fp8_tensor.data is None + +# @pytest.mark.parametrize( +# "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +# ) +# def test_quantize_overflow_fp8_and_fp16_tensor(tensor_cls, dtype): +# tensor = torch.randn((64, 64), dtype=torch.float32, device="cuda:0") +# tensor[0, 0] = torch.tensor(float("inf")) +# fp8_tensor = tensor_cls(tensor, dtype) + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_zero_out_data_of_fp8_and_fp16_tensor(tensor_cls, dtype): + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + quant_tensor = tensor_cls(tensor, dtype=dtype) + + quant_tensor.zero_() + + assert torch.equal(quant_tensor, torch.zeros_like(quant_tensor)) + + dequant_tensor = quant_tensor.to(torch.float32) + assert torch.equal(dequant_tensor, torch.zeros_like(tensor)) + + +# NOTE: add testing based on tensor metadata +@pytest.mark.parametrize("is_meta_the_same", [True, False]) +@pytest.mark.parametrize("dtype", [DTypes.FP8E4M3, DTypes.FP8E5M2]) +def test_fp8_and_fp16_tensor_equality_based_on_tensor_value(is_meta_the_same, dtype): + # TODO(xrsrke): support torch.equal for FP8Tensor + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + ref_tensor = deepcopy(tensor) + + fp8_tensor = FP8Tensor(tensor, dtype=dtype) + ref_fp8_tensor = FP8Tensor(ref_tensor, dtype=dtype) + + if not is_meta_the_same: + fp8_tensor.fp8_meta.scale = ref_fp8_tensor.fp8_meta.scale * 2 + + assert (fp8_tensor == ref_fp8_tensor) is is_meta_the_same + + new_data = torch.randn(tensor.shape, dtype=torch.float32, device="cuda") + ref_fp8_tensor.data = new_data + + assert not fp8_tensor == ref_fp8_tensor + + +# TODO(xrsrke): test it has all the methods of torch.Tensor + +# TODO(xrsrke): test it has all the attributes of its input tensor + + +@pytest.mark.parametrize( + "tensor_cls, dtype", [(FP8Tensor, DTypes.FP8E4M3), (FP8Tensor, DTypes.FP8E5M2), (FP16Tensor, DTypes.KFLOAT16)] +) +def test_serialize_fp8_tensor(tensor_cls, dtype): + test_context = TestContext() + store_folder = test_context.get_auto_remove_tmp_dir() + tensor = torch.randn((4, 4), dtype=torch.float32, device="cuda") + + fp8_tensor = tensor_cls(tensor, dtype=dtype) + + torch.save(fp8_tensor, f"{store_folder}/fp8_tensor.pt") + torch.load(f"{store_folder}/fp8_tensor.pt") + + +# TODO(xrsrke): add test for fp8_tensor.transpose() API +# TODO(xrsrke): add test for fp8_tensor.to(dtype) API diff --git a/tests/fp8/test_tensor.py b/tests/fp8/test_tensor.py deleted file mode 100644 index 7e84095f..00000000 --- a/tests/fp8/test_tensor.py +++ /dev/null @@ -1,60 +0,0 @@ -from copy import deepcopy - -import numpy as np -import pytest -import torch -import transformer_engine as te # noqa -import transformer_engine_extensions as tex -from nanotron.fp8 import DTypes, FP8Tensor -from nanotron.fp8.meta import FP8Meta -from nanotron.fp8.tensor import convert_tensor_from_fp8 - - -@pytest.mark.parametrize("size", [4, 8, 16, 64]) -def test_quantize_and_dequantize_tensor_in_fp8(size): - tensor = torch.randn((size, size), dtype=torch.float32, device="cuda") - ref_tensor = deepcopy(tensor) - - fp8_tensor = FP8Tensor(tensor, dtype=DTypes.FP8E4M3) - - assert isinstance(fp8_tensor, FP8Tensor) - assert isinstance(fp8_tensor.fp8_meta, FP8Meta) - assert fp8_tensor.device == ref_tensor.device - assert fp8_tensor.dtype == torch.uint8 - assert fp8_tensor.shape == ref_tensor.shape - assert fp8_tensor.numel() == ref_tensor.numel() - assert not np.array_equal(fp8_tensor.cpu().numpy(), ref_tensor.cpu().numpy()) - - # TODO(xrsrke): remove the fixed 1 factor - # it couples with the current implementation of FP8Meta - # because we initialize scale with 1 - assert fp8_tensor.fp8_meta.amax == ref_tensor.abs().max() - assert isinstance(fp8_tensor.fp8_meta.inverse_scale, torch.Tensor) - assert fp8_tensor.fp8_meta.scale != 0.1 and fp8_tensor.fp8_meta.scale != 1 - assert isinstance(fp8_tensor.fp8_meta.te_dtype, tex.DType) - - tensor = convert_tensor_from_fp8(fp8_tensor, fp8_tensor.fp8_meta, torch.float32) - assert isinstance(tensor, torch.Tensor) - assert tensor.dtype == ref_tensor.dtype - assert torch.allclose(tensor, ref_tensor, rtol=1e-1, atol=1e-1) - - -def test_fp8_tensor_attrs(): - SIZE = 64 - tensor = torch.randn((SIZE, SIZE), dtype=torch.float32, device="cuda:0") - ref_tensor = tensor.detach().clone() - - fp8_tensor = FP8Tensor(tensor, DTypes.FP8E4M3) - - assert isinstance(fp8_tensor, FP8Tensor) - assert isinstance(fp8_tensor.fp8_meta, FP8Meta) - assert fp8_tensor.device == ref_tensor.device - assert fp8_tensor.dtype == torch.uint8 - assert fp8_tensor.shape == ref_tensor.shape - assert fp8_tensor.numel() == ref_tensor.numel() - assert fp8_tensor.device == ref_tensor.device - - -# TODO(xrsrke): test it has all the methods of torch.Tensor - -# TODO(xrsrke): test it has all the attributes of its input tensor diff --git a/tests/fp8/utils.py b/tests/fp8/utils.py new file mode 100644 index 00000000..43194389 --- /dev/null +++ b/tests/fp8/utils.py @@ -0,0 +1,25 @@ +import importlib +import sys +from contextlib import contextmanager +from pathlib import Path + +import pytest + + +@contextmanager +def fail_if_expect_to_fail(expect_to_fail: bool): + try: + yield + except AssertionError as e: + if expect_to_fail is True: + pytest.xfail("Failed successfully") + else: + raise e + + +def set_system_path(): + package = importlib.import_module("nanotron") + # NOTE: Path(package.__file__).parent = .../nanotron/src/nanotron + # we want .../nanotron + package_path = Path(package.__file__).parent.parent.parent + sys.path.append(str(package_path)) diff --git a/tests/helpers/context.py b/tests/helpers/context.py index 77fa5d70..e69de29b 100644 --- a/tests/helpers/context.py +++ b/tests/helpers/context.py @@ -1,21 +0,0 @@ -import shutil -import uuid -from functools import lru_cache -from pathlib import Path - - -class TestContext: - def __init__(self): - self._random_string = str(uuid.uuid1()) - self._root_dir = Path(__file__).parent.parent / ".test_cache" - self._root_dir.mkdir(parents=True, exist_ok=True) - - @lru_cache(maxsize=1) - def get_auto_remove_tmp_dir(self): - path = self._root_dir / self._random_string - path.mkdir(parents=True, exist_ok=True) - return path - - def __del__(self): - path = self.get_auto_remove_tmp_dir() - shutil.rmtree(path) diff --git a/tests/helpers/llama.py b/tests/helpers/llama.py index 3f94031f..09da194a 100644 --- a/tests/helpers/llama.py +++ b/tests/helpers/llama.py @@ -1,137 +1,137 @@ -import torch -from nanotron.config import ( - AllForwardAllBackwardPipelineEngine, - CheckpointsArgs, - Config, - DataArgs, - DatasetStageArgs, - GeneralArgs, - LlamaConfig, - LoggingArgs, - LRSchedulerArgs, - ModelArgs, - OptimizerArgs, - ParallelismArgs, - TensorParallelLinearMode, - TokenizerArgs, - TokensArgs, -) -from nanotron.config.config import PretrainDatasetsArgs -from nanotron.models import build_model -from nanotron.models.llama import LlamaForTraining -from nanotron.parallel.context import ParallelContext -from nanotron.trainer import mark_tied_parameters +# import torch +# from nanotron.config import ( +# AllForwardAllBackwardPipelineEngine, +# CheckpointsArgs, +# Config, +# DataArgs, +# DatasetStageArgs, +# GeneralArgs, +# LlamaConfig, +# LoggingArgs, +# LRSchedulerArgs, +# ModelArgs, +# OptimizerArgs, +# ParallelismArgs, +# TensorParallelLinearMode, +# TokenizerArgs, +# TokensArgs, +# ) +# from nanotron.config.config import PretrainDatasetsArgs +# from nanotron.models import build_model +# from nanotron.models.llama import LlamaForTraining +# from nanotron.parallel.context import ParallelContext +# from nanotron.trainer import mark_tied_parameters -TINY_LLAMA_CONFIG = LlamaConfig( - **{ - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 16, - "initializer_range": 0.02, - "intermediate_size": 32, - "is_llama_config": True, - "max_position_embeddings": 128, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pad_token_id": None, - "pretraining_tp": 1, - "rms_norm_eps": 1e-06, - "rope_scaling": None, - "tie_word_embeddings": False, - "use_cache": True, - "vocab_size": 4096, - } -) +# TINY_LLAMA_CONFIG = LlamaConfig( +# **{ +# "bos_token_id": 1, +# "eos_token_id": 2, +# "hidden_act": "silu", +# "hidden_size": 16, +# "initializer_range": 0.02, +# "intermediate_size": 32, +# "is_llama_config": True, +# "max_position_embeddings": 128, +# "num_attention_heads": 8, +# "num_hidden_layers": 4, +# "num_key_value_heads": 4, +# "pad_token_id": None, +# "pretraining_tp": 1, +# "rms_norm_eps": 1e-06, +# "rope_scaling": None, +# "tie_word_embeddings": False, +# "use_cache": True, +# "vocab_size": 4096, +# } +# ) -def get_llama_training_config(model_config: ModelArgs): - return Config( - model=model_config, - general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), - checkpoints=CheckpointsArgs( - checkpoints_path="./checkpoints", - checkpoint_interval=10, - ), - parallelism=ParallelismArgs( - dp=1, - pp=1, - tp=2, - expert_parallel_size=2, - pp_engine="1f1b", - tp_mode="ALL_REDUCE", - tp_linear_async_communication=False, - ), - tokenizer=TokenizerArgs("gpt2"), - optimizer=OptimizerArgs( - zero_stage=0, - weight_decay=0.01, - clip_grad=1.0, - accumulate_grad_in_fp32=False, - adam_eps=1e-08, - adam_beta1=0.9, - adam_beta2=0.95, - torch_adam_is_fused=True, - learning_rate_scheduler=LRSchedulerArgs( - learning_rate=3e-4, - lr_warmup_steps=100, - lr_warmup_style="linear", - lr_decay_style="cosine", - min_decay_lr=1e-5, - ), - ), - logging=LoggingArgs(), - tokens=TokensArgs(sequence_length=16, train_steps=10, micro_batch_size=16, batch_accumulation_per_replica=1), - data_stages=[ - DatasetStageArgs( - name="train", - start_training_step=1, - data=DataArgs( - seed=42, - num_loading_workers=1, - dataset=PretrainDatasetsArgs( - hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small", - hf_dataset_splits="train", - text_column_name="completion", - dataset_processing_num_proc_per_process=12, - ), - ), - ) - ], - ) +# def get_llama_training_config(model_config: ModelArgs): +# return Config( +# model=model_config, +# general=GeneralArgs(project="unittest", run="sanity_llama", seed=42), +# checkpoints=CheckpointsArgs( +# checkpoints_path="./checkpoints", +# checkpoint_interval=10, +# ), +# parallelism=ParallelismArgs( +# dp=1, +# pp=1, +# tp=2, +# expert_parallel_size=2, +# pp_engine="1f1b", +# tp_mode="ALL_REDUCE", +# tp_linear_async_communication=False, +# ), +# tokenizer=TokenizerArgs("gpt2"), +# optimizer=OptimizerArgs( +# zero_stage=0, +# weight_decay=0.01, +# clip_grad=1.0, +# accumulate_grad_in_fp32=False, +# adam_eps=1e-08, +# adam_beta1=0.9, +# adam_beta2=0.95, +# torch_adam_is_fused=True, +# learning_rate_scheduler=LRSchedulerArgs( +# learning_rate=3e-4, +# lr_warmup_steps=100, +# lr_warmup_style="linear", +# lr_decay_style="cosine", +# min_decay_lr=1e-5, +# ), +# ), +# logging=LoggingArgs(), +# tokens=TokensArgs(sequence_length=16, train_steps=10, micro_batch_size=16, batch_accumulation_per_replica=1), +# data_stages=[ +# DatasetStageArgs( +# name="train", +# start_training_step=1, +# data=DataArgs( +# seed=42, +# num_loading_workers=1, +# dataset=PretrainDatasetsArgs( +# hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small", +# hf_dataset_splits="train", +# text_column_name="completion", +# dataset_processing_num_proc_per_process=12, +# ), +# ), +# ) +# ], +# ) -def create_llama_from_config( - model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext -) -> LlamaForTraining: +# def create_llama_from_config( +# model_config: LlamaConfig, device: torch.device, parallel_context: ParallelContext +# ) -> LlamaForTraining: - """ - Creates and returns a nanotron model. - If `model_config` is None, then `checkpoint_path` must be set, in which case - the configuration will be loaded from such path. - If `checkpoint_path` is None, then `model_config` must be set, in which case - the model created will have random weights. - """ +# """ +# Creates and returns a nanotron model. +# If `model_config` is None, then `checkpoint_path` must be set, in which case +# the configuration will be loaded from such path. +# If `checkpoint_path` is None, then `model_config` must be set, in which case +# the model created will have random weights. +# """ - parallel_config = ParallelismArgs( - dp=parallel_context.data_parallel_size, - pp=parallel_context.pipeline_parallel_size, - tp=parallel_context.tensor_parallel_size, - pp_engine=AllForwardAllBackwardPipelineEngine(), - tp_mode=TensorParallelLinearMode.ALL_REDUCE, - tp_linear_async_communication=False, - ) - model = build_model( - model_builder=lambda: LlamaForTraining( - config=model_config, - parallel_context=parallel_context, - parallel_config=parallel_config, - random_states=None, - ), - parallel_context=parallel_context, - dtype=torch.bfloat16, - device=device, - ) - mark_tied_parameters(model=model, parallel_context=parallel_context) - return model +# parallel_config = ParallelismArgs( +# dp=parallel_context.data_parallel_size, +# pp=parallel_context.pipeline_parallel_size, +# tp=parallel_context.tensor_parallel_size, +# pp_engine=AllForwardAllBackwardPipelineEngine(), +# tp_mode=TensorParallelLinearMode.ALL_REDUCE, +# tp_linear_async_communication=False, +# ) +# model = build_model( +# model_builder=lambda: LlamaForTraining( +# config=model_config, +# parallel_context=parallel_context, +# parallel_config=parallel_config, +# random_states=None, +# ), +# parallel_context=parallel_context, +# dtype=torch.bfloat16, +# device=device, +# ) +# mark_tied_parameters(model=model, parallel_context=parallel_context) +# return model diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index d0fb01b5..e69de29b 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -1,276 +0,0 @@ -import contextlib -import os -import re -from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Tuple - -import torch.cuda -import torch.multiprocessing as mp -from nanotron.parallel import ParallelContext -from packaging import version - - -def available_gpus(): - if not torch.cuda.is_available(): - return 0 - - device_properties = [torch.cuda.get_device_properties(i) for i in range(torch.cuda.device_count())] - - # We filter out - blacklisted_gpu_names = {"NVIDIA DGX Display"} - device_properties = [property_ for property_ in device_properties if property_.name not in blacklisted_gpu_names] - - # TODO @thomasw21: Can we do this cross node - return len(device_properties) - - -# from https://stackoverflow.com/a/34333710/9201239 -@contextlib.contextmanager -def mock_os_environ(remove_keys: List[str] = None, update_key_values: Dict[str, Any] = None): - """ - Temporarily updates the ``os.environ`` dictionary in-place. - The ``os.environ`` dictionary is updated in-place so that the modification is sure to work in all situations. - Args: - remove_keys: Environment variables to remove. - update_key_values: Dictionary of environment variables and values to add/update. - """ - env = os.environ - update_key_values = update_key_values or {} - remove_keys = remove_keys or [] - - update_keys = set(update_key_values.keys()) - remove_keys = set(remove_keys) - assert remove_keys.isdisjoint(update_keys) - - stomped = (update_keys | remove_keys) & set(env.keys()) - reverse_change = { - # Environment variables and values to restore on exit. - **{k: env[k] for k in update_keys & stomped}, - # Environment variables and values to remove on exit. - **{k: env[k] for k in remove_keys & stomped}, - } - - try: - env.update(update_key_values) - for k in remove_keys: - env.pop(k, None) - yield - finally: - env.update(reverse_change) - - -def is_dict_equal(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None) -> Tuple[bool, Optional[str]]: - """Returns True or False if the dictionaries match, and an additional message when it's False""" - if sub_paths is None: - sub_paths = [] - - first_keys = set(first.keys()) - second_keys = set(second.keys()) - if first_keys != second_keys: - return False, f"Keys don't match in {'.'.join(sub_paths)}.\nCur: {first_keys}\nRef: {second_keys}" - for key in first_keys: - first_elt = first[key] - second_elt = second[key] - - if isinstance(first_elt, dict): - if not isinstance(second_elt, dict): - return ( - False, - f"Object types don't match in {'.'.join(sub_paths + [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}", - ) - match, msg = is_dict_equal(first_elt, second_elt, sub_paths=sub_paths + [str(key)]) - if match is False: - return False, msg - elif isinstance(first_elt, torch.Tensor): - if not isinstance(second_elt, torch.Tensor): - return ( - False, - f"Object types don't match in {'.'.join(sub_paths + [str(key)])}.\nCur: {first_elt}\nRef: {second_elt}", - ) - try: - torch.testing.assert_close( - first_elt, - second_elt, - atol=0.0, - rtol=0.0, - msg=lambda msg: f"Tensor at {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}\n{msg}", - ) - except AssertionError as error: - return False, error.args[0] - else: - if first_elt != second_elt: - return ( - False, - f"Objects at key {'.'.join(sub_paths + [str(key)])} don't match.\nCur: {first_elt}\nRef: {second_elt}", - ) - - return True, None - - -def get_all_3d_configurations(gpus: int) -> List[Tuple[int, int, int]]: - """Given a number of gpus, we want all 3d configurations possible such that pp * dp * tp = gpus""" - result = [] - for tp in range(1, gpus + 1): - if gpus % tp != 0: - continue - gpus_left_after_tp = gpus // tp - for dp in range(1, gpus_left_after_tp + 1): - if gpus_left_after_tp % dp != 0: - continue - gpus_left_after_dp = gpus_left_after_tp // dp - for pp in range(1, gpus_left_after_dp + 1): - if gpus_left_after_dp % pp != 0: - continue - if tp * dp * pp == gpus: - result.append((pp, dp, tp)) - return result - - -def rerun_if_address_is_in_use(max_try: int = 500): - """ - This function reruns a wrapped function if "address already in use" occurs - in testing spawned with torch.multiprocessing - - Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L157 - - Usage:: - - @rerun_if_address_is_in_use() - def test_something(): - ... - - """ - # check version - torch_version = version.parse(torch.__version__) - assert torch_version.major >= 1 - - # only torch >= 1.8 has ProcessRaisedException - if torch_version >= version.parse("1.8.0"): - exception = torch.multiprocessing.ProcessRaisedException - else: - exception = Exception - - func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*", max_try=max_try) - return func_wrapper - - -def rerun_on_exception(exception_type: Exception = Exception, pattern: str = None, max_try: int = 10) -> Callable: - """ - A decorator on a function to re-run when an exception occurs. - - Credits: https://github.com/hpcaitech/ColossalAI/blob/adae123df3badfb15d044bd416f0cf29f250bc86/colossalai/testing/utils.py#L71 - - Usage:: - - # rerun for all kinds of exception - @rerun_on_exception() - def test_method(): - print('hey') - raise RuntimeError('Address already in use') - - # rerun for RuntimeError only - @rerun_on_exception(exception_type=RuntimeError) - def test_method(): - print('hey') - raise RuntimeError('Address already in use') - - # rerun for maximum 10 times if Runtime error occurs - @rerun_on_exception(exception_type=RuntimeError, max_try=10) - def test_method(): - print('hey') - raise RuntimeError('Address already in use') - - # rerun for infinite times if Runtime error occurs - @rerun_on_exception(exception_type=RuntimeError, max_try=None) - def test_method(): - print('hey') - raise RuntimeError('Address already in use') - - # rerun only the exception message is matched with pattern - # for infinite times if Runtime error occurs - @rerun_on_exception(exception_type=RuntimeError, pattern="^Address.*$") - def test_method(): - print('hey') - raise RuntimeError('Address already in use') - - Args: - exception_type (Exception, Optional): The type of exception to detect for rerun - pattern (str, Optional): The pattern to match the exception message. - If the pattern is not None and matches the exception message, - the exception will be detected for rerun - max_try (int, Optional): Maximum reruns for this function. The default value is 5. - If max_try is None, it will rerun forever if exception keeps occurring - """ - - def _match_lines(lines, pattern): - for line in lines: - if re.match(pattern, line): - return True - return False - - def _wrapper(func): - def _run_until_success(*args, **kwargs): - try_count = 0 - assert max_try is None or isinstance( - max_try, int - ), f"Expected max_try to be None or int, but got {type(max_try)}" - - while max_try is None or try_count < max_try: - try: - try_count += 1 - ret = func(*args, **kwargs) - return ret - except exception_type as e: - error_lines = str(e).split("\n") - if try_count < max_try and (pattern is None or _match_lines(error_lines, pattern)): - - print("Exception is caught, retrying...") - # when pattern is not specified, we always skip the exception - # when pattern is specified, we only skip when pattern is matched - continue - else: - print("Maximum number of attempts is reached or pattern is not matched, no more retrying...") - raise e - - # Override signature - # otherwise pytest.mark.parameterize will raise the following error: - # function does not use argument xxx - sig = signature(func) - _run_until_success.__signature__ = sig - - return _run_until_success - - return _wrapper - - -def global_wrapper(rank, func, tp, pp, dp, port, kwargs): - def setup_dist_env(rank, world_size, port): - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["RANK"] = str(rank) - # NOTE: since we do unit tests in a - # single node => this is fine! - os.environ["LOCAL_RANK"] = str(rank) - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(port) - - world_size = tp * pp * dp - setup_dist_env(rank, world_size, port) - parallel_context = ParallelContext(data_parallel_size=dp, pipeline_parallel_size=pp, tensor_parallel_size=tp) - func(parallel_context, **kwargs) - - -def init_distributed(tp: int, dp: int, pp: int): - def _init_distributed(func): - def wrapper(**kwargs): - from nanotron.utils import find_free_port - - world_size = tp * pp * dp - port = find_free_port() - - # Note that kwargs needs to be passed as part of args in a way that can be unpacked - args = (func, tp, pp, dp, port, kwargs) - mp.spawn(global_wrapper, args=args, nprocs=world_size) - - return wrapper - - return _init_distributed diff --git a/tests/nanoset/test_build_nanoset_dataloader.py b/tests/nanoset/test_build_nanoset_dataloader.py index 113c545c..e47d425e 100644 --- a/tests/nanoset/test_build_nanoset_dataloader.py +++ b/tests/nanoset/test_build_nanoset_dataloader.py @@ -8,7 +8,6 @@ import numpy as np import pytest -from helpers.context import TestContext from helpers.data import ( assert_batch_dataloader, assert_nanoset_sync_across_all_ranks, @@ -17,11 +16,17 @@ create_dummy_json_dataset, preprocess_dummy_dataset, ) -from helpers.utils import available_gpus, get_all_3d_configurations, init_distributed, rerun_if_address_is_in_use from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.data.nanoset import Nanoset from nanotron.data.utils import count_dataset_indexes, normalize from nanotron.parallel import ParallelContext +from nanotron.testing.utils import ( + TestContext, + available_gpus, + get_all_3d_configurations, + init_distributed, + rerun_if_address_is_in_use, +) from nanotron.utils import main_rank_first from transformers import AutoTokenizer diff --git a/tests/pytest.ini b/tests/pytest.ini index 0e0b2653..c0c6f3eb 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,4 +1,4 @@ [pytest] -addopts=-n 35 +; addopts=-n 10 markers = fa2: FA2-related diff --git a/tests/test_base_model.py b/tests/test_base_model.py index b4759905..ab86a3b5 100644 --- a/tests/test_base_model.py +++ b/tests/test_base_model.py @@ -1,14 +1,56 @@ import pytest import torch import torch.distributed as dist -from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config -from helpers.utils import init_distributed, rerun_if_address_is_in_use from nanotron.config import Config, ModelArgs, RandomInit +from nanotron.models.base import init_on_device_and_dtype from nanotron.parallel import ParallelContext from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.testing.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from nanotron.testing.utils import init_distributed, rerun_if_address_is_in_use from torch import nn +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda:0")]) +def test_override_dtype_and_device_in_module_init(dtype, device): + class ModuleWithBuffer(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(2, 2)) + self.weight = nn.Parameter(torch.randn(2, 2)) + + with init_on_device_and_dtype(device=device, dtype=dtype): + linear = ModuleWithBuffer() + + assert all(p.dtype == dtype for p in linear.parameters()) + assert all(p.device == device for p in linear.parameters()) + + assert all(b.dtype == dtype for b in linear.buffers()) + assert all(b.device == device for b in linear.buffers()) + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) +@pytest.mark.parametrize("device", [torch.device("cpu"), torch.device("cuda:0")]) +@rerun_if_address_is_in_use() +def test_dtype_of_model_initialization(dtype: torch.dtype, device: torch.device): + init_distributed(tp=1, dp=1, pp=1)(_test_dtype_of_model_initialization)(dtype=dtype, device=device) + + +def _test_dtype_of_model_initialization(parallel_context: ParallelContext, dtype: torch.dtype, device: torch.device): + model_args = ModelArgs(init_method=RandomInit(std=1.0), model_config=TINY_LLAMA_CONFIG) + config = get_llama_training_config(model_args) + llama = create_llama_from_config( + model_config=TINY_LLAMA_CONFIG, device=device, parallel_context=parallel_context, dtype=dtype + ) + llama.init_model_randomly(config=config) + + assert all(p.dtype == dtype for p in llama.parameters()) + assert all(p.device == device for p in llama.parameters()) + + assert all(b.dtype == dtype for b in llama.buffers()) + assert all(b.device == device for b in llama.buffers()) + + @pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 2, 2)]) @pytest.mark.skip @rerun_if_address_is_in_use() diff --git a/tests/test_clip_grads.py b/tests/test_clip_grads.py index b4682875..c1aa55fd 100644 --- a/tests/test_clip_grads.py +++ b/tests/test_clip_grads.py @@ -4,7 +4,6 @@ import pytest import torch from helpers.dummy import DummyModel, dummy_infinite_data_loader -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron import distributed as dist from nanotron.models import init_on_device_and_dtype from nanotron.optim.clip_grads import clip_grad_norm @@ -27,6 +26,7 @@ ) from nanotron.parallel.utils import initial_sync from nanotron.sanity_checks import assert_tensor_synced_across_pg +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from torch import nn diff --git a/tests/test_data_parallel.py b/tests/test_data_parallel.py index 21ae191a..2a03331f 100644 --- a/tests/test_data_parallel.py +++ b/tests/test_data_parallel.py @@ -3,12 +3,12 @@ import pytest import torch from helpers.exception import assert_fail_except_rank_with -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron import distributed as dist from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.parameters import NanotronParameter from nanotron.sanity_checks import assert_tensor_synced_across_pg +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from torch import nn from torch.distributed import GradBucket diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 7c0d2462..2cf58bff 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -1,13 +1,13 @@ import numpy as np import pytest import torch.distributed as dist -from helpers.utils import ( +from nanotron.parallel import ParallelContext +from nanotron.testing.utils import ( available_gpus, get_all_3d_configurations, init_distributed, rerun_if_address_is_in_use, ) -from nanotron.parallel import ParallelContext from torch.distributed import ProcessGroup @@ -45,4 +45,4 @@ def _test_init_parallel_context(parallel_context: ParallelContext): ) @rerun_if_address_is_in_use() def test_init_parallel_context(tp: int, dp: int, pp: int): - init_distributed(tp=tp, dp=dp, pp=pp)(_test_init_parallel_context)() \ No newline at end of file + init_distributed(tp=tp, dp=dp, pp=pp)(_test_init_parallel_context)() diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 1a28f967..9acc9374 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -2,12 +2,12 @@ import pytest import torch -from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config -from helpers.utils import init_distributed, rerun_if_address_is_in_use from nanotron.config import ModelArgs, RandomInit, SpectralMupInit from nanotron.helpers import get_custom_lr_for_named_parameters from nanotron.parallel import ParallelContext from nanotron.scaling.parametrization import ParametrizationMethod +from nanotron.testing.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from nanotron.testing.utils import init_distributed, rerun_if_address_is_in_use @pytest.mark.parametrize("tp,dp,pp", [(1, 1, 1), (2, 1, 1), (1, 1, 2), (2, 1, 2)]) diff --git a/tests/test_optimizer_params_groups.py b/tests/test_optimizer_params_groups.py index fa835e1c..3fa682ca 100644 --- a/tests/test_optimizer_params_groups.py +++ b/tests/test_optimizer_params_groups.py @@ -3,13 +3,13 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron.optim.gradient_accumulator import FP32GradientAccumulator from nanotron.optim.named_optimizer import NamedOptimizer from nanotron.optim.optimizer_from_gradient_accumulator import OptimizerFromGradientAccumulator from nanotron.parallel.context import ParallelContext from nanotron.parallel.parameters import NanotronParameter from nanotron.random import set_random_seed +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use class DummyModel(nn.Module): diff --git a/tests/test_p2p.py b/tests/test_p2p.py index ed8245a8..c49d8d9a 100644 --- a/tests/test_p2p.py +++ b/tests/test_p2p.py @@ -3,10 +3,10 @@ import pytest import torch from helpers.exception import assert_fail_with -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron import distributed as dist from nanotron.parallel import ParallelContext from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use @pytest.mark.skipif(available_gpus() < 2, reason="Testing test_ddp_with_afab requires at least 2 gpus") diff --git a/tests/test_parameter.py b/tests/test_parameter.py index ea031ef2..cb8ef797 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,10 +1,51 @@ import torch from helpers.exception import assert_fail_with from nanotron.models.base import DTypeInvariantTensor, init_on_device_and_dtype + +# from nanotron.testing.utils import TestContext +from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.sharded_parameters import SplitConfig, create_sharded_parameter_from_config +from nanotron.testing.utils import init_distributed, rerun_if_address_is_in_use from torch import nn +@rerun_if_address_is_in_use() +def test_get_parameter_data(): + init_distributed(tp=2, dp=1, pp=1)(_test_get_parameter_data)() + + +def _test_get_parameter_data(parallel_context: ParallelContext): + param = torch.nn.Parameter(torch.randn(16, 64)) + split_config = SplitConfig( + split_dim=0, + contiguous_chunks=(8, 8), + ) + param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config) + + new_data = torch.randn(16, 64) + param.data = new_data + + assert param.data is new_data + + +@rerun_if_address_is_in_use() +def test_random_hash_nanotron_parameter(): + init_distributed(tp=2, dp=1, pp=1)(_test_random_hash_nanotron_parameter)() + + +def _test_random_hash_nanotron_parameter(parallel_context: ParallelContext): + param = torch.nn.Parameter(torch.randn(16, 64)) + split_config = SplitConfig( + split_dim=0, + contiguous_chunks=(8, 8), + ) + param = create_sharded_parameter_from_config(parameter=param, pg=parallel_context.tp_pg, split_config=split_config) + + assert hash(param) is not None + assert type(hash(param)) == int + + def test_nanotron_parameter_does_not_override_some_parameter_variable(): param = nn.Parameter(torch.empty(3)) assert not hasattr(param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME) @@ -58,3 +99,35 @@ def test_register_buffer_does_not_update_uncastable_tensor(): # Test that dtype has not been modified assert module.buffer.dtype is old_dtype + + +@rerun_if_address_is_in_use() +def test_create_param_that_share_metadata(): + init_distributed(tp=2, dp=1, pp=1)(_test_create_param_that_share_metadata)() + + +def _test_create_param_that_share_metadata(parallel_context: ParallelContext): + param = torch.nn.Parameter(torch.randn(16, 64)) + split_config = SplitConfig( + split_dim=0, + contiguous_chunks=(8, 8), + ) + orig_param = create_sharded_parameter_from_config( + parameter=param, pg=parallel_context.tp_pg, split_config=split_config + ) + new_param = NanotronParameter.create_param_that_share_metadata(torch.randn(16, 64), param=orig_param) + + new_param_metadata = getattr(new_param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME) + orig_param_metadata = getattr(orig_param, NanotronParameter.NANOTRON_PARAMETER_METADATA_ATTRIBUTE_NAME) + + for (p1_k, p1_v), (p2_k, p2_v) in zip(new_param_metadata.items(), orig_param_metadata.items()): + assert p1_k == p2_k + assert p1_v == p2_v + + # orig_hash = getattr(orig_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME) + # new_hash = getattr(new_param, NanotronParameter.NANOTRON_PARAMETER_HASH_ATTRIBUTE_NAME) + + # assert new_hash == orig_hash + assert hash(new_param) == hash(orig_param) + + parallel_context.destroy() diff --git a/tests/test_parameters_accumulate_gradient_in_fp32.py b/tests/test_parameters_accumulate_gradient_in_fp32.py index ba0debd6..b993cf03 100644 --- a/tests/test_parameters_accumulate_gradient_in_fp32.py +++ b/tests/test_parameters_accumulate_gradient_in_fp32.py @@ -5,7 +5,6 @@ import torch from helpers.dummy import DummyModel, dummy_infinite_data_loader from helpers.exception import assert_fail_except_rank_with, timeout_after -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron.models import init_on_device_and_dtype from nanotron.optim import ZeroDistributedOptimizer from nanotron.optim.gradient_accumulator import FP32GradBucketManager, FP32GradientAccumulator, get_fp32_accum_hook @@ -29,6 +28,7 @@ ) from nanotron.parallel.utils import initial_sync from nanotron.sanity_checks import assert_tensor_synced_across_pg +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron.utils import ContextManagers from torch import nn @@ -57,6 +57,11 @@ def test_gradient_promoting_in_fp32(half_precision: torch.dtype): torch.testing.assert_close(model.weight.grad, torch.zeros_like(model.weight.grad), atol=1e-6, rtol=1e-7) +# TODO: test gradient accumulator skips creating master weights for fp32 parameters +# TODO: test gradient accumulator creates master weights for FP8 parameter +# TODO: test the number of master weights created for an fp8 model + + @pytest.mark.parametrize("half_precision", [torch.float16, torch.bfloat16]) def test_gradient_accumulated_in_fp32(half_precision: torch.dtype): model = nn.Linear(3, 2, bias=False, dtype=half_precision, device="cuda") @@ -257,7 +262,7 @@ def _test_ddp_with_grad_accum_in_fp32( accumulator.backward(loss_fp32_accum) for name, param in model_ddp_fp32_accum.named_parameters(): - # Check that half grads has been set to None in sync step, to avoid it being uncorrectly used + # Check that half grads has been set to None in sync step, to avoid it being incorrectly used half_grad = param.grad assert half_grad is None, f"{half_grad} != None" diff --git a/tests/test_parametrization.py b/tests/test_parametrization.py index fe76826a..ff120bcf 100644 --- a/tests/test_parametrization.py +++ b/tests/test_parametrization.py @@ -3,11 +3,11 @@ import pytest import torch -from helpers.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config -from helpers.utils import init_distributed, rerun_if_address_is_in_use from nanotron.config import ModelArgs, RandomInit, SpectralMupInit from nanotron.parallel import ParallelContext from nanotron.scaling.parametrization import ParametrizationMethod +from nanotron.testing.llama import TINY_LLAMA_CONFIG, create_llama_from_config, get_llama_training_config +from nanotron.testing.utils import init_distributed, rerun_if_address_is_in_use @pytest.mark.parametrize("tp,dp,pp", [(2, 1, 1)]) diff --git a/tests/test_pipeline_parallel.py b/tests/test_pipeline_parallel.py index a7f8008f..6e311f12 100644 --- a/tests/test_pipeline_parallel.py +++ b/tests/test_pipeline_parallel.py @@ -3,7 +3,6 @@ import pytest import torch from helpers.dummy import DummyModel, dummy_infinite_data_loader -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron import distributed as dist from nanotron.models import init_on_device_and_dtype from nanotron.parallel import ParallelContext @@ -15,6 +14,7 @@ ) from nanotron.parallel.pipeline_parallel.p2p import P2P from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from torch import nn from torch.nn import functional as F diff --git a/tests/test_random_state.py b/tests/test_random_state.py index 7abd0b13..3f9a16f1 100644 --- a/tests/test_random_state.py +++ b/tests/test_random_state.py @@ -1,6 +1,5 @@ import pytest import torch -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron import distributed as dist from nanotron.parallel import ParallelContext from nanotron.random import ( @@ -9,6 +8,7 @@ get_current_random_state, get_synced_random_state, ) +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use @pytest.mark.skipif(available_gpus() < 2, reason="Testing test_random_state_sync requires at least 2 gpus") diff --git a/tests/test_serialize.py b/tests/test_serialize.py index 5234710e..3652f69a 100644 --- a/tests/test_serialize.py +++ b/tests/test_serialize.py @@ -1,14 +1,6 @@ import pytest import torch -from helpers.context import TestContext from helpers.dummy import dummy_infinite_data_loader, init_dummy_model -from helpers.utils import ( - available_gpus, - get_all_3d_configurations, - init_distributed, - is_dict_equal, - rerun_if_address_is_in_use, -) from nanotron import distributed as dist from nanotron.constants import CHECKPOINT_VERSION from nanotron.optim.gradient_accumulator import FP32GradientAccumulator @@ -33,6 +25,14 @@ save_weights, ) from nanotron.serialize.metadata import TensorMetadata +from nanotron.testing.utils import ( + TestContext, + available_gpus, + get_all_3d_configurations, + init_distributed, + is_dict_equal, + rerun_if_address_is_in_use, +) from torch.nn.parallel import DistributedDataParallel diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 16008eaa..0b6b28bd 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -2,7 +2,6 @@ import pytest import torch -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron import distributed as dist from nanotron.distributed import get_global_rank from nanotron.parallel import ParallelContext @@ -12,6 +11,7 @@ TensorParallelEmbedding, TensorParallelRowLinear, ) +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from torch import nn as torch_nn diff --git a/tests/test_tie_weights.py b/tests/test_tie_weights.py index eecfc097..c92091fd 100644 --- a/tests/test_tie_weights.py +++ b/tests/test_tie_weights.py @@ -1,7 +1,6 @@ import torch from helpers.distributed_tensor import assert_tensor_equal_over_group from helpers.exception import assert_fail_with -from helpers.utils import init_distributed, rerun_if_address_is_in_use from nanotron import distributed as dist from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import NanotronParameter @@ -10,6 +9,7 @@ sync_tied_weights_gradients, tie_parameters, ) +from nanotron.testing.utils import init_distributed, rerun_if_address_is_in_use from torch import nn diff --git a/tests/test_zero.py b/tests/test_zero.py index f1127f94..830ea739 100644 --- a/tests/test_zero.py +++ b/tests/test_zero.py @@ -5,7 +5,6 @@ from helpers.distributed_tensor import assert_tensor_equal_over_group from helpers.dummy import dummy_infinite_data_loader, init_dummy_model from helpers.exception import assert_fail_with -from helpers.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from nanotron import distributed as dist from nanotron.optim import NamedOptimizer, ZeroDistributedOptimizer from nanotron.optim.zero import SlicedFlatTensor @@ -18,6 +17,7 @@ from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tied_parameters import sync_tied_weights_gradients from nanotron.random import RandomStates, branch_random_state, get_current_random_state, get_synced_random_state +from nanotron.testing.utils import available_gpus, init_distributed, rerun_if_address_is_in_use from torch import nn as torch_nn from torch.nn.parallel import DistributedDataParallel