Skip to content

Commit

Permalink
Add Sm90LinCombPerColBias (#1774)
Browse files Browse the repository at this point in the history
Co-authored-by: Jiayu Sun <[email protected]>
  • Loading branch information
ucassjy and Jiayu Sun authored Sep 4, 2024
1 parent 6c30441 commit 7369adc
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
17 changes: 17 additions & 0 deletions include/cutlass/epilogue/fusion/operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,23 @@ struct LinCombPerRowBiasEltAct
static constexpr bool IsEltActSupported = true;
};

// D = alpha * acc + beta * C + per-column bias
template<
class ElementOutput_,
class ElementCompute_,
class ElementBias_ = ElementOutput_,
class ElementSource_ = ElementOutput_,
class ElementScalar_ = ElementCompute_,
int AlignmentBias_ = 128 / sizeof_bits_v<ElementBias_>,
FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest
>
struct LinCombPerColBias
: LinearCombination<ElementOutput_, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> {
using ElementBias = ElementBias_;
static constexpr int AlignmentBias = AlignmentBias_;
static constexpr bool IsPerColBiasSupported = true;
};

// D = activation(alpha * acc + beta * C + per-row bias)
// aux = alpha * acc + beta * C + per-row bias
template<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,90 @@ struct FusionCallbacks<

/////////////////////////////////////////////////////////////////////////////////////////////////

// D = alpha * acc + beta * C + per-column bias
template<
int StagesC,
class CtaTileShapeMNK,
class EpilogueTile,
class ElementOutput,
class ElementCompute,
class ElementBias = ElementOutput,
class ElementSource = ElementOutput,
class ElementScalar = ElementCompute,
int AlignmentBias = 128 / sizeof_bits_v<ElementBias>,
FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest
>
using Sm90LinCombPerColBias =
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementOutput, ElementCompute, RoundStyle>, // beta * C + (alpha * acc + bias)
Sm90ScalarBroadcast<ElementScalar>, // beta
Sm90SrcFetch<ElementSource>, // C
Sm90EVT<Sm90Compute<homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc + bias
Sm90ScalarBroadcast<ElementScalar>, // alpha
Sm90AccFetch, // acc
Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_0,_1,int>, AlignmentBias> // bias
>
>;

template <
int StagesC,
int StagesD,
int FragmentSize,
bool ReuseSmemC,
bool DelayTmaStore,
class ElementOutput,
class ElementCompute,
class ElementBias,
class ElementSource,
class ElementScalar,
int AlignmentBias,
FloatRoundStyle RoundStyle,
class CtaTileShapeMNK,
class EpilogueTile
>
struct FusionCallbacks<
epilogue::Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC, DelayTmaStore>,
fusion::LinCombPerColBias<ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>,
CtaTileShapeMNK,
EpilogueTile
> : Sm90LinCombPerColBias<
StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> {
using Impl = Sm90LinCombPerColBias<
StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;
using Operation = fusion::LinCombPerColBias<
ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>;

struct Arguments {
ElementScalar alpha = ElementScalar(1);
ElementScalar beta = ElementScalar(0);
ElementScalar const* alpha_ptr = nullptr;
ElementScalar const* beta_ptr = nullptr;

using StrideBias = Stride<_0,_1,int>;
ElementBias const* bias_ptr = nullptr;
StrideBias dBias = {};

operator typename Impl::Arguments() const {
return
{ // ternary op : beta * C + (alpha * acc + bias)
{{beta}, {beta_ptr}}, // leaf args : beta
{}, // leaf args : C
{ // ternary op : alpha * acc + bias
{{alpha}, {alpha_ptr}}, // leaf args : alpha
{}, // leaf args : acc
{bias_ptr, ElementBias(0), dBias}, // leaf args : bias
{} // ternary args : multiply_add
}, // end ternary op
{} // ternary args : multiply_add
}; // end ternary op
}
};

// Ctor inheritance
using Impl::Impl;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

// D = activation(alpha * acc + beta * C + per-row bias)
template<
class CtaTileShapeMNK,
Expand Down

0 comments on commit 7369adc

Please sign in to comment.