-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] Skinny fp16 gemm on H100 #1176
Comments
Can you swap the AB and use a Tile shape of 64x256? |
You can also use the TMA epilogue and set the C type to void which should buy you some b/w savings by skipping reads of C |
@thakkarV any example of how exactly set 'C type to void' ? thanks! |
@divchenko, here's an example of setting C type to void via the collective builder API: Lines 382 to 425 in 39c6a83
The specific line that does this is here: https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu#L399 |
Thanks for bringing this to our attention - I think you already have a decent set of kernel parameter choices (except the scheduler). The fact you see sub-optimal perf is likely due to the fact that you need an optimized stream-K / split-K implementation (which isn't yet released). The existing stream-K implementation needs more perf. tuning in order to hit peak perf. I think. With a tile size of 64x16 - the total occupancy is just 43SMs, so stream-K is essential in getting more occupancy and hence higher performance (due to having more requests in flight to memory)- to pesudo-test / simulate what an optimal stream-K / split-K kernel would buy in terms of perf., try making GEMM_M 3X larger and GEMM_K 3x smaller (I tested it locally and saw 80%+ bandwidth utilization.) ++@jackkosaian - since he may be able to comment better about stream-K / split-K on SM90. |
Thanks folks for the detailed replies! |
I've also discovered that KernelTmaWarpSpecializedCooperative supports cutlass::gemm::StreamKScheduler . My grid jumped to 132, but memory bandwidth actually degraded a bit. |
The SM80 mainloops haven't been optimized extensively on CUTLASS 3.x - @thakkarV may be able to comment if it looks optimal in terms of #stages etc. Yes, you're right regarding the scheduler support - as mentioned we probably need to tune the StreamK scheduler more to make it optimal for perf. This is the config I used to simulate optimal stream-K / split-K perf. : // Modified GEMM_MNK = {8256, 16, 1344} (approx. 3x larger M & 3x smaller N)
using ElementA = cutlass::half_t;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = cutlass::half_t;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementAccumulator = float;
using LayoutC = cutlass::layout::ColumnMajor;
using TileShape_MNK = Shape<_64,_16,_64>;
using ClusterShape_MNK = Shape<_1,_1,_1>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
cutlass::half_t, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
cutlass::epilogue::NoSmemWarpSpecialized
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
ElementA, LayoutA, 8,
ElementB, LayoutB, 8,
ElementAccumulator,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAuto,
//cutlass::gemm::KernelTmaWarpSpecializedCooperative
cutlass::gemm::KernelTmaWarpSpecialized
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
//cutlass::gemm::StreamKScheduler
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
This issue has been labeled |
Closing due to inactivity. Feel free to reopen if needed. |
This is very similar to #1139 , but about fp16 this time.
I'm trying to optimize 2752x16x4096 (MNK) fp16 matmul on H100 (A/B/D are all in fp16, I don't use C).
Cublas is not very good for this gemm, resulting only in ~40% of memory b/w. I've tried cutlass profiler for both sm80 and sm90 gemms (the whole op is quite memory b/w-bound), but it's only getting up to 1TB/sec, while I want something closer to 3TB/sec.
After spending quite some time playing w/ Hopper wgmma op, I've managed to come up w/ code, which gets ~55% of memory b/w and outperforms cublas by ~25%, which is a nice win.
Here's the code https://gist.github.com/divchenko/528fbc8f2a2a7f1e03cca82a6cf71e38 - it's a PyTorch op.
Question: what settings can I tweak to get closer to 80% of memory b/w?
Thanks!
The text was updated successfully, but these errors were encountered: