-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] Does GEMM support vectors of alpha and beta? #1398
Comments
Hello, I think it is supported, still but not sure how to use it. cutlass/include/cutlass/epilogue/thread/scale_type.h Lines 51 to 60 in ffa34e7
cutlass/include/cutlass/epilogue/thread/linear_combination_relu.h Lines 274 to 279 in ffa34e7
A similar issue I found here #568 |
@apuaaChen , could you comment on how to support it via evt? |
Hi! The pattern can be constructed through EVT. You can try to follow the example 47 streamk_broadcast to construct the epilogue. Your pattern should be something like using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC,
EVTEpilogueStages
>;
// Accumulator
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
// alpha
using Alpha = cutlass::epilogue::threadblock::VisitorColBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_1,_0,int32_t>
>;
// mul
using Mul0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>:
// alpha * accumulator
using EVTMul0 = cutlass::epilogue::threadblock::Sm80EVT<
Mul0, Alpha, Accum>;
// beta
using Beta = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_0, _1, int32_t> // StrideMNL
>;
// b
using B = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementC,
cute::Stride<_0, _1, int32_t> // StrideMNL
>;
// mul
using Mul1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementCompute, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>:
// beta * b
using EVTMul1 = cutlass::epilogue::threadblock::Sm80EVT<
Mul1, Beta, B>;
// add
using Add = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, ElementOutput, ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest
>;
// alpha * accumulator + beta * b
using EVTAdd = cutlass::epilogue::threadblock::Sm80EVT<
Add, EVTMul0, EVTMul1>;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, _1, int64_t> // StrideMNL
>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
D,
EVTAdd>; |
Thanks! |
What is your question?
I checked the related issues #1155 and #1000, but I am still not sure if GEMM supports vectors of alpha and beta.
If I have input matrices with shapes: x (1, M, K), w: (K, N), b: (N), alpha: (M), beta: (N). alpha and beta are shared across the batch (bsize=1 now) for simplicity, and I want to perform
alpha*matmul(x, w) + beta*b
. Which API or example has the closest function I want? Thanks!The text was updated successfully, but these errors were encountered: