Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic region to loops #2306

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Generic region to loops #2306

wants to merge 3 commits into from

Conversation

nsmithtt
Copy link
Contributor

@nsmithtt nsmithtt commented Feb 27, 2025

This change adds a few new passes:

linalg.generic to affine passes

Simply calls the upstream pass for converting a linalg.generic into an affine loop nest.

Lower affine pass

Again, just uses an upstream pass for converting affine into SCF and arith.

Linearize memref accesses pass

A custom TTIR pass that takes a nested loop structure over n-dimensional memrefs and linearizes them into a single dimension. This is a useful because circular buffers in metal are only one-dimensional.

Example, this pass will convert the following code:

  affine.for %arg5 = 0 to 2 {        
    affine.for %arg6 = 0 to 4 {        
      %0 = affine.load %arg2[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>        
      %1 = affine.load %arg3[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>        
      %2 = "ttir.tile_maximum"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>        
      affine.store %2, %arg4[%arg5, %arg6] : memref<2x4x!tt.tile<32x32, f32>, #l1_>                                                       
    }        
  }                                                                              

Into:

  %collapse_shape = memref.collapse_shape %arg2 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>        
  %collapse_shape_0 = memref.collapse_shape %arg3 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>        
  %collapse_shape_1 = memref.collapse_shape %arg4 [[0, 1]] : memref<2x4x!tt.tile<32x32, f32>, #l1_> into memref<8x!tt.tile<32x32, f32>, #l1_>        
  affine.for %arg5 = 0 to 2 {                                    
    affine.for %arg6 = 0 to 4 {                                                       
      %0 = affine.load %collapse_shape[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>        
      %1 = affine.load %collapse_shape_0[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>                                                                         
      %2 = "ttir.tile_maximum"(%0, %1) : (!tt.tile<32x32, f32>, !tt.tile<32x32, f32>) -> !tt.tile<32x32, f32>        
      affine.store %2, %collapse_shape_1[%arg5 * 4 + %arg6] : memref<8x!tt.tile<32x32, f32>, #l1_>        
    }                                                         
  }                                     

Closes #1910
Closes #1911

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/1)

arg.getLoc(), arg, collapsedDims);
rewriter.replaceAllUsesExcept(arg, linearizedArg->getResult(0),
linearizedArg);
for (auto user : linearizedArg->getUsers()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ llvm-qualified-auto ⚠️
auto user can be declared as auto *user

Suggested change
for (auto user : linearizedArg->getUsers()) {
for (auto *user : linearizedArg->getUsers()) {

Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great pass documentation!

- `linalg.generic` to affine passes
- Linearize memref accesses pass

Closes #1910
Closes #1911
@nsmithtt nsmithtt force-pushed the nsmith/generic-loops2 branch from 7b686f4 to 22f3e49 Compare February 28, 2025 15:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Metal Direct Pass] Linearize memref [Metal Direct Pass] Loop Nest Generation
2 participants