Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MFMA] Refactor dot pipeline to reduce code duplication #400

Merged
merged 8 commits into from
Dec 13, 2023

Conversation

binarman
Copy link

This PR:

  • simplifies data types generated by shared->mfma dot op layout conversions. Do not pack data types in int32 or int64
  • reduce code duplication between fast/normal path
  • reduce code duplication between operand A and operand B

This PR generalizes llvm values generalted by ttg->llvm op loading:
shared to mfma op generates array of repNxrepK vectors of matrix elements.
if (nonKIdx == 1)
waveId = udiv(waveId, i32_val(wpt[0]));
return urem(urem(waveId, i32_val(wpt[nonKIdx])),
i32_val(tensorSizeNonK / elemPerInstrNonK));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused about this part.
Say we have warpsPerCTA={2,2}, waveId=1, then we are talking about the top left wave in the workgroup, right?
If so, we have

  • for opA, i.e. nonKIdx=0, spatialWarpId = (waveId % wpt[0]) % (M / 32) = 1
  • for opB, i.e. nonKIdx=1, spatialWarpId = ((waveId/wpt[0])%wpt[1]) % (N / 8) = 0

But shouldn't wave1 has index 1 for opB and index 0 for opA?
example

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But shouldn't wave1 has index 1 for opB and index 0 for opA?

It was originally implemented like this. I am not sure, if one or another orientation have advantages. I did not try other layout.

If you think transposed wave indexing could have advantages or preferable for style reasons, we can swap it and see what happens.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it's assumed that the 2x2 layout is

wave0 wave2
wave1 wave3

These two assumptions are not about styles, but correctness. Maybe this is the reason why some gemm tests failed in #402, in which the mfma layout is used directly for global store.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are probably right. I've investigated failures in 402 a little a few weeks ago, and found that current mfma layout is not compatible with global store implementation.

auto rawElems = elems[n1 * i + j];
Value convertedElems;
if (type.isF32()) {
convertedElems = extract_element(type, rawElems, i32_val(0));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to dereference rawElems here for i32 input?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is related to processing of A/B dot operands.
This code is an adapter from generic vec<base_type x kwidth> format to the format that rocdl intrinsic expects.

Previously this transformation was done in Type converter (see TypeConverter.cpp below), but I feel that it is better to have this transformation closer to MFMA emitting code. So I simplified type converter and moved this code here.

A and B could be one of variety of types: fp32, fp16, bf16, int8, fp8*
But rocdl mfma intrinsics takes some of these types in a packed integer format.

For example:

  • fp32xfp32 -> fp32 version intrinsic takes plain scalar fp32, fp32 as A/B arguments.
  • fp8xfp8 -> fp32 or int8xint8 -> fp32 version takes several values in for of packed int32 or int64 (depending on kwidth of operation)

Copy link

@zhanglx13 zhanglx13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

There might be a disagreement about wave layout that causes the some gemm test failures in #402. We'll fix that one later.

@alefimov-amd alefimov-amd merged commit f2afd65 into ROCm:triton-mlir Dec 13, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants