Skip to content

Commit

Permalink
Add some can implement rules of hopper convolution. (#1835)
Browse files Browse the repository at this point in the history
  • Loading branch information
Junkai-Wu authored Sep 25, 2024
1 parent 44dae8b commit e2b0789
Showing 1 changed file with 55 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,61 @@ struct CollectiveConv<
return false;
}

if (is_im2col_A || is_im2col_B) {
// Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1]
constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1);
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
for (int i = 0; i < problem_shape.RankS; ++i) {
implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1);
}
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
for (int i = 0; i < problem_shape.RankS; ++i) {
implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1);
}

if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
return false;
}
}

// Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized)
if constexpr (ConvOp == conv::Operator::kWgrad) {

const auto & input_shape = problem_shape.shape_A;
const auto & input_stride = problem_shape.stride_A;

implementable &= input_stride[ProblemShape::RankT - 1] == 1;
int input_shape_size = 1;
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
input_shape_size *= input_shape[i + 1];
implementable &= input_stride[i] == input_shape_size;
}

const auto & output_shape = problem_shape.shape_C;
const auto & output_stride = problem_shape.stride_C;

implementable &= output_stride[ProblemShape::RankT - 1] == 1;
int output_shape_size = 1;
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
output_shape_size *= output_shape[i + 1];
implementable &= output_stride[i] == output_shape_size;
}

if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
return false;
}
}

// Conv kernels only support cross correlation mode currently.
implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation;

if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n");
return false;
}

if (problem_shape.groups > 1) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n");
return false;
Expand Down

0 comments on commit e2b0789

Please sign in to comment.