From 07852c81da8bfeef1a39666d81e6536e9f3fc4dd Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 14 Oct 2024 14:47:33 -0700 Subject: [PATCH 1/6] Beginning of the work towards C++ Float8Tensor Signed-off-by: Przemyslaw Tredak --- .../pytorch/cpp_extensions/gemm.py | 145 ++++++++++++++++++ transformer_engine/pytorch/csrc/common.h | 8 + transformer_engine/pytorch/csrc/extensions.h | 11 ++ .../pytorch/csrc/extensions/gemm.cu | 89 +++++++++++ .../pytorch/csrc/extensions/pybind.cpp | 81 ++++++++++ 5 files changed, 334 insertions(+) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index fd1eb4a810..7bd3900c1a 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -9,6 +9,7 @@ import transformer_engine_torch as tex from ..constants import TE_DType from ..utils import assert_dim_for_fp8_exec +from ..tensor import Float8Tensor __all__ = [ @@ -24,6 +25,150 @@ def _empty_tensor() -> torch.Tensor: """Get tensor with no entries and no data""" return torch.Tensor() +def general_gemm( + A: Union[torch.Tensor, Float8Tensor], + B: Union[torch.Tensor, Float8Tensor], + out_dtype: torch.dtype, + workspace: torch.Tensor, + gelu: bool = False, + accumulate: bool = False, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + use_bias: bool = False, + use_split_accumulator: bool = False, + D_dtype: Optional[tex.DType] = None, + ub_algo: tex.UbufOverlapAlgo = None, + ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None, + extra_output_tensor: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """GEMM supporting fp8 inputs.""" + + empty_tensor = _empty_tensor() + if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + raise ValueError("FP8 output not supported") + # assert_dim_for_fp8_exec(A) + # assert_dim_for_fp8_exec(B) + + if out is None: + out = torch.empty( + B.shape[0], + A.shape[0], + dtype=out_dtype, + device="cuda", + ) + else: + if not out.is_contiguous(): + raise ValueError("Output tensor is not contiguous.") + + # Use bfloat16 as default bias_dtype + bias_dtype = torch.bfloat16 if bias is None else bias.dtype + if gelu: + gelu_input = torch.empty_like(out, dtype=bias_dtype) + else: + gelu_input = empty_tensor + bias_dtype = TE_DType[bias_dtype] + + out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + + args = ( + A, + True, # transa + B, + False, # transb + out, + empty_tensor, # if out_index is None else fp8_meta_tensor.scale[out_index], + out_dtype, + empty_tensor, # if out_index is None else fp8_meta_tensor.amax_history[0][out_index], + bias if use_bias else empty_tensor, + bias_dtype, + gelu_input, # this is pre_gelu_out + False, # grad + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + ) + fn = tex.te_gemm2 + if ub_algo is not None: + raise ValueError("Not implemented yet!") + assert ub is not None, "ub object is None!" + if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG: + fn = ub.bulk_overlap + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple( + args + + ( + 1, + extra_output_tensor, + ) + ) + elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: + fn = ub.bulk_overlap + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple( + args + + ( + 0, + extra_output_tensor, + ) + ) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + fn = ub.split_overlap_ag_p2p + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple(args + (extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: + fn = ub.atomic_gemm_overlap_ag_p2p + extra_output_tensor = ( + empty_tensor if extra_output_tensor is None else extra_output_tensor + ) + args = tuple(args + (extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS: + fn = ub.split_overlap_rs + assert ( + extra_output_tensor is not None + ), "SPLIT_PIPELINED_RS requires extra output tensor" + args = tuple( + args + + ( + True, + extra_output_tensor, + ) + ) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + fn = ub.split_overlap_rs_p2p + assert ( + extra_output_tensor is not None + ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" + args = tuple(args + (extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: + fn = ub.atomic_gemm_overlap_rs + assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" + args = tuple( + args + + ( + True, + extra_output_tensor, + ) + ) + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P: + fn = ub.atomic_gemm_overlap_rs_p2p + assert ( + extra_output_tensor is not None + ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" + args = tuple(args + (extra_output_tensor,)) + if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: + out = fn(*args) + else: + out, gelu_input = fn(*args) + + return out, gelu_input + def fp8_gemm( A: torch.Tensor, diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 04a1193a71..c41193b9c6 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -83,6 +83,14 @@ enum FP8BwdTensors { GRAD_INPUT3 = 5 }; +class Float8Tensor { + public: + at::Tensor data; + std::optional transpose = std::nullopt; + at::Tensor scale_inv; + DType dtype; +}; + } // namespace transformer_engine transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index c797208e06..761f992f21 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -138,6 +138,17 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); * GEMM **************************************************************************************************/ +std::vector te_gemm2(transformer_engine::Float8Tensor A, bool transa, + transformer_engine::Float8Tensor B, + bool transb, std::optional D, at::Tensor D_scale, transformer_engine::DType D_type, + at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator, int math_sm_count); +std::vector te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb, std::optional D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count); void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index ba9851e7e8..bc0166362d 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -4,8 +4,97 @@ * See LICENSE for license information. ************************************************************************/ +#include + +#include "common.h" #include "common/util/cuda_runtime.h" #include "extensions.h" +#include "pytorch/csrc/common.h" +#include "transformer_engine/transformer_engine.h" + +std::vector te_gemm2_helper( + at::Tensor A, transformer_engine::DType A_dtype, std::optional A_scale_inv, + bool transa, at::Tensor B, transformer_engine::DType B_dtype, + std::optional B_scale_inv, bool transb, std::optional D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, + transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { + using namespace transformer_engine; + if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) { + at::Tensor out; + if (D.has_value() && D->data_ptr() != nullptr && !accumulate) { + D->zero_(); // TODO: Handle D without a value + out = *D; + } + if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_(); + return {out, pre_gelu_out}; + } + + A = A.contiguous(); + B = B.contiguous(); + + if (!D.has_value()) { + auto type = GetATenDType(D_type); + auto opts = at::TensorOptions().dtype(type).device(A.options().device()); + *D = at::empty({B.size(0), A.size(0)}, opts); + } + + auto A_scale_inv_ptr = A_scale_inv.has_value() ? A_scale_inv->data_ptr() : nullptr; + auto B_scale_inv_ptr = B_scale_inv.has_value() ? B_scale_inv->data_ptr() : nullptr; + auto te_A = makeTransformerEngineTensor( + A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_dtype, + nullptr, nullptr, A_scale_inv_ptr); + auto te_B = makeTransformerEngineTensor( + B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_dtype, + nullptr, nullptr, B_scale_inv_ptr); + auto te_D = makeTransformerEngineTensor( + D->data_ptr(), {static_cast(D->size(0)), static_cast(D->size(1))}, D_type, + D_amax.data_ptr(), D_scale.data_ptr(), nullptr); + auto te_bias = + makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, bias_type); + + const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr + ? std::vector{static_cast(pre_gelu_out.size(0))} + : std::vector{static_cast(pre_gelu_out.size(0)), + static_cast(pre_gelu_out.size(1))}; + auto te_pre_gelu_out = makeTransformerEngineTensor( + pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); + auto te_workspace = + makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); + + nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), + transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, + math_sm_count, at::cuda::getCurrentCUDAStream()); + + return {*D, pre_gelu_out}; +} + +std::vector te_gemm2(transformer_engine::Float8Tensor A, bool transa, + transformer_engine::Float8Tensor B, bool transb, + std::optional D, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + int math_sm_count) { + return te_gemm2_helper(A.data, A.dtype, A.scale_inv, transa, B.data, B.dtype, B.scale_inv, transb, + D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, + workspaceSize, accumulate, use_split_accumulator, math_sm_count); +} + +std::vector te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb, + std::optional D, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + int math_sm_count) { + transformer_engine::DType A_dtype = GetTransformerEngineDType(A.scalar_type()); + transformer_engine::DType B_dtype = GetTransformerEngineDType(B.scalar_type()); + return te_gemm2_helper(A, A_dtype, std::nullopt, transa, B, B_dtype, std::nullopt, transb, D, + D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, + workspaceSize, accumulate, use_split_accumulator, math_sm_count); +} void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7bd5a2d8c8..4772210237 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -4,12 +4,93 @@ * See LICENSE for license information. ************************************************************************/ +#include #include +#include +#include #include "../comm_gemm_overlap.h" #include "../extensions.h" +#include "pytorch/csrc/common.h" + +namespace pybind11::detail { + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(transformer_engine::Float8Tensor, _("transformer_engine.pytorch.tensor.Float8Tensor")); + + bool load(handle src, bool) { + std::cout << "Loading Float8Tensor!" << std::endl; + auto py_data = src.attr("_data"); + value.data = py_data.cast(); + auto py_transpose = src.attr("_transpose"); + if (!py_transpose.is_none()) { + value.transpose = py_transpose.cast(); + } + auto py_scale_inv = src.attr("_scale_inv"); + value.scale_inv = py_scale_inv.cast(); + auto py_dtype = src.attr("_fp8_dtype"); + value.dtype = py_dtype.cast(); + return true; + } + + static handle cast(const transformer_engine::Float8Tensor& src, return_value_policy, handle) { + throw std::runtime_error("Casting back from Float8Tensor not implemented yet!"); + return none().release(); + } +}; + +} // namespace pybind11::detail + +void test(pybind11::handle handle) { + at::Tensor t = handle.cast(); + std::cout << t.size(0) << std::endl; +} + +std::string to_string(transformer_engine::DType t) { + switch (t) { + case transformer_engine::DType::kInt32: + return "int32"; + case transformer_engine::DType::kInt64: + return "int64"; + case transformer_engine::DType::kFloat32: + return "float32"; + case transformer_engine::DType::kFloat16: + return "float16"; + case transformer_engine::DType::kBFloat16: + return "bfloat16"; + case transformer_engine::DType::kByte: + return "byte"; + case transformer_engine::DType::kFloat8E4M3: + return "float8e4m3"; + case transformer_engine::DType::kFloat8E5M2: + return "float8e5m2"; + default: + NVTE_ERROR("Invalid type"); + } +} + +void test2(transformer_engine::Float8Tensor tensor) { + //at::Tensor t = handle.cast(); + std::cout << tensor.data.size(0) << std::endl; + std::cout << tensor.scale_inv.size(0) << std::endl; + std::cout << to_string(tensor.dtype) << std::endl; +} + +template +using GemmFunc = std::vector (*)(InputType, bool, InputType, + bool, std::optional, at::Tensor, transformer_engine::DType, + at::Tensor, at::Tensor, transformer_engine::DType, + at::Tensor, bool, at::Tensor, size_t, + bool, bool, int); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("cast_test", test); + m.def("cast_test2", test2); + m.def("te_gemm2", static_cast>(&te_gemm2), "CublasLt GEMM"); + m.def("te_gemm2", static_cast>(&te_gemm2), "CublasLt GEMM"); + // Permutation functions m.def("moe_permute_fwd", moe_permute_fwd); m.def("moe_permute_bwd", moe_permute_bwd); From b751c133112f178acd069d5a284b37bbdfe3e864 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Oct 2024 21:56:28 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/cpp_extensions/gemm.py | 5 +-- transformer_engine/pytorch/csrc/extensions.h | 24 ++++++++------ .../pytorch/csrc/extensions/pybind.cpp | 31 ++++++++++--------- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 7bd3900c1a..5b379c86dc 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -25,6 +25,7 @@ def _empty_tensor() -> torch.Tensor: """Get tensor with no entries and no data""" return torch.Tensor() + def general_gemm( A: Union[torch.Tensor, Float8Tensor], B: Union[torch.Tensor, Float8Tensor], @@ -76,9 +77,9 @@ def general_gemm( B, False, # transb out, - empty_tensor, # if out_index is None else fp8_meta_tensor.scale[out_index], + empty_tensor, # if out_index is None else fp8_meta_tensor.scale[out_index], out_dtype, - empty_tensor, # if out_index is None else fp8_meta_tensor.amax_history[0][out_index], + empty_tensor, # if out_index is None else fp8_meta_tensor.amax_history[0][out_index], bias if use_bias else empty_tensor, bias_dtype, gelu_input, # this is pre_gelu_out diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 761f992f21..652c2bd01d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -139,16 +139,20 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); **************************************************************************************************/ std::vector te_gemm2(transformer_engine::Float8Tensor A, bool transa, - transformer_engine::Float8Tensor B, - bool transb, std::optional D, at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, int math_sm_count); -std::vector te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb, std::optional D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count); + transformer_engine::Float8Tensor B, bool transb, + std::optional D, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + int math_sm_count); +std::vector te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb, + std::optional D, at::Tensor D_scale, + transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + int math_sm_count); void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 4772210237..224d6bc82e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -7,6 +7,7 @@ #include #include #include + #include #include "../comm_gemm_overlap.h" @@ -18,7 +19,8 @@ namespace pybind11::detail { template <> struct type_caster { public: - PYBIND11_TYPE_CASTER(transformer_engine::Float8Tensor, _("transformer_engine.pytorch.tensor.Float8Tensor")); + PYBIND11_TYPE_CASTER(transformer_engine::Float8Tensor, + _("transformer_engine.pytorch.tensor.Float8Tensor")); bool load(handle src, bool) { std::cout << "Loading Float8Tensor!" << std::endl; @@ -35,7 +37,7 @@ struct type_caster { return true; } - static handle cast(const transformer_engine::Float8Tensor& src, return_value_policy, handle) { + static handle cast(const transformer_engine::Float8Tensor &src, return_value_policy, handle) { throw std::runtime_error("Casting back from Float8Tensor not implemented yet!"); return none().release(); } @@ -44,8 +46,8 @@ struct type_caster { } // namespace pybind11::detail void test(pybind11::handle handle) { - at::Tensor t = handle.cast(); - std::cout << t.size(0) << std::endl; + at::Tensor t = handle.cast(); + std::cout << t.size(0) << std::endl; } std::string to_string(transformer_engine::DType t) { @@ -72,23 +74,24 @@ std::string to_string(transformer_engine::DType t) { } void test2(transformer_engine::Float8Tensor tensor) { - //at::Tensor t = handle.cast(); - std::cout << tensor.data.size(0) << std::endl; - std::cout << tensor.scale_inv.size(0) << std::endl; - std::cout << to_string(tensor.dtype) << std::endl; + //at::Tensor t = handle.cast(); + std::cout << tensor.data.size(0) << std::endl; + std::cout << tensor.scale_inv.size(0) << std::endl; + std::cout << to_string(tensor.dtype) << std::endl; } template -using GemmFunc = std::vector (*)(InputType, bool, InputType, - bool, std::optional, at::Tensor, transformer_engine::DType, - at::Tensor, at::Tensor, transformer_engine::DType, - at::Tensor, bool, at::Tensor, size_t, - bool, bool, int); +using GemmFunc = std::vector (*)(InputType, bool, InputType, bool, + std::optional, at::Tensor, + transformer_engine::DType, at::Tensor, at::Tensor, + transformer_engine::DType, at::Tensor, bool, + at::Tensor, size_t, bool, bool, int); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cast_test", test); m.def("cast_test2", test2); - m.def("te_gemm2", static_cast>(&te_gemm2), "CublasLt GEMM"); + m.def("te_gemm2", static_cast>(&te_gemm2), + "CublasLt GEMM"); m.def("te_gemm2", static_cast>(&te_gemm2), "CublasLt GEMM"); // Permutation functions From 0327c7fb50270810bfae56a004047b4a889c2311 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 15 Oct 2024 20:30:57 -0700 Subject: [PATCH 3/6] More changes to GEMM Signed-off-by: Przemyslaw Tredak --- .../pytorch/cpp_extensions/gemm.py | 31 +- transformer_engine/pytorch/csrc/extensions.h | 25 +- .../pytorch/csrc/extensions/attention.cu | 1 + .../pytorch/csrc/extensions/gemm.cu | 99 +++--- .../pytorch/csrc/extensions/pybind.cpp | 8 +- .../pytorch/csrc/extensions/recipe.cu | 1 + transformer_engine/pytorch/module/linear.py | 318 +++++++----------- .../pytorch/tensor/float8_tensor.py | 4 +- 8 files changed, 212 insertions(+), 275 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 5b379c86dc..31846d4d9a 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -13,6 +13,7 @@ __all__ = [ + "general_gemm", "gemm", "fp8_gemm", "grouped_gemm", @@ -29,13 +30,11 @@ def _empty_tensor() -> torch.Tensor: def general_gemm( A: Union[torch.Tensor, Float8Tensor], B: Union[torch.Tensor, Float8Tensor], - out_dtype: torch.dtype, workspace: torch.Tensor, gelu: bool = False, accumulate: bool = False, out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - use_bias: bool = False, use_split_accumulator: bool = False, D_dtype: Optional[tex.DType] = None, ub_algo: tex.UbufOverlapAlgo = None, @@ -50,26 +49,19 @@ def general_gemm( # assert_dim_for_fp8_exec(A) # assert_dim_for_fp8_exec(B) - if out is None: - out = torch.empty( - B.shape[0], - A.shape[0], - dtype=out_dtype, - device="cuda", - ) - else: + if out is not None: if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias.dtype - if gelu: - gelu_input = torch.empty_like(out, dtype=bias_dtype) - else: - gelu_input = empty_tensor + # if gelu: + # gelu_input = torch.empty_like(out, dtype=bias_dtype) + # else: + # gelu_input = empty_tensor bias_dtype = TE_DType[bias_dtype] - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + out_dtype = TE_DType[A.dtype] if D_dtype is None else D_dtype args = ( A, @@ -77,12 +69,12 @@ def general_gemm( B, False, # transb out, - empty_tensor, # if out_index is None else fp8_meta_tensor.scale[out_index], + None, # if out_index is None else fp8_meta_tensor.scale[out_index], out_dtype, - empty_tensor, # if out_index is None else fp8_meta_tensor.amax_history[0][out_index], - bias if use_bias else empty_tensor, + None, # if out_index is None else fp8_meta_tensor.amax_history[0][out_index], + bias, bias_dtype, - gelu_input, # this is pre_gelu_out + gelu, False, # grad workspace, workspace.shape[0], @@ -165,6 +157,7 @@ def general_gemm( args = tuple(args + (extra_output_tensor,)) if ub_algo is not None and ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: out = fn(*args) + gelu_input = empty_tensor else: out, gelu_input = fn(*args) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 652c2bd01d..9216b39cac 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -8,7 +8,6 @@ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #include "common.h" -#include "common/common.h" /*************************************************************************************************** * Permutation @@ -138,21 +137,21 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); * GEMM **************************************************************************************************/ +using MaybeTensor = std::optional; + std::vector te_gemm2(transformer_engine::Float8Tensor A, bool transa, transformer_engine::Float8Tensor B, bool transb, - std::optional D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count); + MaybeTensor D, MaybeTensor D_scale, + transformer_engine::DType D_type, MaybeTensor D_amax, + MaybeTensor bias, transformer_engine::DType bias_type, + bool gelu, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator); std::vector te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb, - std::optional D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count); + MaybeTensor D, MaybeTensor D_scale, + transformer_engine::DType D_type, MaybeTensor D_amax, + MaybeTensor bias, transformer_engine::DType bias_type, + bool gelu, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator); void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index b2968a688d..2b37a8f181 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -4,6 +4,7 @@ * See LICENSE for license information. ************************************************************************/ +#include "common/common.h" #include "extensions.h" constexpr int block_size = 512; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index bc0166362d..f3e98a2647 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -6,30 +6,51 @@ #include -#include "common.h" #include "common/util/cuda_runtime.h" +#include "common/util/system.h" #include "extensions.h" #include "pytorch/csrc/common.h" #include "transformer_engine/transformer_engine.h" +namespace { + +void* get_data_ptr(MaybeTensor tensor) { + if (tensor.has_value()) return tensor->data_ptr(); + return nullptr; +} + +size_t get_size(MaybeTensor tensor, int dim) { + if (tensor.has_value()) return static_cast(tensor->size(dim)); + return 0; +} + +} // namespace + std::vector te_gemm2_helper( - at::Tensor A, transformer_engine::DType A_dtype, std::optional A_scale_inv, - bool transa, at::Tensor B, transformer_engine::DType B_dtype, - std::optional B_scale_inv, bool transb, std::optional D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { + at::Tensor A, transformer_engine::DType A_dtype, MaybeTensor A_scale_inv, bool transa, + at::Tensor B, transformer_engine::DType B_dtype, MaybeTensor B_scale_inv, bool transb, + MaybeTensor D, MaybeTensor D_scale, transformer_engine::DType D_type, MaybeTensor D_amax, + MaybeTensor bias, transformer_engine::DType bias_type, bool gelu, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator) { using namespace transformer_engine; if (A.data_ptr() == nullptr || B.data_ptr() == nullptr) { at::Tensor out; if (D.has_value() && D->data_ptr() != nullptr && !accumulate) { - D->zero_(); // TODO: Handle D without a value + D->zero_(); out = *D; + } else { + out = at::Tensor(); // TODO: Handle D without a value } - if (pre_gelu_out.data_ptr() != nullptr) pre_gelu_out.zero_(); - return {out, pre_gelu_out}; + return {out, at::Tensor()}; } + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + A = A.contiguous(); B = B.contiguous(); @@ -39,24 +60,26 @@ std::vector te_gemm2_helper( *D = at::empty({B.size(0), A.size(0)}, opts); } - auto A_scale_inv_ptr = A_scale_inv.has_value() ? A_scale_inv->data_ptr() : nullptr; - auto B_scale_inv_ptr = B_scale_inv.has_value() ? B_scale_inv->data_ptr() : nullptr; auto te_A = makeTransformerEngineTensor( A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_dtype, - nullptr, nullptr, A_scale_inv_ptr); + nullptr, nullptr, get_data_ptr(A_scale_inv)); auto te_B = makeTransformerEngineTensor( B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_dtype, - nullptr, nullptr, B_scale_inv_ptr); + nullptr, nullptr, get_data_ptr(B_scale_inv)); auto te_D = makeTransformerEngineTensor( D->data_ptr(), {static_cast(D->size(0)), static_cast(D->size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); - auto te_bias = - makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, bias_type); + get_data_ptr(D_amax), get_data_ptr(D_scale), nullptr); + auto te_bias = makeTransformerEngineTensor(get_data_ptr(bias), {get_size(bias, 0)}, bias_type); - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; + at::Tensor pre_gelu_out; + if (gelu) { + auto dtype = GetATenDType(bias_type); + auto opts = A.options().dtype(dtype); + pre_gelu_out = at::empty_like(*D, opts); + } + const auto gelu_shape = gelu ? std::vector{static_cast(pre_gelu_out.size(0)), + static_cast(pre_gelu_out.size(1))} + : std::vector{0}; auto te_pre_gelu_out = makeTransformerEngineTensor( pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); auto te_workspace = @@ -64,36 +87,34 @@ std::vector te_gemm2_helper( nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + num_math_sms, at::cuda::getCurrentCUDAStream()); return {*D, pre_gelu_out}; } std::vector te_gemm2(transformer_engine::Float8Tensor A, bool transa, - transformer_engine::Float8Tensor B, bool transb, - std::optional D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count) { + transformer_engine::Float8Tensor B, bool transb, MaybeTensor D, + MaybeTensor D_scale, transformer_engine::DType D_type, + MaybeTensor D_amax, MaybeTensor bias, + transformer_engine::DType bias_type, bool gelu, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator) { return te_gemm2_helper(A.data, A.dtype, A.scale_inv, transa, B.data, B.dtype, B.scale_inv, transb, - D, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, - workspaceSize, accumulate, use_split_accumulator, math_sm_count); + D, D_scale, D_type, D_amax, bias, bias_type, gelu, grad, workspace, + workspaceSize, accumulate, use_split_accumulator); } std::vector te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb, - std::optional D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count) { + MaybeTensor D, MaybeTensor D_scale, + transformer_engine::DType D_type, MaybeTensor D_amax, + MaybeTensor bias, transformer_engine::DType bias_type, bool gelu, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator) { transformer_engine::DType A_dtype = GetTransformerEngineDType(A.scalar_type()); transformer_engine::DType B_dtype = GetTransformerEngineDType(B.scalar_type()); return te_gemm2_helper(A, A_dtype, std::nullopt, transa, B, B_dtype, std::nullopt, transb, D, - D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, - workspaceSize, accumulate, use_split_accumulator, math_sm_count); + D_scale, D_type, D_amax, bias, bias_type, gelu, grad, workspace, + workspaceSize, accumulate, use_split_accumulator); } void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 224d6bc82e..4f206e4141 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -82,10 +82,10 @@ void test2(transformer_engine::Float8Tensor tensor) { template using GemmFunc = std::vector (*)(InputType, bool, InputType, bool, - std::optional, at::Tensor, - transformer_engine::DType, at::Tensor, at::Tensor, - transformer_engine::DType, at::Tensor, bool, - at::Tensor, size_t, bool, bool, int); + MaybeTensor, MaybeTensor, + transformer_engine::DType, MaybeTensor, MaybeTensor, + transformer_engine::DType, bool, bool, + at::Tensor, size_t, bool, bool); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cast_test", test); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cu b/transformer_engine/pytorch/csrc/extensions/recipe.cu index a130169fe7..9aff601926 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cu +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cu @@ -9,6 +9,7 @@ #include +#include "common/common.h" #include "extensions.h" void fused_amax_and_scale_update_after_reduction( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index d6406f6119..f40b9836cc 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -37,12 +37,13 @@ _fsdp_gather_tensors, ) from ..cpp_extensions import ( + general_gemm, fp8_gemm, gemm, fp8_cast_transpose_fused, cast_to_fp8, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ..float8_tensor import Float8Tensor @@ -62,10 +63,8 @@ class _Linear(torch.autograd.Function): def forward( ctx, weight: Union[Float8Tensor, torch.Tensor], - weight_fp8: Optional[Float8Tensor], inp: torch.Tensor, - bias: torch.Tensor, - use_bias: bool, + bias: Optional[torch.Tensor], is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, @@ -84,6 +83,8 @@ def forward( ub_name: str, fp8_output: bool, fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> torch.Tensor: is_input_fp8 = isinstance(inp, Float8Tensor) @@ -103,187 +104,133 @@ def forward( inputmat = cast_if_needed(inputmat, activation_dtype) inputmat_t = None inputmat_no_fp8 = inputmat - inputmat_scale_inv = None if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if isinstance(inputmat, Float8Tensor): - inputmat_scale_inv = inputmat._scale_inv - else: - inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad + if not isinstance(inputmat, Float8Tensor): + backward_needs_input = not fp8_meta["recipe"].override_linear_precision.wgrad \ + and is_grad_enabled \ + and weight.requires_grad \ and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat, inputmat_t = fp8_cast_transpose_fused( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, - ) - else: - # FP8 input for forward - inputmat = cast_to_fp8( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, - ) - - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - inputmat_scale_inv.fill_(inputmat_scale_inv.item()) + inputmat = Float8Tensor.to_float8(inputmat, + fp8_meta=fp8_meta["scaling_fwd"], + fp8_meta_index=tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype=fp8_dtype_forward, + with_transpose_cache=backward_needs_input) # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - - # Use FP8 weights - if weight_fp8 is None: - weight_fp8 = weight - assert isinstance(weight_fp8, Float8Tensor) - - if fp8_output: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_meta["scaling_fwd"], - fp8_dtype_forward, - torch.uint8, - ) + # Initialize FP8 weights if needed + weight_fp8 = weight + if fp8: + if isinstance(weight, Float8Tensor): + # Make sure transpose cache is valid, if present + # Note: Transpose cache may have been invalidated + # externally, e.g. by optimizer. + # TODO: Do we actually need this? + if weight._transpose is not None: + weight.transpose_2d( + fill_cache=True, + noop_flag=skip_fp8_weight_update, + ) else: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( - None, - None, - None, - activation_dtype, + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weight_fp8 = module.get_fp8_workspace( + tensor=weight, + fp8_meta_forward=True, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, ) - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): - if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS - if ub_obj_projout.is_fp8_ubuf(): - proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT - meta_tensor = fp8_meta["scaling_fwd"] - proj_out_tetype = fp8_dtype_forward - proj_out_pttype = torch.uint8 - ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) - else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) - - _ = fp8_gemm( - weight_fp8._data, - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - ( - inputmat_total._data - if isinstance(inputmat_total, Float8Tensor) - else inputmat_total - ), - inputmat_scale_inv, - 0, + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + else: + bias_dtype = activation_dtype + weight_fp8 = cast_if_needed(weight, activation_dtype) + + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + if not fp8 and fp8_calibration: + # amax of input + amin, amax = inputmat_total.aminmax() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( + -amin, amax + ).float() + # amax of weight + amin, amax = weight.aminmax() + fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( + -amin, amax + ).float() + + if fp8_output: + proj_out_index, meta_tensor, proj_out_tetype = ( + tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_meta["scaling_fwd"], fp8_dtype_forward, - proj_out_pttype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - out_index=proj_out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=proj_out_tetype, ) - if fp8_output: - out = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, - ) else: - # Cast for native AMP - weight = cast_if_needed(weight, activation_dtype) - bias = cast_if_needed(bias, activation_dtype) if use_bias else bias - - if fp8_calibration: - # amax of input - amin, amax = inputmat_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of weight - amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): + proj_out_index, meta_tensor, proj_out_tetype = ( + None, + None, + TE_DType[activation_dtype], + ) + + if ub_overlap_rs: + ub_obj_projout = get_ub(ub_name + "_fprop") + out = ub_obj_projout.get_ubuf_output(1) + dim_size = list(inputmat_total.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = out_features + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + if ub_obj_projout.is_p2p_overlap(): + if ub_obj_projout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_projout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS - else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - - _ = gemm( - weight, - inputmat_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, + if fp8 and ub_obj_projout.is_fp8_ubuf(): + proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT + meta_tensor = fp8_meta["scaling_fwd"] + proj_out_tetype = fp8_dtype_forward + ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) + else: + dim_size = list(inputmat_total.size()) + dim_size[1] = out_features + out = None + + out, _ = general_gemm( + weight_fp8, + inputmat_total, + get_workspace(), + bias=bias, + use_split_accumulator=_2X_ACC_FPROP, + out=out, + ub_algo=ub_algo if ub_overlap_rs else None, + ub=ub_obj_projout if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, + # out_index=proj_out_index, + # fp8_meta_tensor=meta_tensor, + D_dtype=proj_out_tetype, + ) + if fp8_output: + out = Float8Tensor( + data=out, + fp8_meta=fp8_meta, + fp8_meta_forward=True, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, + fp8_dtype=fp8_dtype_forward, + dtype=activation_dtype, ) if is_grad_enabled: @@ -321,7 +268,7 @@ def forward( ctx.save_for_backward( saved_inputmat, saved_inputmat_t, - inputmat_scale_inv, + inputmat._scale_inv, weight, weight_fp8, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, @@ -333,7 +280,7 @@ def forward( ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = use_bias + ctx.use_bias = bias is not None ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp_shape @@ -610,10 +557,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], return ( wgrad, - None, # weight_fp8 dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, grad_bias, - None, # use_bias None, # is_first_microbatch None, # fp8 None, # fp8_calibration @@ -632,6 +577,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # ub_name None, # fp8_output None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -956,32 +903,7 @@ def forward( if self.use_bias: bias_tensor = _noop_cat([self._fast_get_param(name) for name in self.bias_names]) else: - bias_tensor = self._fast_get_param(self.bias_names[0]) # Unused - - # Initialize FP8 weights if needed - weight_fp8 = None - if self.fp8: - if isinstance(weight_tensor, Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensor._transpose is not None: - weight_tensor.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_fp8 = self.get_fp8_workspace( - tensor=weight_tensor, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=self.fsdp_group, - ) + bias_tensor = None if torch.is_grad_enabled(): linear_fn = _Linear.apply @@ -991,10 +913,8 @@ def forward( args = [None] args += ( weight_tensor, - weight_fp8, inp, - bias_tensor, - self.apply_bias and not self.gemm_bias_unfused_add, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, is_first_microbatch, self.fp8, self.fp8_calibration, @@ -1013,6 +933,8 @@ def forward( self.ub_name, fp8_output, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 0b381258a4..51f0f04b9a 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -12,10 +12,10 @@ from transformer_engine_torch import DType as TE_DType from ..constants import TE_DType as torch_to_transformer_engine_dtype -from ..cpp_extensions import ( +from ..cpp_extensions.transpose import fp8_cast_transpose_fused +from ..cpp_extensions.cast import ( cast_from_fp8, cast_to_fp8, - fp8_cast_transpose_fused, ) from ..fp8 import FP8GlobalStateManager from ..utils import devices_match From c5b02a0c8c4367fd3df339967f7506ed7be4adad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Oct 2024 03:31:34 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/extensions.h | 18 ++++++++--------- .../pytorch/csrc/extensions/pybind.cpp | 7 +++---- transformer_engine/pytorch/module/linear.py | 20 +++++++++++-------- 3 files changed, 24 insertions(+), 21 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9216b39cac..7ab6225191 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -140,18 +140,18 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); using MaybeTensor = std::optional; std::vector te_gemm2(transformer_engine::Float8Tensor A, bool transa, - transformer_engine::Float8Tensor B, bool transb, - MaybeTensor D, MaybeTensor D_scale, - transformer_engine::DType D_type, MaybeTensor D_amax, - MaybeTensor bias, transformer_engine::DType bias_type, - bool gelu, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator); + transformer_engine::Float8Tensor B, bool transb, MaybeTensor D, + MaybeTensor D_scale, transformer_engine::DType D_type, + MaybeTensor D_amax, MaybeTensor bias, + transformer_engine::DType bias_type, bool gelu, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator); std::vector te_gemm2(at::Tensor A, bool transa, at::Tensor B, bool transb, MaybeTensor D, MaybeTensor D_scale, transformer_engine::DType D_type, MaybeTensor D_amax, - MaybeTensor bias, transformer_engine::DType bias_type, - bool gelu, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator); + MaybeTensor bias, transformer_engine::DType bias_type, bool gelu, + bool grad, at::Tensor workspace, size_t workspaceSize, + bool accumulate, bool use_split_accumulator); void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 4f206e4141..30980acbfa 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -81,10 +81,9 @@ void test2(transformer_engine::Float8Tensor tensor) { } template -using GemmFunc = std::vector (*)(InputType, bool, InputType, bool, - MaybeTensor, MaybeTensor, - transformer_engine::DType, MaybeTensor, MaybeTensor, - transformer_engine::DType, bool, bool, +using GemmFunc = std::vector (*)(InputType, bool, InputType, bool, MaybeTensor, + MaybeTensor, transformer_engine::DType, MaybeTensor, + MaybeTensor, transformer_engine::DType, bool, bool, at::Tensor, size_t, bool, bool); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f40b9836cc..385cc177ac 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -108,15 +108,19 @@ def forward( if fp8: fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) if not isinstance(inputmat, Float8Tensor): - backward_needs_input = not fp8_meta["recipe"].override_linear_precision.wgrad \ - and is_grad_enabled \ - and weight.requires_grad \ + backward_needs_input = ( + not fp8_meta["recipe"].override_linear_precision.wgrad + and is_grad_enabled + and weight.requires_grad and not sequence_parallel - inputmat = Float8Tensor.to_float8(inputmat, - fp8_meta=fp8_meta["scaling_fwd"], - fp8_meta_index=tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype=fp8_dtype_forward, - with_transpose_cache=backward_needs_input) + ) + inputmat = Float8Tensor.to_float8( + inputmat, + fp8_meta=fp8_meta["scaling_fwd"], + fp8_meta_index=tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype=fp8_dtype_forward, + with_transpose_cache=backward_needs_input, + ) # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: From 3697c10b732fddac2ae5faf1f2ac8ab670497e06 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Sun, 20 Oct 2024 22:52:13 -0700 Subject: [PATCH 5/6] Mock up Signed-off-by: Przemyslaw Tredak --- transformer_engine/common/recipe/__init__.py | 8 +- .../pytorch/cpp_extensions/activation.py | 29 +--- .../pytorch/cpp_extensions/gemm.py | 20 +-- transformer_engine/pytorch/csrc/common.h | 33 ++++ transformer_engine/pytorch/csrc/extensions.h | 6 + .../pytorch/csrc/extensions/cast.cu | 35 ++++ .../pytorch/csrc/extensions/gemm.cu | 45 ++++-- .../pytorch/csrc/extensions/pybind.cpp | 22 +++ transformer_engine/pytorch/csrc/pybind.h | 22 +++ transformer_engine/pytorch/fp8.py | 12 +- transformer_engine/pytorch/module/base.py | 44 +++-- transformer_engine/pytorch/module/linear.py | 108 ++++--------- .../pytorch/quantization_meta.py | 65 ++++++++ .../pytorch/quantization_params.py | 30 ++++ .../pytorch/tensor/float8_tensor.py | 152 +++++++++--------- .../pytorch/tensor/quantized_tensor.py | 32 ++++ 16 files changed, 435 insertions(+), 228 deletions(-) create mode 100644 transformer_engine/pytorch/csrc/pybind.h create mode 100644 transformer_engine/pytorch/quantization_meta.py create mode 100644 transformer_engine/pytorch/quantization_params.py diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ba276ad406..0b6d4a808f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -49,9 +49,15 @@ class _OverrideLinearPrecision(NamedTuple): dgrad: bool = False wgrad: bool = False +@dataclass +class Recipe: + """ + Abstract base recipe class. + """ + pass @dataclass() -class DelayedScaling: +class DelayedScaling(Recipe): """ Use the delayed scaling factor strategy. Use scale factor from previous iteration and record amax history of `amax_history_len` steps. diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py index f204982aa0..f2264c115b 100644 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ b/transformer_engine/pytorch/cpp_extensions/activation.py @@ -8,6 +8,8 @@ import torch import transformer_engine_torch as tex + +from quantization_params import Float8Params from ._common import canonicalize_fp8_scales __all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] @@ -15,33 +17,18 @@ def gelu( inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, + qparams: Float8Params, ) -> torch.Tensor: """GeLU with FP8 output""" - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - # Launch kernel return torch.ops.tex_ts.gelu_ts( inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, + qparams.scale, + qparams.amax, + qparams.scale_inv, + 0, + qparams.dtype, ) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 31846d4d9a..f7c8f2fb4b 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -9,7 +9,6 @@ import transformer_engine_torch as tex from ..constants import TE_DType from ..utils import assert_dim_for_fp8_exec -from ..tensor import Float8Tensor __all__ = [ @@ -28,41 +27,32 @@ def _empty_tensor() -> torch.Tensor: def general_gemm( - A: Union[torch.Tensor, Float8Tensor], - B: Union[torch.Tensor, Float8Tensor], + A: torch.Tensor, + B: torch.Tensor, workspace: torch.Tensor, + out_dtype: tex.DType, gelu: bool = False, accumulate: bool = False, out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, ub_algo: tex.UbufOverlapAlgo = None, ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None, - extra_output_tensor: Optional[torch.Tensor] = None, + ub_buffer: Optional[torch.Tensor] = None, ) -> torch.Tensor: """GEMM supporting fp8 inputs.""" empty_tensor = _empty_tensor() - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: + if out_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: raise ValueError("FP8 output not supported") - # assert_dim_for_fp8_exec(A) - # assert_dim_for_fp8_exec(B) - if out is not None: if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") # Use bfloat16 as default bias_dtype bias_dtype = torch.bfloat16 if bias is None else bias.dtype - # if gelu: - # gelu_input = torch.empty_like(out, dtype=bias_dtype) - # else: - # gelu_input = empty_tensor bias_dtype = TE_DType[bias_dtype] - out_dtype = TE_DType[A.dtype] if D_dtype is None else D_dtype - args = ( A, True, # transa diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c41193b9c6..9a66c99bb3 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -45,6 +45,7 @@ #include #include +#include "c10/util/ArrayRef.h" #include "common/util/logging.h" namespace transformer_engine { @@ -134,6 +135,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { case torch::kInt64: return transformer_engine::DType::kInt64; default: + std::cout << "Type: " << static_cast(t) << std::endl; NVTE_ERROR("Invalid type"); } } @@ -176,4 +178,35 @@ at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); void* getDataPtr(at::Tensor tensor, int offset = 0); +namespace std { + template + string to_string(const vector& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; + } + + // Torch shape -> string + template + string to_string(const c10::ArrayRef& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; + } +} // namespace std + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 7ab6225191..387e14aa0d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -364,6 +364,12 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl * Cast **************************************************************************************************/ +namespace transformer_engine::pytorch { + +py::handle cast(const at::Tensor& tensor, py::handle quantization_params); + +} // namespace transformer_engine::pytorch + at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, const int scale_offset = 0, const int amax_offset = 0, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index 47f5825866..e2ec1738d3 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -5,6 +5,41 @@ ************************************************************************/ #include "extensions.h" +#include "pybind.h" +#include "object.h" + +namespace transformer_engine::pytorch { + +namespace detail { + +bool IsFloat8QParamsType(PyObject *obj) { + return Py_TYPE(obj) == Float8QParamsClass; +} + +} // namespace detail + +py::handle cast(const at::Tensor& tensor, py::handle quantization_params) { + using namespace pybind11::literals; + init_extension(); + if (detail::IsFloat8QParamsType(quantization_params.ptr())) { + auto py_scale = quantization_params.attr("scale"); + auto py_amax = quantization_params.attr("amax"); + auto py_scale_inv = quantization_params.attr("scale_inv"); + DType type = quantization_params.attr("dtype").cast(); + const at::Tensor& scale = py_scale.cast(); + auto data = cast_to_fp8(tensor, + py_scale.cast(), + py_amax.cast(), + py_scale_inv.cast(), + type); + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + auto ret = Float8TensorClass("data"_a=data, "fp8_scale_inv"_a=py_scale_inv, "fp8_dtype"_a=type); + return ret.release(); + } + NVTE_ERROR("Invalid type of the quantization params"); +} + +} // transformer_engine::pytorch at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cu b/transformer_engine/pytorch/csrc/extensions/gemm.cu index f3e98a2647..9ab3d288f8 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cu +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cu @@ -6,6 +6,7 @@ #include +#include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/system.h" #include "extensions.h" @@ -54,34 +55,58 @@ std::vector te_gemm2_helper( A = A.contiguous(); B = B.contiguous(); + // TODO: check shapes for FP8 execution + const auto& A_shape = A.sizes(); + NVTE_CHECK(A_shape.size() == 2, + "The A tensor in matmul must have 2 dimensions (got :" + + std::to_string(A_shape.size()) + ")."); + + const auto& B_shape = B.sizes(); + // Compute the product of dimensions except for the last one + int64_t prod = 1; + for (size_t i = 0; i < B_shape.size() - 1; ++i) { + prod *= B_shape[i]; + } + + std::vector D_shape; + for (size_t i = 0; i < B_shape.size() - 1; ++i) { + D_shape.push_back(B_shape[i]); + } + D_shape.push_back(A_shape[0]); if (!D.has_value()) { auto type = GetATenDType(D_type); auto opts = at::TensorOptions().dtype(type).device(A.options().device()); - *D = at::empty({B.size(0), A.size(0)}, opts); + *D = at::empty(D_shape, opts); + } else { + NVTE_CHECK(D_shape == D->sizes(), + "Wrong shape of the provided matmul output. Expected " + + std::to_string(D_shape) + " and got " + + std::to_string(D->sizes()) + "."); } auto te_A = makeTransformerEngineTensor( A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_dtype, nullptr, nullptr, get_data_ptr(A_scale_inv)); auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_dtype, + B.data_ptr(), {static_cast(prod), static_cast(B.size(-1))}, B_dtype, nullptr, nullptr, get_data_ptr(B_scale_inv)); auto te_D = makeTransformerEngineTensor( - D->data_ptr(), {static_cast(D->size(0)), static_cast(D->size(1))}, D_type, + D->data_ptr(), {static_cast(prod), static_cast(D->size(-1))}, D_type, get_data_ptr(D_amax), get_data_ptr(D_scale), nullptr); auto te_bias = makeTransformerEngineTensor(get_data_ptr(bias), {get_size(bias, 0)}, bias_type); - at::Tensor pre_gelu_out; + MaybeTensor pre_gelu_out = std::nullopt; + DType gelu_type = bias_type; if (gelu) { auto dtype = GetATenDType(bias_type); auto opts = A.options().dtype(dtype); - pre_gelu_out = at::empty_like(*D, opts); + *pre_gelu_out = at::empty_like(*D, opts); } - const auto gelu_shape = gelu ? std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))} + const auto gelu_shape = gelu ? std::vector{static_cast(prod), + static_cast(D->size(-1))} : std::vector{0}; - auto te_pre_gelu_out = makeTransformerEngineTensor( - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); + auto te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); @@ -89,7 +114,7 @@ std::vector te_gemm2_helper( transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, num_math_sms, at::cuda::getCurrentCUDAStream()); - return {*D, pre_gelu_out}; + return {*D, pre_gelu_out.value_or(at::Tensor())}; } std::vector te_gemm2(transformer_engine::Float8Tensor A, bool transa, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 30980acbfa..4635451713 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -12,8 +12,27 @@ #include "../comm_gemm_overlap.h" #include "../extensions.h" +#include "object.h" #include "pytorch/csrc/common.h" +namespace transformer_engine::pytorch { + +PyTypeObject *Float8TensorPythonClass = nullptr; +PyTypeObject *Float8QParamsClass = nullptr; + +void init_extension() { + if (Float8TensorPythonClass) return; + auto float8tensor_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); + auto qparams_module = py::module_::import("transformer_engine.pytorch.quantization_params"); + Float8QParamsClass = reinterpret_cast(PyObject_GetAttrString(qparams_module.ptr(), + "Float8Params")); + Float8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(float8tensor_module.ptr(), "Float8Tensor")); + NVTE_CHECK(Float8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch extension."); +} + +} // namespace transformer_engine::pytorch + namespace pybind11::detail { template <> @@ -24,6 +43,8 @@ struct type_caster { bool load(handle src, bool) { std::cout << "Loading Float8Tensor!" << std::endl; + transformer_engine::pytorch::init_extension(); + if (Py_TYPE(src.ptr()) != transformer_engine::pytorch::Float8TensorPythonClass) return false; auto py_data = src.attr("_data"); value.data = py_data.cast(); auto py_transpose = src.attr("_transpose"); @@ -89,6 +110,7 @@ using GemmFunc = std::vector (*)(InputType, bool, InputType, bool, M PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("cast_test", test); m.def("cast_test2", test2); + m.def("generic_cast", transformer_engine::pytorch::cast); m.def("te_gemm2", static_cast>(&te_gemm2), "CublasLt GEMM"); m.def("te_gemm2", static_cast>(&te_gemm2), "CublasLt GEMM"); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h new file mode 100644 index 0000000000..59b7aa8a81 --- /dev/null +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -0,0 +1,22 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#include +#include +#include + +namespace transformer_engine::pytorch { + +extern PyTypeObject *Float8TensorPythonClass; +extern PyTypeObject *Float8QParamsClass; + +void init_extension(); + +} // namespace transformer_engine::pytorch + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 76679eb064..e5fb2f4f67 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.common.recipe import DelayedScaling, Format +from transformer_engine.common.recipe import DelayedScaling, Format, Recipe from .constants import dist_group_type from .utils import get_device_compute_capability @@ -33,7 +33,7 @@ def check_fp8_support() -> Tuple[bool, str]: return True, "" -def get_default_fp8_recipe() -> DelayedScaling: +def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" return DelayedScaling() @@ -486,7 +486,8 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager -def fp8_model_init(enabled: bool = True) -> None: +def fp8_model_init(enabled: bool = True, + recipe: Optional[Recipe] = None) -> None: """ Context manager for FP8 initialization of parameters. @@ -510,15 +511,20 @@ def fp8_model_init(enabled: bool = True) -> None: precision copies of weights are already present in the optimizer. * inference, where only the FP8 copies of the parameters are used. * LoRA-like fine-tuning, where the main parameters of the model do not change. + recipe: transformer_engine.common.recipe.Recipe, default = `None` + Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE FP8GlobalStateManager.FP8_PARAMETERS = enabled + FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe @contextmanager diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 85fae4798c..fc2bb5c60b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -21,7 +21,6 @@ from ._common import _ParameterInitMeta from ..export import is_in_onnx_export_mode from ..fp8 import ( - get_default_fp8_recipe, get_fp8_te_dtype, FP8GlobalStateManager, ) @@ -38,6 +37,8 @@ ) from ..constants import dist_group_type from ..float8_tensor import Float8Tensor +from ..quantization_meta import QMeta +from transformer_engine.common.recipe import Recipe __all__ = ["initialize_ub", "destroy_ub"] @@ -396,7 +397,6 @@ def __init__(self) -> None: self.fp8_meta = {} self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None - self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta_tensors_initialized = False self.tp_group = None self.tp_size = 1 @@ -482,37 +482,28 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[meta_key].amax_history ) - def set_meta_tensor(self, fwd: bool) -> None: + def set_meta_tensor(self, + fwd: bool, + recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" if self.fp8_meta_tensors_initialized: # Handle changed amax history size. - self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) + self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd) return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 - self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() - self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones( - num_fp8_tensors, dtype=torch.float32, device="cuda" - ) - self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones( - num_fp8_tensors, dtype=torch.float32, device="cuda" - ) - self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros( - self.fp8_meta["recipe"].amax_history_len, - num_fp8_tensors, - dtype=torch.float32, - device="cuda", - ) + self.fp8_meta[fp8_meta_tensor_key] = QMeta(recipe, num_fp8_tensors) - def init_fp8_meta_tensors(self) -> None: + def init_fp8_meta_tensors(self, + recipe: Recipe) -> None: """Init scales and amaxes.""" - self.set_meta_tensor(True) - self.set_meta_tensor(False) + self.set_meta_tensor(True, recipe) + self.set_meta_tensor(False, recipe) self.fp8_meta_tensors_initialized = True def get_fp8_meta_tensors(self) -> None: @@ -607,7 +598,7 @@ def set_extra_state(self, state: torch.Tensor) -> None: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] # Initialize before loading. - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"]) self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) @@ -667,9 +658,13 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + if (self.fp8_parameters or self.fp8): + self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + + if self.fp8_parameters and not self.fp8_initialized: self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) if self.fp8 or self.fp8_calibration: # FP8 init has already been run and recipe is the same, don't do anything. @@ -679,8 +674,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ): return - # Set FP8, recipe, and other FP8 metadata - self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + # Set FP8 and other FP8 metadata self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() @@ -689,7 +683,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) self.fp8_initialized = True else: # If fp8 isn't enabled, turn off and return. diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 385cc177ac..3bba2c8305 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -22,7 +22,6 @@ from ..utils import ( divide, cast_if_needed, - assert_dim_for_fp8_exec, clear_tensor_data, init_method_constant, requires_grad, @@ -40,8 +39,6 @@ general_gemm, fp8_gemm, gemm, - fp8_cast_transpose_fused, - cast_to_fp8, ) from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo @@ -62,7 +59,7 @@ class _Linear(torch.autograd.Function): @staticmethod def forward( ctx, - weight: Union[Float8Tensor, torch.Tensor], + weight: torch.Tensor, inp: torch.Tensor, bias: Optional[torch.Tensor], is_first_microbatch: Union[bool, None], @@ -89,38 +86,30 @@ def forward( is_input_fp8 = isinstance(inp, Float8Tensor) # Make sure input dimensions are compatible - out_features, in_features = weight.shape + _, in_features = weight.shape inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" - inputmat = inp.view(-1, in_features) - if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(weight) tp_world_size = get_distributed_world_size(tp_group) ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs # Cast input to expected dtype - inputmat = cast_if_needed(inputmat, activation_dtype) - inputmat_t = None - inputmat_no_fp8 = inputmat + inputmat = cast_if_needed(inp, activation_dtype) + + if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad: + inputmat_no_fp8 = inputmat if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if not isinstance(inputmat, Float8Tensor): + if not isinstance(inputmat, QuantizedTensor): backward_needs_input = ( not fp8_meta["recipe"].override_linear_precision.wgrad and is_grad_enabled and weight.requires_grad and not sequence_parallel ) - inputmat = Float8Tensor.to_float8( - inputmat, - fp8_meta=fp8_meta["scaling_fwd"], - fp8_meta_index=tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype=fp8_dtype_forward, - with_transpose_cache=backward_needs_input, - ) + meta = fp8_meta["scaling_fwd"] + inputmat = meta.quantize(inputmat, tex.FP8FwdTensors.GEMM1_INPUT, + columnwise = backward_needs_input) # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: @@ -174,25 +163,16 @@ def forward( ).float() if fp8_output: - proj_out_index, meta_tensor, proj_out_tetype = ( - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_meta["scaling_fwd"], - fp8_dtype_forward, - ) + meta_tensor = fp8_meta["scaling_fwd"] + out_index = tex.FP8FwdTensors.GEMM1_OUTPUT + qparams = meta_tensor.get_quantization_params(out_index) else: - proj_out_index, meta_tensor, proj_out_tetype = ( - None, - None, - TE_DType[activation_dtype], - ) + qparams = None if ub_overlap_rs: + # I think this should be inside the gemm call rather than linear ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + ub_buffer = ub_obj_projout.get_ubuf_output(1) if ub_obj_projout.is_p2p_overlap(): if ub_obj_projout.is_atomic_gemm(): ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P @@ -204,50 +184,28 @@ def forward( else: ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if fp8 and ub_obj_projout.is_fp8_ubuf(): - proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT - meta_tensor = fp8_meta["scaling_fwd"] - proj_out_tetype = fp8_dtype_forward - ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) - else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = None + assert fp8_output + ub_obj_projout.set_ubuf_scale_inv(qparams.scale_inv) out, _ = general_gemm( weight_fp8, inputmat_total, get_workspace(), + quantization_params=qparams, + out_dtype=activation_dtype, bias=bias, use_split_accumulator=_2X_ACC_FPROP, - out=out, ub_algo=ub_algo if ub_overlap_rs else None, ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - # out_index=proj_out_index, - # fp8_meta_tensor=meta_tensor, - D_dtype=proj_out_tetype, + ub_buffer=ub_buffer if ub_overlap_rs else None, ) - if fp8_output: - out = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, - ) if is_grad_enabled: saved_inputmat = None - saved_inputmat_t = None if weight.requires_grad: if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmat_t is None: - saved_inputmat = inputmat - else: - saved_inputmat_t = inputmat_t - if cpu_offloading: - saved_inputmat_t.activation_offloading = True + inputmat.update_usage(rowwise=False) + saved_inputmat = inputmat else: saved_inputmat = inputmat_no_fp8 @@ -265,15 +223,11 @@ def forward( ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, saved_inputmat, # None if fp8 == False - saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, ) ctx.save_for_backward( saved_inputmat, - saved_inputmat_t, - inputmat._scale_inv, - weight, weight_fp8, weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, ) @@ -303,15 +257,13 @@ def forward( ) # Row Parallel Linear - if ub_overlap_rs: - out = rs_out - elif parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) - - # [*, in_features] -> [*, out_features] except first dimension changes for SP - return out.view(-1, *inp_shape[1:-1], out_features) + if not ub_overlap_rs: + if parallel_mode == "row" and sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif parallel_mode == "row" and tensor_parallel: + out, _ = allreduce(out, tp_group) + + return out @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: diff --git a/transformer_engine/pytorch/quantization_meta.py b/transformer_engine/pytorch/quantization_meta.py new file mode 100644 index 0000000000..24cf9f325b --- /dev/null +++ b/transformer_engine/pytorch/quantization_meta.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantization metadata class""" + +from transformer_engine.common.recipe import ( + Recipe, DelayedScaling +) +import torch + +import transformer_engine_torch as tex + +from fp8 import get_fp8_te_dtype +from tensor import QuantizedTensor, Float8Tensor + +from .quantization_params import Float8Params + +class QMeta: + def __init__(self, + recipe: Recipe, + num_tensors: int, + forward: bool): + if isinstance(recipe, DelayedScaling): + self.recipe_type = DelayedScaling + self.scale = torch.ones(num_tensors, dtype=torch.float32, device="cuda") + self.scale_inv = torch.ones(num_tensors, dtype=torch.float32, device="cuda") + self.amax_history = torch.zeros( + recipe.amax_history_len, + num_tensors, + dtype=torch.float32, + device="cuda", + ) + self.fp8_type = get_fp8_te_dtype(recipe, forward) + raise ValueError("Unknown recipe type.") + + def quantize(self, + tensor: torch.Tensor, + index: int, + *, + rowwise: bool = True, + columnwise: bool = True) -> QuantizedTensor: + if self.recipe_type == DelayedScaling: + return Float8Tensor.quantize(tensor, + self.get_quantization_params(index), + rowwise_usage=rowwise, + columnwise_usage=columnwise) + raise NotImplementedError("Not implemented yet!") + + def quantize_param(self, + tensor: torch.Tensor, + index: int): + if self.recipe_type == DelayedScaling: + pass + raise NotImplementedError("Not implemented yet! Same as quantize but also sets proxy") + + def get_quantization_params(self, + index: int): + # Could be cached + if self.recipe_type == DelayedScaling: + return Float8Params(scale=self.scale[index], + amax=self.amax_history[0][index], + scale_inv=self.scale_inv[index], + dtype=self.fp8_type) + raise NotImplementedError("Not implemented yet!") diff --git a/transformer_engine/pytorch/quantization_params.py b/transformer_engine/pytorch/quantization_params.py new file mode 100644 index 0000000000..1ad57ce4e1 --- /dev/null +++ b/transformer_engine/pytorch/quantization_params.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Parameters needed for quantization using different recipes.""" + +import torch +from transformer_engine_torch import DType as TE_DType + +class QuantizationParams: + def __init__(self): + pass + +class Float8Params(QuantizationParams): + scale: torch.Tensor + amax: torch.Tensor + scale_inv: torch.Tensor + dtype: TE_DType + + def __init__(self, + scale: torch.Tensor, + amax: torch.Tensor, + scale_inv: torch.Tensor, + dtype: TE_DType): + super().__init__() + self.scale = scale + self.amax = amax + self.scale_inv = scale_inv + self.dtype = dtype + diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 51f0f04b9a..e007c23134 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -19,11 +19,28 @@ ) from ..fp8 import FP8GlobalStateManager from ..utils import devices_match -from .quantized_tensor import QuantizedTensor +from ..quantization_params import QuantizationParams +from .quantized_tensor import QuantizationParamsProxy, QuantizedTensor aten = torch.ops.aten updated_fp8_params = {} +class Float8ParamsProxy(QuantizationParamsProxy): + def __init__(self, + meta, + index, + dtype): + super().__init__() + self.meta = meta + self.index = index + self.dtype = dtype + + def get_quantization_params(self) -> QuantizationParams: + return Float8Params( + self.meta.scale[self.index], + self.meta.amax[0][self.index], + self.meta.scale_inv[self.index], + self.dtype) def _make_fp8_attr_property_funcs(name: str) -> Any: """Make accessors for an FP8 attribute @@ -103,14 +120,14 @@ class _ToFloat8Func(torch.autograd.Function): def forward( _ctx: torch.autograd.function.FunctionCtx, # unused tensor: torch.Tensor, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - with_transpose_cache: bool = False, + scale: torch.Tensor, + amax: torch.Tensor, + scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + rowwise_usage: bool = True, + columnwise_usage: bool = True, + *, + proxy: Optional[Float8ParamsProxy] = None, ) -> Float8Tensor: # Tensor attributes @@ -125,20 +142,16 @@ def forward( data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) # Check scale - if scale is None and fp8_meta is None: - scale = torch.full([1], 1, dtype=torch.float32, device=device) - if scale is not None: + if not devices_match(scale.device, device) or scale.dtype != dtype: scale = scale.to(device=device, dtype=torch.float32) # Check scale-inverse - if scale_inv is None: - scale_inv = torch.empty([1], dtype=torch.float32, device=device) - elif not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: + if not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: scale_inv = scale_inv.to(device=device, dtype=torch.float32) # Transpose cache data_transpose = None - if with_transpose_cache: + if columnwise_usage: data_transpose = torch.empty( (data.size(-1), data.numel() // data.size(-1)), dtype=torch.uint8, @@ -148,13 +161,11 @@ def forward( # Construct FP8 tensor out = Float8Tensor( data=data, - fp8_meta=fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, fp8_scale_inv=scale_inv, + fp8_dtype=fp8_dtype, dtype=dtype, data_transpose=data_transpose, + proxy=proxy, ) # Cast to FP8 tensor @@ -332,11 +343,9 @@ class Float8Tensor(QuantizedTensor): """ - _data: torch.Tensor + _data: Optional[torch.Tensor] _fp8_attrs: Dict[str, Any] - _fp8_meta: Optional[Dict[str, Any]] - _fp8_meta_forward: bool - _fp8_meta_index: Optional[int] + _proxy: Optional[Float8ParamsProxy] _fp8_dtype: TE_DType _scale_inv: torch.Tensor @@ -348,17 +357,14 @@ def __new__( cls, *, data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - fp8_scale_inv: Optional[torch.Tensor] = None, + fp8_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, dtype: torch.dtype = torch.float32, requires_grad: bool = False, data_transpose: Optional[torch.Tensor] = None, + proxy: Optional[Float8ParamsProxy] = None, + fp8_attrs: Optional[Dict[str, Any]] = None, ): - # Check that data buffer is valid if data.element_size() != 1: raise ValueError( @@ -392,15 +398,7 @@ def __new__( self._fp8_attrs = fp8_attrs return self - # FP8 meta tensors - if fp8_meta is not None and fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - self._fp8_meta = fp8_meta - self._fp8_meta_forward = fp8_meta_forward - self._fp8_meta_index = fp8_meta_index + self._proxy = proxy # FP8 dtype assert fp8_dtype in ( @@ -410,16 +408,6 @@ def __new__( self._fp8_dtype = fp8_dtype # FP8 scale-inverse - if fp8_scale_inv is None and self._fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() - if fp8_scale_inv is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying scale-inverse" - ) if fp8_scale_inv.numel() != 1: raise ValueError( "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" @@ -457,9 +445,7 @@ def make_like( """ default_kwargs = dict( - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, + proxy=tensor._proxy, fp8_dtype=tensor._fp8_dtype, fp8_scale_inv=tensor._scale_inv, dtype=tensor.dtype, @@ -487,6 +473,7 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: # Make sure FP8 data is in expected format data = self._data + assert data is not None if data.device.type != "cuda": data = data.cuda() if not data.is_contiguous(): @@ -555,7 +542,11 @@ def quantize_( return dst.quantize_(src.dequantize()) # Directly copy FP8 data - dst._data.copy_(src._data.detach()) + if src._data is not None: + assert dst._data is not None + dst._data.copy_(src._data.detach()) + else: + dst._data = None dst._scale_inv.copy_(src._scale_inv.detach()) if amax is not None or dst._fp8_meta is not None: src_amax: torch.Tensor @@ -630,6 +621,7 @@ def quantize_( fp8_meta = dst._fp8_meta[fp8_meta_key] # Check local data + assert dst._data is not None if not dst._data.is_contiguous(): raise RuntimeError("Transformer Engine cast kernels require contiguous data") @@ -670,30 +662,26 @@ def quantize_( return self @classmethod - def to_float8( + def quantize( cls, tensor: torch.Tensor, + params: QuantizationParams, *, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - with_transpose_cache: bool = False, - ): + proxy: Optional[QuantizationParamsProxy] = None, + rowwise_usage: bool = True, + columnwise_usage: bool = True) -> QuantizedTensor: """Construct Float8Tensor from plain PyTorch tensor""" + assert isinstance(params, Float8Params), \ + f"Invalid quantization params type: {type(params)}" return _ToFloat8Func.apply( tensor, - fp8_meta, - fp8_meta_forward, - fp8_meta_index, - fp8_dtype, - scale, - amax, - scale_inv, - with_transpose_cache, + params.scale, + params.amax, + params.scale_inv, + params.dtype, + rowwise_usage, + columnwise_usage, + proxy=proxy, ) def detach(self) -> Float8Tensor: @@ -703,6 +691,22 @@ def detach(self) -> Float8Tensor: fp8_attrs=self._fp8_attrs, ) + def update_usage(self, rowwise=True, columnwise=True): + if rowwise: + assert self._data is not None, \ + "Rowwise usage of the tensor was already disabled" + else: + if not columnwise or (self._transpose is not None and + not self._transpose_invalid): + self._data = None + if columnwise: + assert (self._transpose is not None and + not self._transpose_invalid) or self._data is not None, \ + "The tensor does not hold any data anymore" + else: + self._transpose = None + self._transpose_invalid = True + def clone(self) -> Float8Tensor: data = self._data.detach().clone() data_transpose = None @@ -1027,9 +1031,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Note: We store FP8 attributes in a dictionary so we can share # them between tensors with the same data, e.g. detached tensors. # For convenience, we also expose them as property attributes. - _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) - _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) - _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) + _proxy = property(**_make_fp8_attr_property_funcs("proxy")) _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) _transpose = property(**_make_fp8_attr_property_funcs("transpose")) _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index f890b0878a..a6849b658d 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -10,6 +10,16 @@ import torch from torch.utils._pytree import tree_map +from ..quantization_params import QuantizationParams + +class QuantizationParamsProxy: + def __init__(self): + pass + + def get_quantization_params(self) -> QuantizationParams: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement get_quantization_params function" + ) class _DequantizeFunc(torch.autograd.Function): """Autograd function to convert quantized tensor to standard tensor""" @@ -70,6 +80,18 @@ def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor: f"{self.__class__.__name__} class does not implement quantize_ function" ) + @classmethod + def quantize(cls, + tensor: torch.Tensor, + params: QuantizationParams, + *, + proxy: Optional[QuantizationParamsProxy] = None, + rowwise_usage: bool = True, + columnwise_usage: bool = True) -> QuantizedTensor: + raise NotImplementedError( + f"{cls.__name__} class does not implement quantize function" + ) + def detach(self) -> QuantizedTensor: """Create new quantized tensor with same data @@ -81,6 +103,16 @@ def detach(self) -> QuantizedTensor: f"{self.__class__.__name__} class does not implement detach function" ) + def update_usage(self, rowwise=True, columnwise=True): + """Indicate to the tensor how it is going to be used. + This enables optimizations to memory usage in some cases + where forward and backward passes use the tensor in + different directions. + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_usage function" + ) + def __repr__(self) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" From 4c8ce6b03a12caf45caadef247e967c24f4affc4 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 22 Oct 2024 16:44:03 -0700 Subject: [PATCH 6/6] More changes to Float8Tensor Signed-off-by: Przemyslaw Tredak --- .../pytorch/cpp_extensions/activation.py | 13 +- .../pytorch/cpp_extensions/gemm.py | 10 +- transformer_engine/pytorch/csrc/extensions.h | 11 +- .../pytorch/csrc/extensions/activation.cu | 3 +- .../pytorch/csrc/extensions/cast.cu | 72 +++- .../pytorch/csrc/extensions/pybind.cpp | 5 +- .../pytorch/csrc/extensions/transpose.cu | 36 +- transformer_engine/pytorch/module/linear.py | 68 ++-- .../pytorch/quantization_meta.py | 15 +- .../pytorch/quantization_params.py | 3 - .../pytorch/tensor/float8_tensor.py | 342 ++++-------------- 11 files changed, 202 insertions(+), 376 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py index f2264c115b..3f5d4807ea 100644 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ b/transformer_engine/pytorch/cpp_extensions/activation.py @@ -9,7 +9,7 @@ import transformer_engine_torch as tex -from quantization_params import Float8Params +from ..quantization_params import QuantizationParams from ._common import canonicalize_fp8_scales __all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] @@ -17,19 +17,14 @@ def gelu( inp: torch.Tensor, - qparams: Float8Params, + qparams: QuantizationParams, ) -> torch.Tensor: """GeLU with FP8 output""" # Launch kernel - return torch.ops.tex_ts.gelu_ts( + return tex.gelu( inp, - qparams.scale, - qparams.amax, - qparams.scale_inv, - 0, - qparams.dtype, - ) + qparams) def relu( diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index f7c8f2fb4b..ff02dc0ce9 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -6,6 +6,7 @@ import functools from typing import Optional, Tuple, Union, List import torch +from ..quantization_params import QuantizationParams import transformer_engine_torch as tex from ..constants import TE_DType from ..utils import assert_dim_for_fp8_exec @@ -30,7 +31,8 @@ def general_gemm( A: torch.Tensor, B: torch.Tensor, workspace: torch.Tensor, - out_dtype: tex.DType, + out_dtype: torch.dtype, + quantization_params: Optional[QuantizationParams] = None, gelu: bool = False, accumulate: bool = False, out: Optional[torch.Tensor] = None, @@ -43,8 +45,7 @@ def general_gemm( """GEMM supporting fp8 inputs.""" empty_tensor = _empty_tensor() - if out_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - raise ValueError("FP8 output not supported") + assert quantization_params is None, "FP8 output not supported yet" if out is not None: if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") @@ -59,9 +60,8 @@ def general_gemm( B, False, # transb out, - None, # if out_index is None else fp8_meta_tensor.scale[out_index], + quantization_params, out_dtype, - None, # if out_index is None else fp8_meta_tensor.amax_history[0][out_index], bias, bias_dtype, gelu, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 387e14aa0d..b61461c549 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -7,6 +7,7 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ +#include #include "common.h" /*************************************************************************************************** @@ -235,9 +236,7 @@ std::tuple, std::vector> fused_multi_cast_tr std::vector scale_indices, std::vector amax_indices, std::vector scale_inv_indices, transformer_engine::DType otype); -at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype); - -void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype); +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, std::optional output = std::nullopt); void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, transformer_engine::DType otype); @@ -366,7 +365,11 @@ at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, fl namespace transformer_engine::pytorch { -py::handle cast(const at::Tensor& tensor, py::handle quantization_params); +py::handle cast(const at::Tensor& tensor, + py::handle quantization_params, + bool rowwise_usage, + bool columnwise_usage, + py::handle proxy); } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cu b/transformer_engine/pytorch/csrc/extensions/activation.cu index 7f8cff5584..84a231d4a0 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cu +++ b/transformer_engine/pytorch/csrc/extensions/activation.cu @@ -6,7 +6,8 @@ #include "extensions.h" -at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, +at::Tensor gelu(at::Tensor input, + at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype) { using namespace transformer_engine; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index e2ec1738d3..41dbec8f95 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -4,9 +4,11 @@ * See LICENSE for license information. ************************************************************************/ +#include "c10/core/ScalarType.h" #include "extensions.h" #include "pybind.h" #include "object.h" +#include "torch/types.h" namespace transformer_engine::pytorch { @@ -16,25 +18,77 @@ bool IsFloat8QParamsType(PyObject *obj) { return Py_TYPE(obj) == Float8QParamsClass; } +bool IsFloatingPointType(at::ScalarType type) { + return type == at::kFloat || + type == at::kHalf || + type == at::kBFloat16; +} + } // namespace detail -py::handle cast(const at::Tensor& tensor, py::handle quantization_params) { +py::handle cast(const at::Tensor& tensor, + py::handle quantization_params, + bool rowwise_usage, + bool columnwise_usage, + py::handle proxy + ) { using namespace pybind11::literals; init_extension(); + auto input_tensor = tensor.contiguous(); + NVTE_CHECK(rowwise_usage || columnwise_usage, + "Could not create a QuantizedTensor with no usage."); if (detail::IsFloat8QParamsType(quantization_params.ptr())) { auto py_scale = quantization_params.attr("scale"); auto py_amax = quantization_params.attr("amax"); - auto py_scale_inv = quantization_params.attr("scale_inv"); DType type = quantization_params.attr("dtype").cast(); const at::Tensor& scale = py_scale.cast(); - auto data = cast_to_fp8(tensor, - py_scale.cast(), - py_amax.cast(), - py_scale_inv.cast(), - type); + auto opts = input_tensor.options().dtype(torch::kFloat32); + at::Tensor scale_inv = at::empty({1}, opts); + at::Tensor data, data_transpose; + if (columnwise_usage) { + const auto dim = tensor.dim(); + NVTE_CHECK(dim >= 2, "Tensor needs to be at least 2D for columnwise usage"); + auto reshaped_input = input_tensor.view({-1, tensor.size(dim - 1)}); + auto data_opts = input_tensor.options().dtype(torch::kUInt8); + data = at::empty_like(input_tensor, data_opts); + data_transpose = at::empty({reshaped_input.size(1), + reshaped_input.size(0)}, + data_opts); + fused_cast_transpose(reshaped_input, + py_scale.cast(), + py_amax.cast(), + scale_inv, + data, + data_transpose, + type); + } else { + data = cast_to_fp8(input_tensor, + py_scale.cast(), + py_amax.cast(), + scale_inv, + type); + } + auto fake_tensor_type = tensor.scalar_type(); + if (!detail::IsFloatingPointType(fake_tensor_type)) { + fake_tensor_type = at::kFloat; + } py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - auto ret = Float8TensorClass("data"_a=data, "fp8_scale_inv"_a=py_scale_inv, "fp8_dtype"_a=type); - return ret.release(); + if (columnwise_usage) { + auto ret = Float8TensorClass("data"_a=data, + "data_transpose"_a=data_transpose, + "fp8_scale_inv"_a=scale_inv, + "fp8_dtype"_a=type, + "dtype"_a=fake_tensor_type, + "proxy"_a=proxy); + return ret.release(); + } else { + auto ret = Float8TensorClass("data"_a=data, + "fp8_scale_inv"_a=scale_inv, + "fp8_dtype"_a=type, + "dtype"_a=fake_tensor_type, + "proxy"_a=proxy); + return ret.release(); + } } NVTE_ERROR("Invalid type of the quantization params"); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 4635451713..740537eb5e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -13,6 +13,7 @@ #include "../comm_gemm_overlap.h" #include "../extensions.h" #include "object.h" +#include "pybind11/cast.h" #include "pytorch/csrc/common.h" namespace transformer_engine::pytorch { @@ -229,8 +230,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V", py::call_guard()); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", - py::call_guard()); - m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O", + py::arg("input"), py::arg("dtype"), + py::kw_only(), py::arg("out"), py::call_guard()); m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, "Transpose with FP8 I/O with noop option.", py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cu b/transformer_engine/pytorch/csrc/extensions/transpose.cu index 56f6b56769..7790e7ca47 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cu +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cu @@ -4,6 +4,8 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include "ATen/core/TensorBody.h" #include "extensions.h" void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, @@ -323,33 +325,33 @@ std::tuple, std::vector> fused_multi_cast_tr return std::make_tuple(std::move(cast_output_list), std::move(transposed_output_list)); } -at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) { +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, std::optional output) { using namespace transformer_engine; - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - auto output = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); - if (M == 0 || N == 0) return output; - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + const auto dim = input.dim(); + NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); - return output; -} - -void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype) { - using namespace transformer_engine; + if (input.dim() > 2) { + input = input.view({-1, input.size(dim - 1)}); + } size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); + at::Tensor out; + if (output.has_value()) { + out = *output; + } else { + out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + } + if (M == 0 || N == 0) return out; + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), {N, M}, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return out; } void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3bba2c8305..48a47b2975 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -96,47 +96,39 @@ def forward( # Cast input to expected dtype inputmat = cast_if_needed(inp, activation_dtype) - if not fp8 or fp8_meta["recipe"].override_linear_precision.wgrad: - inputmat_no_fp8 = inputmat + inputmat_no_fp8 = inputmat + + backward_needs_input = ( + is_grad_enabled + and weight.requires_grad + ) + own_quantized_input = not isinstance(inputmat, QuantizedTensor) if fp8: - if not isinstance(inputmat, QuantizedTensor): - backward_needs_input = ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel - ) + if own_quantized_input: meta = fp8_meta["scaling_fwd"] + backward_needs_fp8_input = (backward_needs_input and + not fp8_meta["recipe"].override_linear_precision.wgrad) inputmat = meta.quantize(inputmat, tex.FP8FwdTensors.GEMM1_INPUT, - columnwise = backward_needs_input) + columnwise = backward_needs_fp8_input) # Column Parallel Linear if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) + inputmat_total, _ = gather_along_first_dim(inputmat, tp_group, columnwise=False) else: inputmat_total = inputmat # Initialize FP8 weights if needed weight_fp8 = weight if fp8: - if isinstance(weight, Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - # TODO: Do we actually need this? - if weight._transpose is not None: - weight.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: + if not isinstance(weight, QuantizedTensor): # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch weight_fp8 = module.get_fp8_workspace( tensor=weight, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + quantization_params=fp8_meta["scaling_fwd"].get_quantization_params( + tex.FP8FwdTensors.GEMM1_WEIGHT + ), cache_name=(None if is_first_microbatch is None else "weight"), update_workspace=update_workspace, skip_update_flag=skip_fp8_weight_update, @@ -151,16 +143,8 @@ def forward( bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias if not fp8 and fp8_calibration: - # amax of input - amin, amax = inputmat_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of weight - amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() + fp8_meta["scaling_fwd"].calibrate(inputmat_total, tex.FP8FwdTensors.GEMM1_INPUT) + fp8_meta["scaling_fwd"].calibrate(weight, tex.FP8FwdTensors.GEMM1_WEIGHT) if fp8_output: meta_tensor = fp8_meta["scaling_fwd"] @@ -202,20 +186,20 @@ def forward( if is_grad_enabled: saved_inputmat = None - if weight.requires_grad: - if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: + if backward_needs_input: + if own_quantized_input and isinstance(inputmat, QuantizedTensor): inputmat.update_usage(rowwise=False) + if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: saved_inputmat = inputmat else: saved_inputmat = inputmat_no_fp8 - if cpu_offloading: - if fp8 and weight_fp8 is not None: - weight_fp8.weight_offloading = True - weight.weight_offloading = True + if cpu_offloading: + weight_fp8.weight_offloading = True + weight.weight_offloading = True - if saved_inputmat is not None: - saved_inputmat.activation_offloading = True + if saved_inputmat is not None: + saved_inputmat.activation_offloading = True # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights diff --git a/transformer_engine/pytorch/quantization_meta.py b/transformer_engine/pytorch/quantization_meta.py index 24cf9f325b..0139529966 100644 --- a/transformer_engine/pytorch/quantization_meta.py +++ b/transformer_engine/pytorch/quantization_meta.py @@ -11,8 +11,8 @@ import transformer_engine_torch as tex -from fp8 import get_fp8_te_dtype -from tensor import QuantizedTensor, Float8Tensor +from .fp8 import get_fp8_te_dtype +from .tensor import QuantizedTensor, Float8Tensor from .quantization_params import Float8Params @@ -24,7 +24,6 @@ def __init__(self, if isinstance(recipe, DelayedScaling): self.recipe_type = DelayedScaling self.scale = torch.ones(num_tensors, dtype=torch.float32, device="cuda") - self.scale_inv = torch.ones(num_tensors, dtype=torch.float32, device="cuda") self.amax_history = torch.zeros( recipe.amax_history_len, num_tensors, @@ -60,6 +59,14 @@ def get_quantization_params(self, if self.recipe_type == DelayedScaling: return Float8Params(scale=self.scale[index], amax=self.amax_history[0][index], - scale_inv=self.scale_inv[index], dtype=self.fp8_type) raise NotImplementedError("Not implemented yet!") + + def calibrate(self, + tensor: torch.Tensor, + index: int): + if self.recipe_type == DelayedScaling: + amin, amax = tensor.aminmax() + self.amax_history[0][index] = torch.max(-amin, amax).float() + return + raise NotImplementedError("Not implemented yet!") diff --git a/transformer_engine/pytorch/quantization_params.py b/transformer_engine/pytorch/quantization_params.py index 1ad57ce4e1..10c6d81571 100644 --- a/transformer_engine/pytorch/quantization_params.py +++ b/transformer_engine/pytorch/quantization_params.py @@ -14,17 +14,14 @@ def __init__(self): class Float8Params(QuantizationParams): scale: torch.Tensor amax: torch.Tensor - scale_inv: torch.Tensor dtype: TE_DType def __init__(self, scale: torch.Tensor, amax: torch.Tensor, - scale_inv: torch.Tensor, dtype: TE_DType): super().__init__() self.scale = scale self.amax = amax - self.scale_inv = scale_inv self.dtype = dtype diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index e007c23134..c733d064a9 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -8,18 +8,18 @@ import warnings import torch +from torch._prims_common import is_contiguous import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from ..constants import TE_DType as torch_to_transformer_engine_dtype from ..cpp_extensions.transpose import fp8_cast_transpose_fused from ..cpp_extensions.cast import ( - cast_from_fp8, cast_to_fp8, ) from ..fp8 import FP8GlobalStateManager from ..utils import devices_match -from ..quantization_params import QuantizationParams +from ..quantization_params import Float8Params, QuantizationParams from .quantized_tensor import QuantizationParamsProxy, QuantizedTensor aten = torch.ops.aten @@ -39,7 +39,6 @@ def get_quantization_params(self) -> QuantizationParams: return Float8Params( self.meta.scale[self.index], self.meta.amax[0][self.index], - self.meta.scale_inv[self.index], self.dtype) def _make_fp8_attr_property_funcs(name: str) -> Any: @@ -120,56 +119,24 @@ class _ToFloat8Func(torch.autograd.Function): def forward( _ctx: torch.autograd.function.FunctionCtx, # unused tensor: torch.Tensor, - scale: torch.Tensor, - amax: torch.Tensor, - scale_inv: torch.Tensor, - fp8_dtype: TE_DType, + qparams: Float8Params, rowwise_usage: bool = True, columnwise_usage: bool = True, - *, proxy: Optional[Float8ParamsProxy] = None, ) -> Float8Tensor: # Tensor attributes - dtype = tensor.dtype - if dtype not in (torch.float32, torch.bfloat16, torch.float16): - dtype = torch.float32 - device = tensor.device - if device.type != "cuda": - device = torch.device("cuda") - - # FP8 data buffer - data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) + if not tensor.is_cuda: + tensor = tensor.cuda() - # Check scale - if not devices_match(scale.device, device) or scale.dtype != dtype: - scale = scale.to(device=device, dtype=torch.float32) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() - # Check scale-inverse - if not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: - scale_inv = scale_inv.to(device=device, dtype=torch.float32) - - # Transpose cache - data_transpose = None - if columnwise_usage: - data_transpose = torch.empty( - (data.size(-1), data.numel() // data.size(-1)), - dtype=torch.uint8, - device=tensor.device, - ) - - # Construct FP8 tensor - out = Float8Tensor( - data=data, - fp8_scale_inv=scale_inv, - fp8_dtype=fp8_dtype, - dtype=dtype, - data_transpose=data_transpose, - proxy=proxy, - ) - - # Cast to FP8 tensor - out.quantize_(tensor, scale=scale, amax=amax) + out = tex.generic_cast(tensor, + qparams, + rowwise_usage, + columnwise_usage, + proxy) return out @@ -205,9 +172,7 @@ def forward( # Construct new tensor if constructor kwargs are provided default_kwargs = dict( data=tensor._data, - fp8_meta=tensor._fp8_meta, - fp8_meta_forward=tensor._fp8_meta_forward, - fp8_meta_index=tensor._fp8_meta_index, + proxy=tensor._proxy, fp8_dtype=tensor._fp8_dtype, fp8_scale_inv=tensor._scale_inv, dtype=tensor.dtype, @@ -233,7 +198,7 @@ class _ViewFunc(torch.autograd.Function): def forward( ctx, tensor: torch.Tensor, - shape: Tuple[int] = None, + shape: Optional[Tuple[int]] = None, ) -> torch.Tensor: # Return input tensor if shape is not provided @@ -243,6 +208,8 @@ def forward( # Construct new tensor if shape is provided if isinstance(tensor, Float8Tensor): + if tensor._data is None: + return tensor return Float8Tensor.make_like( tensor, data=tensor._data.view(*shape), @@ -256,6 +223,8 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: if isinstance(grad, Float8Tensor): + if grad._data is None: + return grad, None dgrad = Float8Tensor.make_like( grad, data=grad._data.view(ctx.shape), @@ -275,7 +244,7 @@ class _ReshapeFunc(torch.autograd.Function): def forward( ctx, tensor: torch.Tensor, - shape: Tuple[int] = None, + shape: Optional[Tuple[int]] = None, ) -> torch.Tensor: # Return input tensor if shape is not provided @@ -285,6 +254,8 @@ def forward( # Construct new tensor if shape is provided if isinstance(tensor, Float8Tensor): + if tensor._data is None: + return tensor return Float8Tensor.make_like( tensor, data=tensor._data.reshape(*shape), @@ -298,6 +269,8 @@ def backward( ) -> Tuple[Optional[torch.Tensor], ...]: if isinstance(grad, Float8Tensor): + if grad._data is None: + return grad, None dgrad = Float8Tensor.make_like( grad, data=grad._data.reshape(ctx.shape), @@ -321,15 +294,8 @@ class Float8Tensor(QuantizedTensor): fp8_attrs: dict, optional FP8 metadata, primarily managed by Float8Tensor. If provided, all other FP8 configuration is ignored. - fp8_meta: dict, optional + proxy: FP8ParamsProxy, optional FP8 metadata object, primarily managed by TE modules. - fp8_meta_forward: bool, default = `True` - Whether to access the FP8 metadata for the - forward pass. Ignored if fp8_meta is not - provided. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if - fp8_meta is provided and otherwise ignored. fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 FP8 format. fp8_scale_inv: torch.Tensor @@ -365,16 +331,6 @@ def __new__( proxy: Optional[Float8ParamsProxy] = None, fp8_attrs: Optional[Dict[str, Any]] = None, ): - # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - ) - if data.requires_grad: - raise ValueError("Float8Tensor requires non-differentiable data buffer") - if not data.is_cuda: - data = data.cuda() - # Initialize tensor object self = torch.Tensor._make_wrapper_subclass( cls, @@ -401,27 +357,9 @@ def __new__( self._proxy = proxy # FP8 dtype - assert fp8_dtype in ( - TE_DType.kFloat8E4M3, - TE_DType.kFloat8E5M2, - ), f"Unsupported fp8_dtype {fp8_dtype}." self._fp8_dtype = fp8_dtype # FP8 scale-inverse - if fp8_scale_inv.numel() != 1: - raise ValueError( - "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" - ) - if fp8_scale_inv.dim() != 1: - fp8_scale_inv = fp8_scale_inv.reshape(1) - if ( - not devices_match(fp8_scale_inv.device, self._data.device) - or fp8_scale_inv.dtype != torch.float32 - ): - fp8_scale_inv = fp8_scale_inv.to( - device=self._data.device, - dtype=torch.float32, - ) self._scale_inv = fp8_scale_inv # FP8 transpose cache @@ -474,28 +412,15 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: # Make sure FP8 data is in expected format data = self._data assert data is not None - if data.device.type != "cuda": - data = data.cuda() - if not data.is_contiguous(): - data = data.contiguous() - if data.dim() != 2: - data = data.view(1, -1) # Cast from FP8 - out = cast_from_fp8( - data.view(1, -1), - None, # fp8_meta_tensor - None, # fp8_tensor - self._fp8_dtype, - dtype, - scale_inv=self._scale_inv, + return tex.cast_from_fp8(data, + self._scale_inv, + self._fp8_dtype, + dtype, + 0 ) - # Make sure output is in expected format - if out.size() != self.size(): - out = out.view(self.size()) - return out - def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ Construct plain PyTorch tensor from Float8Tensor @@ -509,8 +434,6 @@ def quantize_( self, tensor: torch.Tensor, *, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None, ) -> Float8Tensor: """Update FP8 data @@ -675,13 +598,10 @@ def quantize( f"Invalid quantization params type: {type(params)}" return _ToFloat8Func.apply( tensor, - params.scale, - params.amax, - params.scale_inv, - params.dtype, + params, rowwise_usage, columnwise_usage, - proxy=proxy, + proxy, ) def detach(self) -> Float8Tensor: @@ -691,23 +611,34 @@ def detach(self) -> Float8Tensor: fp8_attrs=self._fp8_attrs, ) - def update_usage(self, rowwise=True, columnwise=True): - if rowwise: + def _create_transpose(self): + data = self._data + if not data.is_contiguous(): + data = data.contiguous() + self._transpose = tex.fp8_transpose(self._data, self._fp8_dtype, out=self._transpose) + self._transpose_invalid = False + + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + assert rowwise_usage or columnwise_usage, \ + "Could not disable all usages of the tensor" + if rowwise_usage: assert self._data is not None, \ "Rowwise usage of the tensor was already disabled" else: - if not columnwise or (self._transpose is not None and - not self._transpose_invalid): - self._data = None - if columnwise: - assert (self._transpose is not None and - not self._transpose_invalid) or self._data is not None, \ - "The tensor does not hold any data anymore" + if self._transpose is None or self._transpose_invalid: + self._create_transpose() + self._data = None + if columnwise_usage: + if self._transpose is None or self._transpose_invalid: + assert self._data is not None, \ + "The tensor does not hold any data anymore" + self._create_transpose() else: self._transpose = None self._transpose_invalid = True def clone(self) -> Float8Tensor: + assert self._data is not None data = self._data.detach().clone() data_transpose = None if self._transpose is not None: @@ -743,129 +674,6 @@ def contiguous( {"data": self._data.detach().contiguous(memory_format=memory_format)}, ) - def transpose_2d( - self, - *, - force_compute: bool = False, - fill_cache: bool = False, - noop_flag: Optional[torch.Tensor] = None, - cache: Optional[bool] = None, - ) -> torch.Tensor: - """ - 2D transpose with caching support. - - Parameters - ---------- - force_compute: bool, default = `False` - Force computation of transpose. Otherwise use - cached values, if possible. - fill_cache: bool, default = `False` - Cache output tensor for future function calls. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - cached values, if possible. - cache: bool, deprecated - - """ - - # Handle deprecated cache kwarg - if cache is not None: - msg = ( - "cache kwarg for Float8Tensor.transpose_2d is deprecated, " - "please use force_compute and fill_cache instead" - ) - warnings.warn(msg, DeprecationWarning) - if cache: - force_compute = False - fill_cache = True - else: - force_compute = True - fill_cache = False - - # Need to compute transpose if cache is invalid - need_compute = ( - force_compute - or (self._transpose is None) - or self._transpose_invalid - or (noop_flag is not None) - ) - - # Return cached transpose if possible - if not need_compute: - assert self._transpose is not None - return self._transpose - - # Allocate output if needed - data = self._data.contiguous().reshape(-1, self.size(-1)) - out: Optional[torch.Tensor] = self._transpose - if out is None: - out = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - noop_flag = None - else: - self._transpose_invalid = False - - # Apply transpose kernel - fp8_dtype = self._fp8_dtype - if noop_flag is None: - tex.fp8_transpose_noalloc(data, out, fp8_dtype) - else: - noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) - tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) - - # Fill cache if needed - if fill_cache: - self._transpose = out - self._transpose_invalid = False - - return out - - @torch.no_grad() - def cast_transpose_( - self, - tensor: torch.Tensor, - noop_flag: Optional[torch.Tensor] = None, - ) -> None: - """Cast from tensor and populate transpose cache - - Tensor is reshaped as a 2D matrix. - - Parameters - ---------- - tensor: torch.Tensor - Tensor to copy from. Must have same dimensions as - destination tensor. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - destination tensor. - - """ - if self._transpose is None: - self._transpose = torch.empty( - (self.size(-1), self.numel() // self.size(-1)), - dtype=torch.uint8, - device=self.device, - ) - self.quantize_(tensor, noop_flag=noop_flag) - - @torch.no_grad() - def reset_fp8_meta_scale_inv(self) -> None: - """Replace FP8 meta tensor scale-inverse with cached value - - The FP8 meta tensor scale_inv entry corresponding to this - tensor is replaced with the scale_inv value used to construct - the tensor. - - """ - assert self._fp8_meta is not None, "FP8 meta tensors not found." - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) - def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: """Create `Float8Tensor` with given nominal dtype @@ -960,10 +768,6 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Tensor device new_device = tensor.device if tensor.is_cuda else self.device - # Check whether grad is required - if self.requires_grad != tensor.requires_grad: - self.requires_grad_(requires_grad=tensor.requires_grad) - # Just copy FP8 data if other tensor is Float8Tensor if isinstance(tensor, Float8Tensor): if ( # pylint: disable=too-many-boolean-expressions @@ -987,42 +791,20 @@ def _set_data(self, tensor: torch.Tensor) -> None: super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) self._data = tensor._data self._fp8_attrs = tensor._fp8_attrs + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) return - # Reallocate FP8 data if needed - if ( - self.size() != tensor.size() - or self.stride() != tensor.stride() - or self.dtype != tensor.dtype - or self.layout != tensor.layout - or not devices_match(self.device, new_device) - ): - self._data = torch.empty_like( - tensor, - dtype=torch.uint8, - device=new_device, - ) - dummy_tensor = torch.Tensor._make_wrapper_subclass( - Float8Tensor, - self._data.size(), - strides=self._data.stride(), - storage_offset=self._data.storage_offset(), - dtype=tensor.dtype, - layout=self._data.layout, - requires_grad=tensor.requires_grad, - device=self._data.device, - ) - super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) - if self._transpose is not None: - self._transpose = torch.empty( - (self._data.size(-1), self._data.numel() // self._data.size(-1)), - dtype=torch.uint8, - device=self.device, - ) - self._transpose_invalid = True + assert self._proxy is not None, "Can't quantize without a proxy" + + self.data = Float8Tensor.quantize(tensor, + self._proxy.get_quantization_params(), + rowwise_usage=self._data is not None, + columnwise_usage=self._transpose is not None, + proxy=self._proxy) - # Copy values from other tensor - self.quantize_(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) # Cast to FP8 when setting Float8Tensor.data data = property(_get_data, _set_data)