Skip to content

Commit

Permalink
feat(llm): 添加一个既支持 transpose 又支持广播的 MatMul
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 25, 2024
1 parent 6877516 commit 9dce4b3
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/08-01llm/src/operators.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "llm/operators.h"
#include "operators/rms_normalization.hh"
#include "operators/mat_mul.hh"

namespace refactor::llm {
using namespace frontend;

void register_() {
// clang-format off
#define REGISTER(NAME, CLASS) Operator::register_<CLASS>("llm::" #NAME)
REGISTER(MatMul, MatMul);
#undef REGISTER
// clang-format on
}
Expand Down
86 changes: 86 additions & 0 deletions src/08-01llm/src/operators/mat_mul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "computation/operators/mat_mul.h"
#include "common.h"
#include "mat_mul.hh"

namespace refactor::llm {
using Op = MatMul;

Op::MatMul(
decltype(transA) transA_,
decltype(transB) transB_)
: transA(transA_),
transB(transB_) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
auto transA = attributes.getOrInsert("transA", {0}).int_() != 0;
auto transB = attributes.getOrInsert("transB", {0}).int_() != 0;
return OpBox(std::make_unique<Op>(transA, transB));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto Op::opTypeId() const -> size_t { return typeId(); }
auto Op::opTypeName() const -> std::string_view { return "llm::MatMul"; }

auto Op::infer(TensorRefs inputs, InferOptions const &options) const -> InferResult {
EXPECT_SIZE(2)

auto const &a = inputs[0];
auto const &b = inputs[1];
auto dataType = a.dataType;
if (!dataType.isNumberic() || b.dataType != dataType) {
return Err(InferError(ERROR_MSG("Input data type not support")));
}
auto sa = a.shape, sb = b.shape;
switch (sa.size()) {
case 1:
sa.insert(sa.begin(), DimExpr(1));
break;
case 0:
return Err(InferError(ERROR_MSG("Input shape not support")));
default:
break;
}
switch (sb.size()) {
case 1:
sb.emplace_back(1);
break;
case 0:
return Err(InferError(ERROR_MSG("Input shape not support")));
default:
break;
}
DimExpr m(1), n(1), ka(1), kb(1);
if (!transA) {
m = sa.rbegin()[1];
ka = sa.rbegin()[0];
} else {
m = sa.rbegin()[0];
ka = sa.rbegin()[1];
}
sa.pop_back();
sa.pop_back();
if (!transB) {
kb = sb.rbegin()[1];
n = sb.rbegin()[0];
} else {
kb = sb.rbegin()[0];
n = sb.rbegin()[1];
}
sb.pop_back();
sb.pop_back();
ASSERT(ka == kb, "Input shape not support");
MULTIDIR_BROADCAST((ShapeRefs{sa, sb}))
output.emplace_back(std::move(m));
output.emplace_back(std::move(n));
return Ok(Tensors{Tensor::share(dataType, std::move(output), extractDependency(inputs))});
}

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::MatMul;
return std::make_unique<Op_>(1., 1., transA, transB);
}

}// namespace refactor::llm
25 changes: 25 additions & 0 deletions src/08-01llm/src/operators/mat_mul.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef LLM_MAT_MUL_HH
#define LLM_MAT_MUL_HH

#include "frontend/operator.h"

namespace refactor::llm {
using namespace frontend;

struct MatMul final : public Operator {
bool transA, transB;

MatMul(decltype(transA), decltype(transB));

static OpBox build(ModelContext const &, std::string_view, Attributes);
static size_t typeId();

size_t opTypeId() const final;
std::string_view opTypeName() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
computation::OpBox lower(TensorRefs) const final;
};

}// namespace refactor::llm

#endif// LLM_MAT_MUL_HH
2 changes: 1 addition & 1 deletion src/09python_ffi/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ add_subdirectory(pybind11)

file(GLOB_RECURSE PYFFI_SRC src/*.cc src/*.cpp)
pybind11_add_module(python_ffi SHARED ${PYFFI_SRC})
target_link_libraries(python_ffi PRIVATE onnx communication)
target_link_libraries(python_ffi PRIVATE onnx llm communication)
target_include_directories(python_ffi PRIVATE include)

# EXAMPLE_VERSION_INFO is defined by setup.py and passed into the C++ code as a
Expand Down
2 changes: 2 additions & 0 deletions src/09python_ffi/src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "communication/operators.h"
#include "hardware/device.h"
#include "import.h"
#include "llm/operators.h"
#include "onnx/operators.h"
#include <pybind11/stl.h>// keep this line to convert stl types

Expand All @@ -14,6 +15,7 @@ namespace refactor::python_ffi {
using namespace frontend;

onnx::register_();
llm::register_();
communication::register_();

// clang-format off
Expand Down

0 comments on commit 9dce4b3

Please sign in to comment.