Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use a single wgmma wait_group to flush async wgmma pipeline #3843

Merged
merged 3 commits into from
Feb 8, 2025

Conversation

rdspring1
Copy link
Collaborator

This PR optimizes RAW sync insertion for wgmma operations.

Problem:

  1. Inserting multiple RAW syncs for wgmma pipeline can slowdown performance.
  2. Combining wgmma.commit_group and wgmma.wait_group together can cause the compiler to serialize wgmma.mma_async. wgmma.commit_group is required when issuing wgmma operation group. wgmma.wait_group is required when waiting for wgmma operation groups to finish.

Proposed Solution:

  1. Issue a single wgmma.commit_group after completing mma operations but before any consumer operations.
  2. Add bool requires_commit to lower_utils::getSyncExprs, so the commit phase is optional.

Why?

  • wgmma.wait_group 0 flushes the entire pipeline, so all wgmma operations are complete. Additional RAW syncs are unnecessary.

Cuda Examples

From MLPBenchmarkTest.FwdHorizontalFusion/data_parallel_warpspec

Without Fix: 4 RAW wgmma syncs

__global__ void kernel_without_raw_sync_fix(args) {
  // Compute MMA
  
  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  Array<__bfloat, 64, 8> T18;
  #pragma unroll
  for(nvfuser_index_t i65 = 0; i65 < 16; ++i65) {
    nvfuser_index_t i66;
    i66 = 4 * i65;
    #pragma unroll
    for(nvfuser_index_t i67 = 0; i67 < 2; ++i67) {
      nvfuser_index_t i68;
      i68 = i66 + (2 * i67);
      #pragma unroll
      for(nvfuser_index_t i69 = 0; i69 < 2; ++i69) {
        nvfuser_index_t i70;
        i70 = i68 + i69;
        Array<__bfloat, 1, 1> T23;
        T23[0]
           = __float2bfloat(T22[i70]);
        T18[i70]
           = T23[0];
      }
    }
  }
  #pragma unroll
  for(nvfuser_index_t i71 = 0; i71 < 8; ++i71) {
    asm volatile(
      "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
      :
      :"r"((uint32_t)((toSmem(T26) + ((((nvfuser_index_t)threadIdx.y) * 16384) + (((i71 / 4) * 8192) + ((i17 * 128) + (((((((nvfuser_index_t)threadIdx.x) % 32) / 16) + ((i71 % 4) * 2)) ^ (i17 % 8)) * 16))))))),
       "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T18[(8 * i71)]))[0]),
       "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T18[(8 * i71)]))[1]),
       "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T18[(8 * i71)]))[2]),
       "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T18[(8 * i71)]))[3])
    );
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i72 = 0; i72 < 2; ++i72) {
    asm volatile("fence.proxy.async;\n");
    if ((Hopper::electSync(4294967295U) && b26)) {
      Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr23, (Array<nvfuser_index_t, 2, 1>{(i9 + (64 * i72)), i21}) }), (i22 + (8192 * i72)));
    }
  }
  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  asm volatile("wgmma.commit_group.sync.aligned;\n");
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");

  // Compute Epilogue
}

With Fix: 1 RAW wgmma syncs

__global__ void kernel_with_raw_sync_fix(args) {
  // Compute MMA
  
  asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
  Array<__bfloat, 64, 8> T20;
  #pragma unroll
  for(nvfuser_index_t i52 = 0; i52 < 16; ++i52) {
    nvfuser_index_t i53;
    i53 = 4 * i52;
    #pragma unroll
    for(nvfuser_index_t i54 = 0; i54 < 2; ++i54) {
      nvfuser_index_t i55;
      i55 = i53 + (2 * i54);
      #pragma unroll
      for(nvfuser_index_t i56 = 0; i56 < 2; ++i56) {
        nvfuser_index_t i57;
        i57 = i55 + i56;
        Array<__bfloat, 1, 1> T25;
        T25[0]
           = __float2bfloat(T24[i57]);
        T20[i57]
           = T25[0];
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i58 = 0; i58 < 8; ++i58) {
    if ((b30 && (i31 < (-(16 * i58))))) {
      asm volatile(
        "stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
        :
        :"r"((uint32_t)((toSmem(T27) + ((((nvfuser_index_t)threadIdx.y) * 16384) + (((i58 / 4) * 8192) + ((i16 * 128) + (((((((nvfuser_index_t)threadIdx.x) % 32) / 16) + ((i58 % 4) * 2)) ^ (i16 % 8)) * 16))))))),
         "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T20[(8 * i58)]))[0]),
         "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T20[(8 * i58)]))[1]),
         "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T20[(8 * i58)]))[2]),
         "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T20[(8 * i58)]))[3])
      );
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i59 = 0; i59 < 2; ++i59) {
    asm volatile("fence.proxy.async;\n");
    if (((Hopper::electSync(4294967295U) && b26) && b29)) {
      Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr19, (Array<nvfuser_index_t, 2, 1>{(i8 + (64 * i59)), i21}) }), (i18 + (8192 * i59)));
    }
  }

  // Compute Epilogue
}

