Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Does GEMM support vectors of alpha and beta? #1398

Closed
hychiang-git opened this issue Mar 13, 2024 · 4 comments
Closed

[QST] Does GEMM support vectors of alpha and beta? #1398

hychiang-git opened this issue Mar 13, 2024 · 4 comments
Labels
question Question

Comments

@hychiang-git
Copy link

hychiang-git commented Mar 13, 2024

What is your question?
I checked the related issues #1155 and #1000, but I am still not sure if GEMM supports vectors of alpha and beta.

If I have input matrices with shapes: x (1, M, K), w: (K, N), b: (N), alpha: (M), beta: (N). alpha and beta are shared across the batch (bsize=1 now) for simplicity, and I want to perform alpha*matmul(x, w) + beta*b. Which API or example has the closest function I want? Thanks!

@hychiang-git
Copy link
Author

Hello, I think it is supported, still but not sure how to use it.

struct ScaleType {
enum Kind {
Default, // D = scalar_alpha x Acc + scalar_beta x C
NoBetaScaling, // D = scalar_alpha x Acc + C
OnlyAlphaScaling, // D = scalar_alpha x Acc
PerChannelScaling, // D = vector_alpha x Acc + vector_beta x C
OnlyAlphaPerChannelScaling, // D = vector_alpha x Acc
Nothing // D = Acc
};
};

if(Scale == ScaleType::OnlyAlphaPerChannelScaling)
intermediate = mul_add_accumulator(scale, converted_accumulator, bias); // D = scale * Accum + bias
else
intermediate = mul_add_accumulator(alpha_, converted_accumulator, bias); // D = alpha * Accum + bias

A similar issue I found here #568

@hwu36
Copy link
Collaborator

hwu36 commented Mar 15, 2024

@apuaaChen , could you comment on how to support it via evt?

@apuaaChen
Copy link

Hi!

The pattern can be constructed through EVT. You can try to follow the example 47 streamk_broadcast to construct the epilogue. Your pattern should be something like

using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
  ThreadblockShape, 
  WarpShape, 
  ElementC, 
  AlignmentC, 
  EVTEpilogueStages
>;
// Accumulator
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
// alpha
using Alpha = cutlass::epilogue::threadblock::VisitorColBroadcast<
    OutputTileThreadMap, ElementC,
    cute::Stride<_1,_0,int32_t>
>;

// mul
using Mul0 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, ElementCompute, ElementCompute,
    cutlass::FloatRoundStyle::round_to_nearest
>:

// alpha * accumulator
using EVTMul0 = cutlass::epilogue::threadblock::Sm80EVT<
    Mul0, Alpha, Accum>;

// beta
using Beta = cutlass::epilogue::threadblock::VisitorRowBroadcast<
    OutputTileThreadMap, ElementC,
    cute::Stride<_0, _1, int32_t>  // StrideMNL
>;

// b
using B = cutlass::epilogue::threadblock::VisitorRowBroadcast<
    OutputTileThreadMap, ElementC,
    cute::Stride<_0, _1, int32_t>  // StrideMNL
>;

// mul
using Mul1 = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::multiplies, ElementCompute, ElementCompute,
    cutlass::FloatRoundStyle::round_to_nearest
>:

// beta * b
using EVTMul1 = cutlass::epilogue::threadblock::Sm80EVT<
    Mul1, Beta, B>;
    
// add
using Add = cutlass::epilogue::threadblock::VisitorCompute<
    cutlass::plus, ElementOutput, ElementCompute,
    cutlass::FloatRoundStyle::round_to_nearest
>;

// alpha * accumulator + beta * b
using EVTAdd = cutlass::epilogue::threadblock::Sm80EVT<
    Add, EVTMul0, EVTMul1>;
    
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
    OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
    cute::Stride<int64_t, _1, int64_t> // StrideMNL
>;

using EVTD = cutlass::epilogue::threadblock::Sm80EVT<
    D,
    EVTAdd>;

@hychiang-git
Copy link
Author

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Question
Projects
None yet
Development

No branches or pull requests

4 participants