-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MFMA] Refactor dot pipeline to reduce code duplication #400
Conversation
This PR generalizes llvm values generalted by ttg->llvm op loading: shared to mfma op generates array of repNxrepK vectors of matrix elements.
5ef7bbb
to
327d9aa
Compare
if (nonKIdx == 1) | ||
waveId = udiv(waveId, i32_val(wpt[0])); | ||
return urem(urem(waveId, i32_val(wpt[nonKIdx])), | ||
i32_val(tensorSizeNonK / elemPerInstrNonK)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused about this part.
Say we have warpsPerCTA={2,2}, waveId=1, then we are talking about the top left wave in the workgroup, right?
If so, we have
- for opA, i.e. nonKIdx=0, spatialWarpId = (waveId % wpt[0]) % (M / 32) = 1
- for opB, i.e. nonKIdx=1, spatialWarpId = ((waveId/wpt[0])%wpt[1]) % (N / 8) = 0
But shouldn't wave1 has index 1 for opB and index 0 for opA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But shouldn't wave1 has index 1 for opB and index 0 for opA?
It was originally implemented like this. I am not sure, if one or another orientation have advantages. I did not try other layout.
If you think transposed wave indexing could have advantages or preferable for style reasons, we can swap it and see what happens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so it's assumed that the 2x2 layout is
wave0 wave2
wave1 wave3
These two assumptions are not about styles, but correctness. Maybe this is the reason why some gemm tests failed in #402, in which the mfma layout is used directly for global store.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are probably right. I've investigated failures in 402 a little a few weeks ago, and found that current mfma layout is not compatible with global store implementation.
auto rawElems = elems[n1 * i + j]; | ||
Value convertedElems; | ||
if (type.isF32()) { | ||
convertedElems = extract_element(type, rawElems, i32_val(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to dereference rawElems here for i32 input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is related to processing of A/B dot operands.
This code is an adapter from generic vec<base_type x kwidth>
format to the format that rocdl intrinsic expects.
Previously this transformation was done in Type converter (see TypeConverter.cpp below), but I feel that it is better to have this transformation closer to MFMA emitting code. So I simplified type converter and moved this code here.
A and B could be one of variety of types: fp32, fp16, bf16, int8, fp8*
But rocdl mfma intrinsics takes some of these types in a packed integer format.
For example:
fp32xfp32 -> fp32
version intrinsic takes plain scalarfp32
,fp32
as A/B arguments.fp8xfp8 -> fp32
orint8xint8 -> fp32
version takes several values in for of packed int32 or int64 (depending on kwidth of operation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There might be a disagreement about wave layout that causes the some gemm test failures in #402. We'll fix that one later.
This PR:
shared->mfma dot op
layout conversions. Do not pack data types in int32 or int64