…lar buffering for-loop

* Add requires_commit arg to getSyncExprs
* No wgmma commit required for ReadAfterWriteSyncs
* Create flush_async_mma_pipeline
@rdspring1
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 6, 2025

Review updated until commit fdd0f98

Description

  • Use single wgmma.wait_group to flush async pipeline

  • Add requires_commit to getSyncExprs for optional commit

  • Track async mma pipeline state with fill_async_mma_pipeline_ and flush_async_mma_pipeline_


Changes walkthrough 📝

Relevant files
Enhancement
insert_syncs.cpp
Enhance sync logic for wgmma operations                                   

csrc/device_lower/pass/insert_syncs.cpp

  • Track async mma pipeline state with fill_async_mma_pipeline_ and
    flush_async_mma_pipeline_
  • Modify sync insertion logic to insert syncs after async ops
  • Skip unnecessary wgmma.wait_group(0) if pipeline is already flushed
  • Add checkAsyncMmaPipeline to ensure pipeline is empty at kernel end
  • +55/-2   
    utils.cpp
    Add commit condition to sync expressions                                 

    csrc/device_lower/utils.cpp

  • Add requires_commit parameter to getSyncExprs
  • Conditionally add wgmma.commit_group.sync.aligned based on
    requires_commit
  • +8/-3     
    utils.h
    Update getSyncExprs signature                                                       

    csrc/device_lower/utils.h

  • Update getSyncExprs signature to include requires_commit parameter
  • +2/-1     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The logic for fill_async_mma_pipeline_ and flush_async_mma_pipeline_ might not handle all edge cases correctly. The conditions and transitions between these states need thorough validation.

    // An mma operation is added to async mma pipeline.
    fill_async_mma_pipeline_ = true;
    // async mma pipeline has not been flushed yet.
    flush_async_mma_pipeline_ = false;
    Code Clarity

    The logic for inserting sync expressions could be more clearly separated into distinct functions or methods to improve readability and maintainability.

    // here.
    // TODO: unify the handle of cp.async
    std::unordered_map<AsyncOpType, std::unordered_set<Expr*>> input_async_ops;
    for (auto inp : expr->inputs()) {
      auto def = inp->definition();
      auto async_type = ir_utils::getAsyncOpType(def);
    
      NVF_ERROR(
          !flush_async_mma_pipeline_ || !fill_async_mma_pipeline_,
          "The async mma pipeline cannot be filled without encountering ",
          "another mma op after it is flushed with a RAW sync.");
    
      // Detect a expression that consumes async mma operation.
      // The async mma pipeline is already flushed and is empty.
      // Adding a RAW wgmma.wait_group(0) is not necessary so skip it.
      if (async_type == AsyncOpType::WgMma && !fill_async_mma_pipeline_ &&
          flush_async_mma_pipeline_) {
        continue;
      }
    
      if (async_type != AsyncOpType::NotAsync &&
          async_type != AsyncOpType::CpAsync) {
        input_async_ops[async_type].insert(def);
        // async mma pipeline is flushed.
        flush_async_mma_pipeline_ = true;
        // No mma operations are active in the async mma pipeline.
        fill_async_mma_pipeline_ = false;
      }
    }
    for (const auto& [async_type, ops] : input_async_ops) {
      auto sync_exprs = lower_utils::getSyncExprs(
          async_type,
          /*keep_stages=*/0,
          /*requires_commit=*/async_type != AsyncOpType::WgMma);
      for (auto sync_expr : sync_exprs) {
        insertSyncExpr(ops, expr, sync_expr, nullptr);
      }
      for (auto op : ops) {
    Code Clarity

    The getSyncExprs function now has additional parameters with default values. Ensure that the default values are appropriate and that the function behaves as expected in all scenarios.

    std::vector<Expr*> getSyncExprs(
        AsyncOpType async_type,
        int64_t keep_stages,
        bool requires_commit) {
      std::vector<Expr*> sync_exprs;
      sync_exprs.reserve(2);
      if (requires_commit) {
        auto commit = IrBuilder::create<kir::AsyncCommit>(async_type);
        sync_exprs.push_back(commit);
      }
      auto wait = IrBuilder::create<kir::AsyncWait>(async_type, keep_stages);
      sync_exprs.push_back(wait);
      return sync_exprs;

    @rdspring1 rdspring1 requested a review from zasdfgbnm February 6, 2025 18:57
    @rdspring1 rdspring1 marked this pull request as ready for review February 6, 2025 18:58
    @@ -759,6 +769,9 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
    }

    private:
    //! Only a single wgmma wait_group to flush async mma pipeline.
    bool flush_async_mma_pipeline = false;
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is the reason for generating multiple commit and wait because there are multiple use of mma results, and one commit and wait is generated for each use? If this is the case, would it make more sense to promote the local variable input_async_ops to a class member async_ops_to_sync_ where we add an async op and its type to it when seeing a new async op with that type, and insert a commit-wait and remove the async type (together with all ops with that type) if seeing an expr whose input's definition is that async type?

    The current approach proposed here is basically saying: "If I have ever inserted a mma wait, then never insert one again in the fusion". I don't think this is safe, for example, if we have a kernel:

    D1 = mma(A1, B1);
    output1 = relu(D1);
    D2 = mma(A2, B2);
    output2 = relu(D2);

    then we do need a wait before the output2 = relu(D2).

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Is the reason for generating multiple commit and wait because there are multiple use of mma results, and one commit and wait is generated for each use?

    Yes.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    At the moment, all RAW sync for wgmma occur outside K loop, so no new wgmma ops are issued. It is important not to issue wgmma.commit_group. Essentially, RAW sync is flush async mma pipeline.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    TL;DR: I pushed a commit to set flush_async_mma_pipeline := false when encountering a mma op, so we can issue more than one RAW sync.

    Changes

    • fill_async_mma_pipeline is true when any mma expression is issued. A RAW sync is required before any consumer operations use the results of mma expression.
    • flush_async_mma_pipeline is true when a RAW sync is issued for async mma pipeline. At the moment, the RAW sync for async wgmma is wgmma.wait_group(0). All prior mma operations are completed after this operation.
    • fill_async_mma_pipeline is always false at end of ReadAfterWriteSyncs.

    Two independent mma expressions; No shared circular buffered main loop

    #  <<<--- Step 0: fill_async_mma_pipeline == false && flush_async_mma_pipeline == false
    
    D1 = mma(A1, B1)
    #  <<<--- Step 1: fill_async_mma_pipeline := true && flush_async_mma_pipeline == false
    
    output1 = relu(D1)
    #  <<<--- Step 2: fill_async_mma_pipeline := false && flush_async_mma_pipeline == true
    # (Add wgmma.wait_group(0) before relu)
    
    D2 = mma(A2, B2)
    #  <<<--- Step 3: fill_async_mma_pipeline := true && flush_async_mma_pipeline == false
    
    output2 = relu(D2)
    #  <<<--- Step 4: fill_async_mma_pipeline := false && flush_async_mma_pipeline == true
    # (Add wgmma.wait_group(0) before relu)

    Horizontal fused mma expressions; Shared circular buffered main loop

    • Given that both mma operations share circular buffer main loop, they are grouped together with the same wgmma.commit_group.
    #  <<<--- Step 0: fill_async_mma_pipeline == false && flush_async_mma_pipeline == false
    
    D1 = mma(A1, B1)
    D2 = mma(A1, B2)
    #  <<<--- Step 1: fill_async_mma_pipeline := true && flush_async_mma_pipeline == false
    
    output1 = relu(D1)
    output2 = relu(D2)
    #  <<<--- Step 2: fill_async_mma_pipeline := false && flush_async_mma_pipeline == true
    # (Add wgmma.wait_group(0) before relu)

    Single mma with epilogue

    • Given that both mma operations share circular buffer main loop, they are grouped together with the same wgmma.commit_group.
    #  <<<--- Step 0: fill_async_mma_pipeline == false && flush_async_mma_pipeline == false
    
    D1 = mma(A1, B1)
    #  <<<--- Step 1: fill_async_mma_pipeline := true && flush_async_mma_pipeline == false
    
    b = D1 + bias
    #  <<<--- Step 2: fill_async_mma_pipeline := false && flush_async_mma_pipeline == true
    # (Add wgmma.wait_group(0) before add)
    
    a = relu(b)
    #  <<<--- Step 3: fill_async_mma_pipeline := false && flush_async_mma_pipeline == true
    # (Do nothing; RAW sync not required)

    Copy link
    Collaborator Author

    @rdspring1 rdspring1 Feb 7, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Your proposal is more flexible than what is implemented now because it can be used for wgmma and tma store.
    There are failure cases in the sync logic for both wgmma and tma store. I'd prefer to fix wgmma separately now. Then, try this proposal when fixing tma store failures.

    Insert a commit-wait and remove the async type (together with all ops with that type) if seeing an expr whose input's definition is that async type.

    This would flush an async type upon encountering an input definition of that async type.

    Would this be sub-optimal for tma store async group?

    If RAW sync for tma store async group is always outside circular buffer loop then maybe it is fine.

    csrc/device_lower/pass/insert_syncs.cpp Outdated Show resolved Hide resolved
    @rdspring1
    Copy link
    Collaborator Author

    !test

    @rdspring1 rdspring1 merged commit 6328cf8 into main Feb 8, 2025
    54 checks passed
    @rdspring1 rdspring1 deleted the raw_wgmma_sync_fix branch February 8, 2025 18:46
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants