Skip to content

Commit

Permalink
feat(llm): 添加 rms normalization 的前端算子、计算图算子和 cpu 核函数
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 26, 2024
1 parent 9dce4b3 commit eca56d8
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 5 deletions.
20 changes: 20 additions & 0 deletions src/04kernel/include/kernel/collectors/rms_normalization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#ifndef KERNEL_RMS_NORMALIZATION_H
#define KERNEL_RMS_NORMALIZATION_H

#include "../collector.h"

namespace refactor::kernel {

struct RmsNormalizationCollector final : public InfoCollector {
float epsilon;

constexpr RmsNormalizationCollector(decltype(_target) target, float epsilon_) noexcept
: InfoCollector(target), epsilon(epsilon_) {}

std::vector<KernelBox>
filter(TensorRefs inputs, TensorRefs outputs) const final;
};

}// namespace refactor::kernel

#endif// KERNEL_RMS_NORMALIZATION_H
29 changes: 29 additions & 0 deletions src/04kernel/src/collectors/rms_normalization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#include "kernel/collectors/rms_normalization.h"
#include "../kernels/rms_normalization/cpu_kernel.hh"
#include "../kernels/rms_normalization/cuda_kernel.hh"

namespace refactor::kernel {

#define REGISTER(T) \
if (auto ptr = T::build(epsilon, inputs); ptr) { \
ans.emplace_back(std::move(ptr)); \
}

std::vector<KernelBox>
RmsNormalizationCollector::filter(TensorRefs inputs, TensorRefs outputs) const {

std::vector<KernelBox> ans;
switch (_target) {
case decltype(_target)::Cpu:
REGISTER(RmsNormalizationCpu)
break;
case decltype(_target)::Nvidia:
REGISTER(RmsNormalizationCuda)
break;
default:
UNREACHABLEX(void, "Unknown target");
}
return ans;
}

}// namespace refactor::kernel
74 changes: 74 additions & 0 deletions src/04kernel/src/kernels/rms_normalization/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include "cpu_kernel.hh"
#include <numeric>

namespace refactor::kernel {
using K = RmsNormalizationCpu;
using DT = DataType;

K::RmsNormalizationCpu(
decltype(epsilon) epsilon_,
decltype(dataType) dataType_,
decltype(blockCount) blockCount_,
decltype(blockSize) blockSize_) noexcept
: Kernel(),
epsilon(epsilon_),
dataType(dataType_),
blockCount(blockCount_),
blockSize(blockSize_) {}

auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox {
auto const &x = inputs[0].get();
auto const &w = inputs[1].get();
if ((x.dataType != DataType::F32 && x.dataType != DataType::F64) || x.dataType != w.dataType) {
return nullptr;
}
auto it = x.shape.rbegin();
dim_t blockSize = *it++;
dim_t blockCount = std::accumulate(it, x.shape.rend(), 1, std::multiplies());
return std::make_unique<K>(epsilon, x.dataType, blockCount, blockSize);
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing rms normalization on generic cpu";
}

template<decltype(DT::internal) T>
static Routine lowerTyped(float epsilon, dim_t blockCount, dim_t blockSize) {
using namespace runtime;
using dt = typename primitive<T>::type;

return [epsilon, blockCount, blockSize]//
(Resources &, void *, void const *const *inputs, void *const *outputs) {
auto x = reinterpret_cast<dt const *>(inputs[0]);
auto w = reinterpret_cast<dt const *>(inputs[1]);
auto y = reinterpret_cast<dt *>(outputs[0]);
for (auto i : range0_(blockCount)) {
auto x_ = x + i * blockSize;
auto y_ = y + i * blockSize;

auto ss = std::accumulate(x_, x_ + blockSize, dt(0), [](auto acc, auto it) {
return acc + it * it;
});
ss /= blockSize;
ss += epsilon;
ss = 1. / std::sqrt(ss);

for (auto j : range0_(blockSize)) {
y_[j] = x_[j] * ss * w[j];
}
}
};
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
return dataType == DataType::F32
? lowerTyped<DataType::F32>(epsilon, blockCount, blockSize)
: lowerTyped<DataType::F64>(epsilon, blockCount, blockSize);
}

}// namespace refactor::kernel
30 changes: 30 additions & 0 deletions src/04kernel/src/kernels/rms_normalization/cpu_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH
#define KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH

