Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Skinny fp16 gemm on H100 #1176

Closed
divchenko opened this issue Nov 3, 2023 · 10 comments
Closed

[QST] Skinny fp16 gemm on H100 #1176

divchenko opened this issue Nov 3, 2023 · 10 comments

Comments

@divchenko
Copy link

This is very similar to #1139 , but about fp16 this time.
I'm trying to optimize 2752x16x4096 (MNK) fp16 matmul on H100 (A/B/D are all in fp16, I don't use C).
Cublas is not very good for this gemm, resulting only in ~40% of memory b/w. I've tried cutlass profiler for both sm80 and sm90 gemms (the whole op is quite memory b/w-bound), but it's only getting up to 1TB/sec, while I want something closer to 3TB/sec.
After spending quite some time playing w/ Hopper wgmma op, I've managed to come up w/ code, which gets ~55% of memory b/w and outperforms cublas by ~25%, which is a nice win.
Here's the code https://gist.github.com/divchenko/528fbc8f2a2a7f1e03cca82a6cf71e38 - it's a PyTorch op.

Question: what settings can I tweak to get closer to 80% of memory b/w?
Thanks!

@thakkarV
Copy link
Collaborator

thakkarV commented Nov 3, 2023

Can you swap the AB and use a Tile shape of 64x256?

@thakkarV
Copy link
Collaborator

thakkarV commented Nov 3, 2023

You can also use the TMA epilogue and set the C type to void which should buy you some b/w savings by skipping reads of C

@divchenko
Copy link
Author

divchenko commented Nov 3, 2023

@thakkarV any example of how exactly set 'C type to void' ? thanks!

@jackkosaian
Copy link
Contributor

@divchenko, here's an example of setting C type to void via the collective builder API:

TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU_VoidC) {
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
using TileShape_MNK = Shape<_256,_128,_64>;
using ClusterShape_MNK = Shape<_2,_2,_1>;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux<
LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape_MNK, ClusterShape_MNK,
cutlass::epilogue::collective::EpilogueTileAuto,
float, float,
void, LayoutC, 8,
cutlass::half_t, LayoutC, 8,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
cutlass::half_t, LayoutA, 8,
cutlass::half_t, LayoutB, 8,
float,
TileShape_MNK, ClusterShape_MNK,
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveMainloop,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
bool passed = test::gemm::device::TestAllBiasElementwise<Gemm>();
EXPECT_TRUE(passed);
}

The specific line that does this is here: https://github.com/NVIDIA/cutlass/blob/main/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu#L399

@IonThruster
Copy link
Collaborator

IonThruster commented Nov 5, 2023

Thanks for bringing this to our attention - I think you already have a decent set of kernel parameter choices (except the scheduler). The fact you see sub-optimal perf is likely due to the fact that you need an optimized stream-K / split-K implementation (which isn't yet released). The existing stream-K implementation needs more perf. tuning in order to hit peak perf. I think.

With a tile size of 64x16 - the total occupancy is just 43SMs, so stream-K is essential in getting more occupancy and hence higher performance (due to having more requests in flight to memory)- to pesudo-test / simulate what an optimal stream-K / split-K kernel would buy in terms of perf., try making GEMM_M 3X larger and GEMM_K 3x smaller (I tested it locally and saw 80%+ bandwidth utilization.)

++@jackkosaian - since he may be able to comment better about stream-K / split-K on SM90.

@divchenko
Copy link
Author

Thanks folks for the detailed replies!
@IonThruster I've tried to experiment following your advice - increased M by 4 and decreased K by 4. Launch grid jumped to 172 blocks, but memory throughput is still stuck at ~50%. On another hand, if I just increase M by 4 then throughput jumps closer to 80%.
I've also tried to play w/ Ampere ops to just decrease tile size and hence increase launch grid, but it only improved memory b/w by a few % . Here's the code https://gist.github.com/divchenko/ac39ba61d3ec24ff74a616992e3589d0 - I had to code CollectiveMma manually as looks like there are builders for Hopper only.

@divchenko
Copy link
Author

I've also discovered that KernelTmaWarpSpecializedCooperative supports cutlass::gemm::StreamKScheduler . My grid jumped to 132, but memory bandwidth actually degraded a bit.

@IonThruster
Copy link
Collaborator

The SM80 mainloops haven't been optimized extensively on CUTLASS 3.x - @thakkarV may be able to comment if it looks optimal in terms of #stages etc. Yes, you're right regarding the scheduler support - as mentioned we probably need to tune the StreamK scheduler more to make it optimal for perf.

This is the config I used to simulate optimal stream-K / split-K perf. :

 // Modified GEMM_MNK = {8256, 16, 1344}  (approx. 3x larger M & 3x smaller N)
  using ElementA = cutlass::half_t;
  using LayoutA = cutlass::layout::RowMajor;
  using ElementB = cutlass::half_t;
  using LayoutB = cutlass::layout::ColumnMajor;
  using ElementAccumulator = float;
  using LayoutC = cutlass::layout::ColumnMajor;
  using TileShape_MNK = Shape<_64,_16,_64>;
  using ClusterShape_MNK = Shape<_1,_1,_1>;

  using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
      TileShape_MNK, ClusterShape_MNK,
      cutlass::epilogue::collective::EpilogueTileAuto,
      float, float,
      cutlass::half_t, LayoutC, 8,
      cutlass::half_t, LayoutC, 8,
      cutlass::epilogue::NoSmemWarpSpecialized
    >::CollectiveOp;

  using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
      cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
      ElementA, LayoutA, 8,
      ElementB, LayoutB, 8,
      ElementAccumulator,
      TileShape_MNK, ClusterShape_MNK,
      cutlass::gemm::collective::StageCountAuto,
      //cutlass::gemm::KernelTmaWarpSpecializedCooperative
      cutlass::gemm::KernelTmaWarpSpecialized
    >::CollectiveOp;

  using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
      Shape<int,int,int,int>,
      CollectiveMainloop,
      CollectiveEpilogue
      //cutlass::gemm::StreamKScheduler
  >;
  
  using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

Copy link

github-actions bot commented Dec 7, 2023

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

@mnicely
Copy link
Collaborator

mnicely commented Jan 2, 2024

Closing due to inactivity. Feel free to reopen if needed.

@mnicely mnicely closed this as completed Jan 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants