From d5b95d1897e0b0207e6a211ec4f9bd97ea959127 Mon Sep 17 00:00:00 2001 From: Peter Caday Date: Mon, 25 Nov 2024 15:23:19 -0800 Subject: [PATCH] xe: jit: gemm: fix dangling references in kernel selection --- src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp | 27 +++++++++++----------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp b/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp index cd63fcac3dd..b1a5475186c 100644 --- a/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp +++ b/src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp @@ -487,41 +487,42 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch, // Select a kernel from the catalog. std::vector match_params; + MatchParams base(hw_, has_systolic, is_integrated, problem_); - match_params.emplace_back(hw_, has_systolic, is_integrated, problem_); + base.sizes.m = m; + base.sizes.n = n; + base.sizes.k = k; + base.sizes.batch = batch; + base.stepping = stepping; - match_params[0].sizes.m = m; - match_params[0].sizes.n = n; - match_params[0].sizes.k = k; - match_params[0].sizes.batch = batch; - match_params[0].stepping = stepping; - - auto tags = const_cast(match_params[0].tags); + auto tags = const_cast(base.tags); while (*tags) tags++; if (can_2d_a) *tags++ = kcatalog::ReqBlock2DA; if (can_2d_b) *tags++ = kcatalog::ReqBlock2DB; if (can_2d_c) *tags++ = kcatalog::ReqBlock2DC; + match_params.push_back(base); + bool fpmath_tf32 = mode & mode_tf32; bool fpmath_bf16 = mode & mode_bf16x1; bool fpmath_f16 = mode & mode_f16x1; auto add_mode_matches = [&](bool has_mode, const char *(*match)(Type)) { if (!has_mode) return; - auto &def = match_params[0].selector.precisions; + auto &def = base.selector.precisions; if (match(problem_.Ta)) { - match_params.emplace_back(match_params[0]); + match_params.push_back(base); match_params.back().selector.precisions[0] = match(problem_.Ta); match_params.back().selector.precisions[1] = def[1]; } if (match(problem_.Tb)) { - match_params.emplace_back(match_params[0]); + match_params.push_back(base); match_params.back().selector.precisions[0] = def[0]; match_params.back().selector.precisions[1] = match(problem_.Tb); } if (match(problem_.Ta) && match(problem_.Tb)) { - match_params.emplace_back(match_params[0]); + match_params.push_back(base); match_params.back().selector.precisions[0] = match(problem_.Ta); match_params.back().selector.precisions[1] = match(problem_.Tb); } @@ -551,7 +552,7 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch, EvaluateParams eval_params; - eval_params.sizes = match_params[0].sizes; + eval_params.sizes = base.sizes; eval_params.alpha = alpha; eval_params.beta = beta; eval_params.postOps = !problem_.postOps.empty();