#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {

struct RmsNormalizationCpu final : public Kernel {
float epsilon;
DataType dataType;
dim_t blockCount, blockSize;

RmsNormalizationCpu(
decltype(epsilon),
decltype(dataType),
decltype(blockCount),
decltype(blockSize)) noexcept;

static KernelBox build(float, TensorRefs) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
RoutineWorkspace lower(Resources &) const noexcept final;
};

}// namespace refactor::kernel

#endif// KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH
50 changes: 50 additions & 0 deletions src/04kernel/src/kernels/rms_normalization/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "cuda_kernel.hh"
#include <numeric>

namespace refactor::kernel {
using K = RmsNormalizationCuda;
using DT = DataType;

K::RmsNormalizationCuda(
decltype(epsilon) epsilon_,
decltype(dataType) dataType_,
decltype(blockCount) blockCount_,
decltype(blockSize) blockSize_) noexcept
: Kernel(),
epsilon(epsilon_),
dataType(dataType_),
blockCount(blockCount_),
blockSize(blockSize_) {}

auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox {
#ifndef USE_CUDA
return nullptr;
#endif

auto const &x = inputs[0].get();
auto const &w = inputs[1].get();
if (!x.dataType.isFloat() || x.dataType != w.dataType) {
return nullptr;
}
auto it = x.shape.rbegin();
dim_t blockSize = *it++;
dim_t blockCount = std::accumulate(it, x.shape.rend(), 1, std::multiplies());
return std::make_unique<K>(epsilon, x.dataType, blockCount, blockSize);
}
auto K::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}

auto K::kernelTypeId() const noexcept -> size_t { return typeId(); }
auto K::description() const noexcept -> std::string_view {
return "Performing rms normalization using CUDA";
}

#ifdef USE_CUDA
auto K::lower(Resources &) const -> RoutineWorkspace {
TODO("");
}
#endif

}// namespace refactor::kernel
31 changes: 31 additions & 0 deletions src/04kernel/src/kernels/rms_normalization/cuda_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#ifndef KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH
#define KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH

#include "kernel/kernel.h"
#include "kernel/tensor.h"

namespace refactor::kernel {
struct RmsNormalizationCuda final : public Kernel {
float epsilon;
DataType dataType;
dim_t blockCount, blockSize;

RmsNormalizationCuda(
decltype(epsilon),
decltype(dataType),
decltype(blockCount),
decltype(blockSize)) noexcept;

static KernelBox build(float, TensorRefs) noexcept;
static size_t typeId() noexcept;

size_t kernelTypeId() const noexcept final;
std::string_view description() const noexcept final;
#ifdef USE_CUDA
RoutineWorkspace lower(Resources &) const final;
#endif
};

}// namespace refactor::kernel

#endif// KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef COMPUTATION_RMS_NORMALIZATION_H
#define COMPUTATION_RMS_NORMALIZATION_H

#include "../operator.h"

namespace refactor::computation {

struct RmsNormalization final : public Operator {
float epsilon;

constexpr explicit RmsNormalization(float epsilon_) noexcept
: Operator(), epsilon(epsilon_) {}

static size_t typeId() noexcept;
size_t opTypeId() const noexcept final;
std::string_view name() const noexcept final;
kernel::CollectorBox candidateKernels(Target) const final;
std::string serialize() const noexcept final;
};

}// namespace refactor::computation

#endif// COMPUTATION_RMS_NORMALIZATION_H
27 changes: 27 additions & 0 deletions src/05computation/src/operators/rms_normalization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "computation/operators/rms_normalization.h"
#include "kernel/collectors/rms_normalization.h"

