Skip to content

Commit

Permalink
xe: jit: gemm: fix dangling references in kernel selection
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad committed Nov 26, 2024
1 parent e3c3d47 commit d5b95d1
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -487,41 +487,42 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,

// Select a kernel from the catalog.
std::vector<MatchParams> match_params;
MatchParams base(hw_, has_systolic, is_integrated, problem_);

match_params.emplace_back(hw_, has_systolic, is_integrated, problem_);
base.sizes.m = m;
base.sizes.n = n;
base.sizes.k = k;
base.sizes.batch = batch;
base.stepping = stepping;

match_params[0].sizes.m = m;
match_params[0].sizes.n = n;
match_params[0].sizes.k = k;
match_params[0].sizes.batch = batch;
match_params[0].stepping = stepping;

auto tags = const_cast<char *>(match_params[0].tags);
auto tags = const_cast<char *>(base.tags);
while (*tags)
tags++;
if (can_2d_a) *tags++ = kcatalog::ReqBlock2DA;
if (can_2d_b) *tags++ = kcatalog::ReqBlock2DB;
if (can_2d_c) *tags++ = kcatalog::ReqBlock2DC;

match_params.push_back(base);

bool fpmath_tf32 = mode & mode_tf32;
bool fpmath_bf16 = mode & mode_bf16x1;
bool fpmath_f16 = mode & mode_f16x1;

auto add_mode_matches = [&](bool has_mode, const char *(*match)(Type)) {
if (!has_mode) return;
auto &def = match_params[0].selector.precisions;
auto &def = base.selector.precisions;
if (match(problem_.Ta)) {
match_params.emplace_back(match_params[0]);
match_params.push_back(base);
match_params.back().selector.precisions[0] = match(problem_.Ta);
match_params.back().selector.precisions[1] = def[1];
}
if (match(problem_.Tb)) {
match_params.emplace_back(match_params[0]);
match_params.push_back(base);
match_params.back().selector.precisions[0] = def[0];
match_params.back().selector.precisions[1] = match(problem_.Tb);
}
if (match(problem_.Ta) && match(problem_.Tb)) {
match_params.emplace_back(match_params[0]);
match_params.push_back(base);
match_params.back().selector.precisions[0] = match(problem_.Ta);
match_params.back().selector.precisions[1] = match(problem_.Tb);
}
Expand Down Expand Up @@ -551,7 +552,7 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,

EvaluateParams eval_params;

eval_params.sizes = match_params[0].sizes;
eval_params.sizes = base.sizes;
eval_params.alpha = alpha;
eval_params.beta = beta;
eval_params.postOps = !problem_.postOps.empty();
Expand Down

0 comments on commit d5b95d1

Please sign in to comment.