From f42f25001e4a9cd32b213bb32a7926fd9f5f0c5f Mon Sep 17 00:00:00 2001 From: panzezhong Date: Thu, 29 Feb 2024 17:34:35 +0800 Subject: [PATCH] =?UTF-8?q?feat=20(kernel):=20llm=20matmul=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0alpha=20beta=E5=B1=9E=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/08-01llm/src/operators/mat_mul.cc | 11 +++++++---- src/08-01llm/src/operators/mat_mul.hh | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/08-01llm/src/operators/mat_mul.cc b/src/08-01llm/src/operators/mat_mul.cc index 99f123cd..b53b0a7c 100644 --- a/src/08-01llm/src/operators/mat_mul.cc +++ b/src/08-01llm/src/operators/mat_mul.cc @@ -6,16 +6,19 @@ namespace refactor::llm { using Op = MatMul; Op::MatMul( + float _alpha, + float _beta, decltype(transA) transA_, decltype(transB) transB_) : Operator(), - transA(transA_), - transB(transB_) {} + alpha(_alpha), beta(_beta), transA(transA_), transB(transB_) {} auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox { + auto alpha = attributes.getOrInsert( "alpha", {1.0f}).float_(); + auto beta = attributes.getOrInsert( "beta", {1.0f}).float_(); auto transA = attributes.getOrInsert("transA", {0}).int_() != 0; auto transB = attributes.getOrInsert("transB", {0}).int_() != 0; - return OpBox(std::make_unique(transA, transB)); + return OpBox(std::make_unique(alpha, beta, transA, transB)); } auto Op::typeId() -> size_t { static uint8_t ID = 1; @@ -81,7 +84,7 @@ namespace refactor::llm { auto Op::lower(TensorRefs) const -> computation::OpBox { using Op_ = computation::MatMul; - return std::make_unique(1., 1., transA, transB); + return std::make_unique(alpha, beta, transA, transB); } }// namespace refactor::llm diff --git a/src/08-01llm/src/operators/mat_mul.hh b/src/08-01llm/src/operators/mat_mul.hh index ef1ca4bb..ce2bca31 100644 --- a/src/08-01llm/src/operators/mat_mul.hh +++ b/src/08-01llm/src/operators/mat_mul.hh @@ -7,9 +7,10 @@ namespace refactor::llm { using namespace frontend; struct MatMul final : public Operator { + float alpha, beta; bool transA, transB; - MatMul(decltype(transA), decltype(transB)); + MatMul(decltype(alpha), decltype(beta), decltype(transA), decltype(transB)); static OpBox build(ModelContext const &, std::string_view, Attributes); static size_t typeId();