Skip to content

Commit

Permalink
fp8 rowwise regular gemm tuning for llm new shapes (#3654)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3654

X-link: facebookresearch/FBGEMM#730

1. Add more FP8 rowwise instances
2. Extend the template to allow specifying AK1 and BK1
3. Add tuning for new shapes

Reviewed By: jianyuh

Differential Revision: D69070084

fbshipit-source-id: 7f841bd0855743d5389d7afa0f135619ca0cd61a
  • Loading branch information
mxz297 authored and facebook-github-bot committed Feb 4, 2025
1 parent 6a78ec6 commit bdcce9c
Show file tree
Hide file tree
Showing 17 changed files with 628 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,72 @@ static const std::map<int, RowwiseKernel> N_5120_K_1024_dispatch_table = {
{ 8192, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
};

static const std::map<int, RowwiseKernel> N_2048_K_5120_dispatch_table = {
{ 4, fp8_rowwise_256x16x64x128_16x16_1x1_16x16x1_8x32x1_1x16x1x16_4x4x1_1x1_intrawave_v2_8},
{ 8, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_4},
{ 64, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2},
{ 288, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 576, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 1216, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 1664, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 2432, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 2944, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 3456, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 4864, fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 5888, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 5984, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
};

static const std::map<int, RowwiseKernel> N_896_K_5120_dispatch_table = {
{ 64, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
{ 72, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2_8},
{ 80, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2_2},
{ 160, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2},
{ 200, fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
{ 256, fp8_rowwise_256x64x16x512_16x16_1x1_32x8x1_32x8x1_1x64x1x4_4x4x1_1x1_intrawave_v2},
{ 672, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 1344, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 2752, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 3840, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 5504, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 5984, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
};

static const std::map<int, RowwiseKernel> N_5120_K_640_dispatch_table = {
{ 64, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{ 80, fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
{ 112, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
{ 192, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 224, fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2},
{ 256, fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
{ 384, fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 448, fp8_rowwise_256x64x128x128_32x32_1x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 512, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 704, fp8_rowwise_256x64x192x128_32x32_1x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 896, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 960, fp8_rowwise_256x64x256x128_32x32_1x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 1152, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
{ 1280, fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 1408, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 1920, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 2304, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 2816, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 3360, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 3840, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 4864, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{ 5520, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
{ 5760, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{ 5984, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
};

static const std::unordered_map<std::tuple<int, int>, NKLookupTableType, IntTupleHash> NK_lookup_table = {
{{7168, 8192}, N_7168_K_8192_dispatch_table},
{{8192, 3584}, N_8192_K_3584_dispatch_table},
{{1024, 5120}, N_1024_K_5120_dispatch_table},
{{5120, 1024}, N_5120_K_1024_dispatch_table}
{{5120, 1024}, N_5120_K_1024_dispatch_table},
{{2048, 5120}, N_2048_K_5120_dispatch_table},
{{896, 5120}, N_896_K_5120_dispatch_table},
{{5120, 640}, N_5120_K_640_dispatch_table}
};

RowwiseKernel rowwise_nk_lookup(int M, const NKLookupTableType& table) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_4(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 4);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2_8(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 8);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2_2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
512,
16,
16,
1,
1,
S<32, 4, 1>,
S<32, 4, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 2);
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,25 @@ fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// A small kernel for small but not tiny shapes.

// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 32 != 0) || (N % 16 != 0) || (K % 128 != 0);

if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
32,
16,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<2, 2, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
128,
32,
16,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<2, 2, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}
using DeviceGemmInstance = DeviceGemmHelper<
128,
32,
16,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<2, 2, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Original file line number Diff line number Diff line change
Expand Up @@ -15,54 +15,25 @@ fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// A kernel that works well on small but not super tiny shapes.

// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (M % 32 != 0) || (N % 64 != 0) || (K % 128 != 0);

if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
32,
64,
128,
32,
32,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
128,
32,
64,
128,
32,
32,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}
using DeviceGemmInstance = DeviceGemmHelper<
128,
32,
64,
128,
32,
32,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}

Loading

0 comments on commit bdcce9c

Please sign in to comment.