Skip to content

Commit

Permalink
feat (kernel): llm matmul增加alpha beta属性
Browse files Browse the repository at this point in the history
  • Loading branch information
PanZezhong1725 committed Feb 29, 2024
1 parent bf0b90a commit f42f250
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
11 changes: 7 additions & 4 deletions src/08-01llm/src/operators/mat_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@ namespace refactor::llm {
using Op = MatMul;

Op::MatMul(
float _alpha,
float _beta,
decltype(transA) transA_,
decltype(transB) transB_)
: Operator(),
transA(transA_),
transB(transB_) {}
alpha(_alpha), beta(_beta), transA(transA_), transB(transB_) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
auto alpha = attributes.getOrInsert( "alpha", {1.0f}).float_();
auto beta = attributes.getOrInsert( "beta", {1.0f}).float_();
auto transA = attributes.getOrInsert("transA", {0}).int_() != 0;
auto transB = attributes.getOrInsert("transB", {0}).int_() != 0;
return OpBox(std::make_unique<Op>(transA, transB));
return OpBox(std::make_unique<Op>(alpha, beta, transA, transB));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
Expand Down Expand Up @@ -81,7 +84,7 @@ namespace refactor::llm {

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::MatMul;
return std::make_unique<Op_>(1., 1., transA, transB);
return std::make_unique<Op_>(alpha, beta, transA, transB);
}

}// namespace refactor::llm
3 changes: 2 additions & 1 deletion src/08-01llm/src/operators/mat_mul.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ namespace refactor::llm {
using namespace frontend;

struct MatMul final : public Operator {
float alpha, beta;
bool transA, transB;

MatMul(decltype(transA), decltype(transB));
MatMul(decltype(alpha), decltype(beta), decltype(transA), decltype(transB));

static OpBox build(ModelContext const &, std::string_view, Attributes);
static size_t typeId();
Expand Down

0 comments on commit f42f250

Please sign in to comment.