diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index e95d8ad11f..1236e3b7ed 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -194,8 +194,12 @@ void performTestGLU(const size_t N, const size_t H) { ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + auto [atol, rtol] = getTolerances(DType::kFloat32); + compareResults("amax", output.amax(), ref_amax, atol, rtol); + if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + const float ref_scale = 1.f; // assuming input scale is 1.f + compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr(), ref_scale, atol, rtol); + } } auto [atol, rtol] = getTolerances(otype); compareResults("output_gelu", output, ref_output.get(), atol, rtol); diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu index 8c18f048bc..be0b6acf04 100644 --- a/tests/cpp/operator/test_cast.cu +++ b/tests/cpp/operator/test_cast.cu @@ -23,31 +23,31 @@ namespace { template void compute_ref(const InputType *data, OutputType *output_c, - const size_t N, const size_t H, + const size_t size, float *amax, float scale) { using compute_t = float; compute_t current_max = -1e100; - for (size_t i = 0; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); current_max = fmaxf(current_max, fabsf(current)); - output_c[i * H + j] = OutputType(scale * current); - } + output_c[i] = OutputType(scale * current); } *amax = current_max; } template -void performTest(const size_t N, const size_t H) { +void performTest(const std::vector& shape) { using namespace test; + const size_t full_size = product(shape); + DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output_c({ N, H }, otype); + Tensor input(shape, itype); + Tensor output_c(shape, otype); - std::unique_ptr ref_output_c = std::make_unique(N * H); + std::unique_ptr ref_output_c = std::make_unique(full_size); fillUniform(&input); setRandomScale(&output_c); @@ -56,7 +56,7 @@ void performTest(const size_t N, const size_t H) { float ref_amax; compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), - N, H, &ref_amax, output_c.scale()); + full_size, &ref_amax, output_c.scale()); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -71,7 +71,9 @@ void performTest(const size_t N, const size_t H) { compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); } -std::vector> test_cases = { +std::vector> test_cases = { + {16}, + {16000}, {128, 128}, {256, 256}, {768, 1024}, @@ -79,19 +81,19 @@ std::vector> test_cases = { {2048, 12288}, {65536, 128}, {65536, 160}, - {16384, 6144}, {16384, 1616}, {1, 128}, {1, 1296}, {1, 16}, {5, 160}, + {5, 4, 3, 160}, {217, 256}, }; } // namespace class CastTestSuite : public ::testing::TestWithParam>> {}; + std::vector>> {}; TEST_P(CastTestSuite, TestCast) { using namespace transformer_engine; @@ -103,7 +105,7 @@ TEST_P(CastTestSuite, TestCast) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second); + performTest(size); ); ); } @@ -119,8 +121,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(test_cases)), [](const testing::TestParamInfo& info) { std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second); + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } return name; }); diff --git a/tests/cpp/operator/test_cast_dbias.cu b/tests/cpp/operator/test_cast_dbias.cu index 3fa8383a83..20ae33e304 100644 --- a/tests/cpp/operator/test_cast_dbias.cu +++ b/tests/cpp/operator/test_cast_dbias.cu @@ -56,16 +56,19 @@ void compute_ref_cast_dbias(const IT *input_h, } template -void performTest(const size_t N, const size_t H) { +void performTest(const std::vector& shape) { using namespace test; using CType = fp32; DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({N, H}, itype); + const size_t N = first_dimension(shape); + const size_t H = last_dimension(shape); - Tensor output_c({N, H}, otype); + Tensor input(shape, itype); + + Tensor output_c(shape, otype); // dbias has the same data type with "output grad" Tensor dbias({H}, itype); @@ -117,7 +120,7 @@ void performTest(const size_t N, const size_t H) { compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); } -std::vector> test_cases = { +std::vector> test_cases = { {128, 128}, {256, 256}, {768, 1024}, @@ -125,12 +128,12 @@ std::vector> test_cases = { {2048, 12288}, {65536, 128}, {65536, 160}, - {16384, 6144}, {16384, 1616}, {1, 128}, {1, 1296}, {1, 16}, {5, 160}, + {5, 4, 3, 160}, {217, 256}, }; @@ -139,7 +142,7 @@ std::vector> test_cases = { class CastDBiasTestSuite : public ::testing::TestWithParam>> {}; + std::vector>> {}; TEST_P(CastDBiasTestSuite, TestCastDBias) { using namespace transformer_engine; @@ -155,7 +158,7 @@ TEST_P(CastDBiasTestSuite, TestCastDBias) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second); + performTest(size); ); ); } @@ -169,8 +172,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(test_cases)), [](const testing::TestParamInfo& info) { std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second); + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } return name; }); diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu index 34e59be2ec..1fb6acf834 100644 --- a/tests/cpp/operator/test_cast_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu @@ -64,17 +64,20 @@ void compute_ref_cast_dbias_dgelu(const IT *input, } template -void performTest(const size_t N, const size_t H) { +void performTest(const std::vector& shape) { using namespace test; using CType = fp32; DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({N, H}, itype); - Tensor gelu_input({N, H}, itype); + const size_t N = first_dimension(shape); + const size_t H = last_dimension(shape); - Tensor output_c({N, H}, otype); + Tensor input(shape, itype); + Tensor gelu_input(shape, itype); + + Tensor output_c(shape, otype); // dbias has the same data type with "output grad" Tensor dbias({H}, itype); @@ -132,7 +135,7 @@ void performTest(const size_t N, const size_t H) { compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); } -std::vector> test_cases = { +std::vector> test_cases = { {128, 128}, {256, 256}, {768, 1024}, @@ -140,12 +143,12 @@ std::vector> test_cases = { {2048, 12288}, {65536, 128}, {65536, 160}, - {16384, 6144}, {16384, 1616}, {1, 128}, {1, 1296}, {1, 16}, {5, 160}, + {5, 4, 3, 160}, {217, 256}, }; @@ -154,7 +157,7 @@ std::vector> test_cases = { class CastDBiasDGeluTestSuite : public ::testing::TestWithParam>> {}; + std::vector>> {}; TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) { using namespace transformer_engine; @@ -170,7 +173,7 @@ TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second); + performTest(size); ); ); } @@ -184,8 +187,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(test_cases)), [](const testing::TestParamInfo& info) { std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second); + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } return name; }); diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu index d165807168..83cc5f6dbf 100644 --- a/tests/cpp/operator/test_cast_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_gated_swiglu.cu @@ -58,21 +58,29 @@ void compute_ref_cast_dgated_swiglu(const IType * const grad, } template -void performTest(const size_t rows, const size_t cols) { +void performTest(const std::vector& shape) { using namespace test; DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor grad({rows, cols}, itype); - Tensor input({rows, cols * 2}, itype); - Tensor output_c({rows, cols * 2}, otype); + std::vector input_shape = shape; + input_shape[input_shape.size() - 1] *= 2; + + const size_t input_size = product(input_shape); + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + Tensor grad(shape, itype); + Tensor input(input_shape, itype); + Tensor output_c(input_shape, otype); fillUniform(&grad); fillUniform(&input); setRandomScale(&output_c); - std::unique_ptr ref_output_c = std::make_unique(rows * cols * 2); + std::unique_ptr ref_output_c = std::make_unique(input_size); nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0); cudaDeviceSynchronize(); @@ -100,21 +108,28 @@ void performTest(const size_t rows, const size_t cols) { compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); } -std::vector> test_cases = { +std::vector> test_cases = { {128, 128}, {256, 256}, {768, 1024}, - // {256, 65536}, - // {2048, 12288}, - // {65536, 128}, - // {16384, 6144}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, }; } // namespace class CastSwiGLUTestSuite : public ::testing::TestWithParam>> {}; + transformer_engine::DType, transformer_engine::DType, std::vector>> {}; TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) { using namespace transformer_engine; @@ -131,7 +146,7 @@ TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( - output_type, OutputType, performTest(size.first, size.second););); + output_type, OutputType, performTest(size););); } INSTANTIATE_TEST_SUITE_P( @@ -142,8 +157,10 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(test_cases)), [](const testing::TestParamInfo &info) { std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second); + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } return name; }); diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index db574748cc..93627a7231 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -169,8 +169,7 @@ void compute_ref_x2(const ProcessingMethod processing_method, template void performTest_x1(const ProcessingMethod processing_method, - const size_t rows, - const size_t cols, + const std::vector& shape, const bool rowwise, const bool colwise, InputsFillCase fill_case) { @@ -179,6 +178,9 @@ void performTest_x1(const ProcessingMethod processing_method, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + const size_t block_size_rows = rowwise ? 1 : 32; const size_t block_size_cols = colwise ? 1 : 32; const size_t unpadded_blocks_Y = (rows + block_size_rows - 1) / block_size_rows; @@ -196,9 +198,9 @@ void performTest_x1(const ProcessingMethod processing_method, const size_t blocks_X = ((unpadded_blocks_X + block_alignment_X - 1) / block_alignment_X) * block_alignment_X; const size_t scales_stride = blocks_X; - Tensor input({ rows, cols }, itype); - Tensor act_input({ rows, cols }, itype); - Tensor output_c({ rows, cols }, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + Tensor input(shape, itype); + Tensor act_input(shape, itype); + Tensor output_c(shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); Tensor output_dbias({ cols }, itype); std::unique_ptr ref_output_c = std::make_unique(rows * cols); @@ -301,8 +303,7 @@ void performTest_x1(const ProcessingMethod processing_method, */ template void performTest_x2(const ProcessingMethod processing_method, - const size_t rows, - const size_t cols, + const std::vector& shape, const size_t block_size_rows, const size_t block_size_cols, InputsFillCase fill_case) { @@ -311,6 +312,9 @@ void performTest_x2(const ProcessingMethod processing_method, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + const size_t unpadded_blocks_Y_rowwise = rows; const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); @@ -328,9 +332,9 @@ void performTest_x2(const ProcessingMethod processing_method, const size_t scales_stride_rowwise = blocks_X_rowwise; const size_t scales_stride_colwise = blocks_X_colwise; - Tensor input({ rows, cols }, itype); - Tensor act_input({ rows, cols }, itype); - Tensor output({ rows, cols }, otype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor input(shape, itype); + Tensor act_input(shape, itype); + Tensor output(shape, otype, true, true, NVTE_MXFP8_1D_SCALING); Tensor output_dbias({ cols }, itype); std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols); @@ -429,18 +433,25 @@ void performTest_x2(const ProcessingMethod processing_method, } } -std::vector> matrix_sizes = { - {1, 16}, - {16, 48}, - {65, 96}, - {128, 128}, - {256, 256}, - {993, 512}, - {768, 1024}, - // {2048, 12288}, - // {65536, 128}, - // {16384, 1632}, - // {16384, 6144}, +std::vector> matrix_sizes = { + {1, 16}, + {32}, + {65536}, + {128, 128}, + {256, 256}, + {768, 1024}, + {993, 512}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1632}, + {1, 128}, + {1, 1296}, + {1, 32}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, }; std::vector> block_sizes = { @@ -480,7 +491,7 @@ std::vector Activation_types = { class FusedCastMXFP8TestSuite : public ::testing::TestWithParam , + std::vector, std::pair, transformer_engine::DType, transformer_engine::DType, @@ -544,11 +555,11 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, if (block_size.first == 1 || block_size.second == 1) { performTest_x1( - processing_method, matrix_size.first, matrix_size.second, + processing_method, matrix_size, rowwise, colwise, fill_case); } else { performTest_x2( - processing_method, matrix_size.first, matrix_size.second, + processing_method, matrix_size, block_size.first, block_size.second, fill_case); } ); @@ -560,11 +571,11 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, if (block_size.first == 1 || block_size.second == 1) { performTest_x1( - processing_method, matrix_size.first, matrix_size.second, + processing_method, matrix_size, rowwise, colwise, fill_case); } else { performTest_x2( - processing_method, matrix_size.first, matrix_size.second, + processing_method, matrix_size, block_size.first, block_size.second, fill_case); } ); @@ -609,13 +620,15 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(input_scenarios)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)) + "X" + - to_string(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second) + "X" + - std::to_string(std::get<3>(info.param).first) + "X" + - std::to_string(std::get<3>(info.param).second) + "X" + - test::typeName(std::get<4>(info.param)) + "X" + - test::typeName(std::get<5>(info.param)) + "X" + - test::caseName(std::get<6>(info.param)); + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + std::to_string(std::get<3>(info.param).first) + + "X" + std::to_string(std::get<3>(info.param).second) + + "X" + test::typeName(std::get<4>(info.param)) + + "X" + test::typeName(std::get<5>(info.param)) + + "X" + test::caseName(std::get<6>(info.param)); return name; }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index a0b65318e4..1b33157078 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -752,4 +752,15 @@ int32_t getDeviceComputeCapability() return 10 * deviceProp.major + deviceProp.minor; } +size_t first_dimension(const std::vector &shape) { + if (shape.size() == 0) return 0; + if (shape.size() == 1) return 1; + return product(shape, 0, shape.size() - 1); +} + +size_t last_dimension(const std::vector &shape) { + if (shape.size() == 0) return 0; + return shape[shape.size() - 1]; +} + } // namespace test diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 9ab59dfd96..d79131d3a4 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -407,6 +407,10 @@ inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } size_t typeToSize(DType type); size_t product(const NVTEShape &shape); +size_t product(const std::vector &shape); + +size_t first_dimension(const std::vector &shape); +size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 064b913bf2..50dadeaee8 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -80,25 +80,25 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); - const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_elems_total = BUFFERS_NUM * buff_elems; - const size_t buff_size_aligned_in = + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = + constexpr size_t buff_size_aligned_out = DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; + constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = in_act_mem + in_gate_mem; + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - const size_t out_mem = out_act_mem + out_gate_mem; + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t out_gate_mem = buff_size_aligned_out; + constexpr size_t out_mem = out_act_mem + out_gate_mem; // const size_t in_transaction_size = grad_mem + in_mem; - const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_grad_sh = reinterpret_cast(dshmem); @@ -118,44 +118,21 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma nv_diag_suppress static_var_with_dynamic_init __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - if (is_master_thread) { -// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); - } - ptx::fence_proxy_async_shared_cta(); - } - // Syncthreads so initialized barrier is visible to all threads. - __syncthreads(); + initialize_barriers(mbar, is_master_thread); int parity = 0; // Prefetch data of the first stage - if (is_master_thread) { - // Initiate bulk tensor copy - if constexpr (IS_DGATED) { - // Grad - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), - TMAP_grad_in, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - } - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), - TMAP_gate_in, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), - TMAP_gate_in, chunk_offset_X + cols, - chunk_offset_Y, &mbar[0]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, + in_act_sh, TMAP_gate_in, chunk_offset_X, chunk_offset_Y, + in_gate_sh, TMAP_gate_in, chunk_offset_X + cols, chunk_offset_Y, + in_transaction_size, &mbar[0], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[0]); + copy_2d_to_sharedx2(in_act_sh, TMAP_gate_in, chunk_offset_X, chunk_offset_Y, + in_gate_sh, TMAP_gate_in, chunk_offset_X + cols, chunk_offset_Y, + in_transaction_size, &mbar[0], is_master_thread); } #pragma unroll @@ -163,31 +140,23 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int buff = it % BUFFERS_NUM; const int next_it = it + 1; if (next_it < ITERATIONS) { - if (is_master_thread) { - const int next_buff = next_it % BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - // Initiate bulk tensor copy - if constexpr (IS_DGATED) { - // Grad - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - } - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_gate_in, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_gate_in, - chunk_it_offset_x + cols, chunk_it_offset_y, &mbar[next_it]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, + chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_gate_in, + chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_gate_in, + chunk_it_offset_x + cols, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[next_it]); + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_gate_in, + chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_gate_in, + chunk_it_offset_x + cols, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); } } @@ -697,9 +666,9 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } - NVTE_CHECK(!output->has_columnwise_data(), "Only cast supported in this function."); - const size_t rows = gated_input.data.shape[0]; - const size_t cols = gated_input.data.shape[1] / 2; + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); @@ -899,19 +868,25 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions."); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], - "Output shape[0] must be equal to grad shape[0]."); - NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, - "Output shape[1] must be 2x larger than grad shape[1]."); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); + NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), + "Wrong output shape. Expected (after flattenting) [", + grad.flat_first_dim(), ", *], got [", + output->flat_first_dim(), ", ", + output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, + "Wrong output shape. Expected (after flattenting) [*, ", + grad.flat_last_dim() * 2, "], got [", + output->flat_first_dim(), ", ", + output->flat_last_dim(), "]."); + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes must match. Input shape: ", + input.data.shape, ", output shape: ", + output->data.shape, "."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, + input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, + output->dtype(), OType, if (!is_fp8_dtype(output->data.dtype) || is_delayed_tensor_scaling(output->scaling_mode)) { @@ -919,8 +894,11 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt DGatedActivationKernelLauncher( reinterpret_cast(grad.data.dptr), reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), grad.data.shape[0], - grad.data.shape[1], {}, stream); + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), {}, stream); } else { NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); }); // NOLINT(*) @@ -936,20 +914,18 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu CheckInputTensor(gated_input, "gated_input"); CheckOutputTensor(*output, "output"); - const size_t rows = gated_input.data.shape[0]; - const size_t cols = gated_input.data.shape[1] / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); - NVTE_CHECK(gated_input.data.shape[1] % 2 == 0, "Number of columns must be even."); - NVTE_CHECK(gated_input.data.shape.size() == 2, "Gated input must have 2 dimensions."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; if constexpr (IS_DGATED) { CheckInputTensor(grad, "grad"); NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(grad.data.shape.size() == 2, "Grad input must have 2 dimensions."); - NVTE_CHECK(grad.data.shape[0] == rows, "Wrong dimension of the grad input."); - NVTE_CHECK(grad.data.shape[1] == cols, "Wrong dimension of the grad input."); + NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); + NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); } NVTE_CHECK(output->has_data() || output->has_columnwise_data(), @@ -959,15 +935,13 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu bool is_fp8_colwise_output = true; if (output->has_data()) { is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->data.shape[0] == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->data.shape[1] == output_cols, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); } if (output->has_columnwise_data()) { is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); - NVTE_CHECK(output->columnwise_data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->columnwise_data.shape[0] == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->columnwise_data.shape[1] == output_cols, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); } const bool is_full_tile = (rows % CHUNK_DIM_Y == 0) && (cols % CHUNK_DIM_X == 0); @@ -987,7 +961,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if (use_tma_kernels) { cast_mxfp8_gated(grad, gated_input, output, stream); } else { - NVTE_ERROR("MX FP8 quantization supports full tiles only."); + NVTE_ERROR("MXFP8 quantization supports full tiles only."); } } else { NVTE_ERROR("Not supported scaling mode"); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index d713738a4e..30061fc3b9 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -28,104 +28,6 @@ namespace transformer_engine { -namespace { - -template -__forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if (is_master_thread) { - // Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. -#pragma unroll - for (int iter = 0; iter < num_barriers; ++iter) { - ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); - } - ptx::fence_proxy_async_shared_cta(); - } - // Syncthreads so initialized barrier is visible to all threads. - __syncthreads(); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -template -__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_master_thread) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - // Destroy barrier. This invalidates the memory region of the barrier. If - // further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int iter = 0; iter < num_barriers; ++iter) { - ptx::mbarrier_invalid(&mbar[iter]); - } - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, - const size_t num_bytes, uint64_t *barrier, - const bool is_master_thread) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if (is_master_thread) { - // Initiate bulk tensor copy - ptx::cp_async_bulk_tensor_1d_global_to_shared(reinterpret_cast(dst), - reinterpret_cast(src), - num_bytes, barrier); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); - } else { - // Other threads just arrive - ptx::mbarrier_arrive(barrier); - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -__forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, - const size_t chunk_Y, const size_t num_bytes, - uint64_t *barrier, const bool is_master_thread) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if (is_master_thread) { - // Initiate bulk tensor copy - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), - reinterpret_cast(src), chunk_X, - chunk_Y, barrier); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); - } else { - // Other threads just arrive - ptx::mbarrier_arrive(barrier); - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -__forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, void *dst2, - const void *src2, const size_t chunk_X, - const size_t chunk_Y, const size_t num_bytes, - uint64_t *barrier, - const bool is_master_thread) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if (is_master_thread) { - // Initiate bulk tensor copy - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), - reinterpret_cast(src), chunk_X, - chunk_Y, barrier); - - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), - reinterpret_cast(src2), chunk_X, - chunk_Y, barrier); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); - } else { - // Other threads just arrive - ptx::mbarrier_arrive(barrier); - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -} // namespace - constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; @@ -278,7 +180,9 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; const int chunk_stage_offset_X = chunk_offset_X; if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, &act_in_sh[prefetch_buff], + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, + chunk_stage_offset_X, chunk_stage_offset_Y, + &act_in_sh[prefetch_buff], &tensor_map_act_input, chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], is_master_thread); } else { @@ -299,8 +203,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, &act_in_sh[next_buff], - &tensor_map_act_input, chunk_it_offset_x, chunk_it_offset_y, + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, + chunk_it_offset_x, chunk_it_offset_y, + &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); } else { copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, @@ -616,8 +522,10 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; const int chunk_stage_offset_X = chunk_offset_X; if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, &act_in_sh[prefetch_buff], - &tensor_map_act_input, chunk_stage_offset_X, chunk_stage_offset_Y, + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, + chunk_stage_offset_X, chunk_stage_offset_Y, + &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], is_master_thread); } else { copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, @@ -636,8 +544,10 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, &act_in_sh[next_buff], - &tensor_map_act_input, chunk_it_offset_x, chunk_it_offset_y, + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, + chunk_it_offset_x, chunk_it_offset_y, + &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); } else { copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, @@ -935,10 +845,9 @@ template = 2, "Input must have at least 2 dimensions."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); if (use_rowwise_scaling) { diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 46fdb82a48..02b478349a 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -167,6 +167,144 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // namespace ptx + +namespace { + +template +__forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, + const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_1d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), + num_bytes, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, + const size_t chunk_Y, const size_t num_bytes, + uint64_t *barrier, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X, + chunk_Y, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, + const size_t chunk_X1, + const size_t chunk_Y1, + void *dst2, const void *src2, + const size_t chunk_X2, + const size_t chunk_Y2, + const size_t num_bytes, + uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, + chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), + reinterpret_cast(src2), chunk_X2, + chunk_Y2, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx3(void *dst, const void *src, + const size_t chunk_X1, + const size_t chunk_Y1, + void *dst2, const void *src2, + const size_t chunk_X2, + const size_t chunk_Y2, + void *dst3, const void *src3, + const size_t chunk_X3, + const size_t chunk_Y3, + const size_t num_bytes, + uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, + chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), + reinterpret_cast(src2), chunk_X2, + chunk_Y2, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst3), + reinterpret_cast(src3), chunk_X3, + chunk_Y3, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +} // namespace } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_PTX_CUH_ diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index a20449045d..420b9ed3bb 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -404,18 +404,19 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; VectorizedLoader loader0(input + id_y * n * 2, n); VectorizedLoader loader1(input + id_y * n * 2 + n, n); VectorizedStorer storer(output + id_y * n, n); - ComputeType max = 0; - ComputeType s = 1; - if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; - } - const int warp_id = threadIdx.x / THREADS_PER_WARP; loader0.load(id_x, n); loader1.load(id_x, n); @@ -432,21 +433,20 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.separate()[i] = static_cast(static_cast(temp)); } storer.store(id_x, n); - - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } - // Update scale-inverse - if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { - reciprocal(scale_inv, s); - } + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -491,9 +491,17 @@ template __launch_bounds__(unary_kernel_threads) __global__ void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; @@ -516,12 +524,35 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgate = grad_val * Activation(gelu_in, p); + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(after_dgelu), max); + after_dgelu = after_dgelu * s; + max = fmaxf(fabsf(after_dgate), max); + after_dgate = after_dgate * s; + } + storer0.separate()[i] = static_cast(after_dgelu); storer1.separate()[i] = static_cast(after_dgate); } storer0.store(id_x, n); storer1.store(id_x, n); } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } + } } template void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input, - OutputType *output, const size_t m, const size_t n, - const Param &p, cudaStream_t stream) { + OutputType *output, const fp32 *scale, fp32 *amax, + fp32 *scale_inv, const size_t m, const size_t n, const Param &p, + cudaStream_t stream) { if (m != 0 && n != 0) { size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; @@ -541,18 +573,19 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { case Alignment::SAME_ALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, - num_aligned_elements); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, - num_aligned_elements); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> - <<>>(grad, input, output, m, n, p, n); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, n); break; } }