namespace refactor::computation {
using Op = RmsNormalization;

auto Op::typeId() noexcept -> size_t {
static uint8_t ID = 1;
return reinterpret_cast<size_t>(&ID);
}
auto Op::opTypeId() const noexcept -> size_t { return typeId(); }
auto Op::name() const noexcept -> std::string_view { return "RmsNormalization"; }
auto Op::candidateKernels(Target target) const -> kernel::CollectorBox {
using Collector_ = kernel::RmsNormalizationCollector;
return std::make_unique<Collector_>(target, epsilon);
}
auto Op::serialize() const noexcept -> std::string {
union code {
float f;
int32_t i;
};
return fmt::format(("{}({:e}={:#010x})"),
name(), epsilon,
code{epsilon}.i);
}

}// namespace refactor::computation
3 changes: 2 additions & 1 deletion src/08-01llm/src/operators/mat_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ namespace refactor::llm {
Op::MatMul(
decltype(transA) transA_,
decltype(transB) transB_)
: transA(transA_),
: Operator(),
transA(transA_),
transB(transB_) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
Expand Down
24 changes: 21 additions & 3 deletions src/08-01llm/src/operators/rms_normalization.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#include "rms_normalization.hh"
#include "common.h"
#include "computation/operators/rms_normalization.h"

namespace refactor::llm {
using Op = RmsNormalization;

auto Op::build(ModelContext const &, std::string_view, Attributes) -> OpBox {
return OpBox(std::make_unique<Op>());
Op::RmsNormalization(decltype(epsilon) epsilon_)
: Operator(), epsilon(epsilon_) {}

auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
auto epsilon = attributes.getOrInsert("epsilon", {1e-5f}).float_();
return OpBox(std::make_unique<Op>(epsilon));
}
auto Op::typeId() -> size_t {
static uint8_t ID = 1;
Expand All @@ -18,7 +23,20 @@ namespace refactor::llm {
auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult {
EXPECT_SIZE(2)

TODO("");
auto const &x = inputs[0];
auto const &w = inputs[1];
if (x.rank() < 1 || w.rank() != 1 || x.shape.back() != w.shape.back()) {
return Err(InferError(ERROR_MSG("Input shape not support")));
}
if (x.dataType != w.dataType) {
return Err(InferError(ERROR_MSG("Input data type not support")));
}
return Ok(Tensors{Tensor::share(x.dataType, x.shape, extractDependency(inputs))});
}

auto Op::lower(TensorRefs) const -> computation::OpBox {
using Op_ = computation::RmsNormalization;
return std::make_unique<Op_>(epsilon);
}

}// namespace refactor::llm
4 changes: 3 additions & 1 deletion src/08-01llm/src/operators/rms_normalization.hh
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ namespace refactor::llm {
using namespace frontend;

struct RmsNormalization final : public Operator {
float epsilon;

constexpr RmsNormalization() noexcept = default;
RmsNormalization(decltype(epsilon));

static OpBox build(ModelContext const &, std::string_view, Attributes);
static size_t typeId();

size_t opTypeId() const final;
std::string_view opTypeName() const final;
InferResult infer(TensorRefs, InferOptions const &) const final;
computation::OpBox lower(TensorRefs) const final;
};

}// namespace refactor::llm
Expand Down
22 changes: 22 additions & 0 deletions src/08-01llm/test/test_rms_normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "../src/operators/rms_normalization.hh"
#include "llm/operators.h"
#include <gtest/gtest.h>

using namespace refactor;
using namespace llm;

TEST(infer, RmsNormalization) {
llm::register_();
auto edges = Edges{
{Tensor::share(DataType::F32, Shape{DimExpr(7), DimExpr(2), DimExpr(3)}, {}), ""},
{Tensor::share(DataType::F32, Shape{DimExpr(3)}, {}), ""},
};
count_t inputs[]{0, 1};
auto infered = RmsNormalization(1e-6).infer(TensorRefs(edges, inputs), {true});
ASSERT_TRUE(infered.isOk());
auto outputs = std::move(infered.unwrap());
ASSERT_EQ(outputs.size(), 1);
auto y = std::move(outputs[0]);
ASSERT_EQ(y->dataType, DataType::F32);
ASSERT_EQ(y->shape, (Shape{DimExpr(7), DimExpr(2), DimExpr(3)}));
}

0 comments on commit eca56d8

Please sign in to comment.