-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(llm): 添加 rms normalization 的前端算子、计算图算子和 cpu 核函数
Signed-off-by: YdrMaster <[email protected]>
- Loading branch information
Showing
12 changed files
with
332 additions
and
5 deletions.
There are no files selected for viewing
20 changes: 20 additions & 0 deletions
20
src/04kernel/include/kernel/collectors/rms_normalization.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
#ifndef KERNEL_RMS_NORMALIZATION_H | ||
#define KERNEL_RMS_NORMALIZATION_H | ||
|
||
#include "../collector.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct RmsNormalizationCollector final : public InfoCollector { | ||
float epsilon; | ||
|
||
constexpr RmsNormalizationCollector(decltype(_target) target, float epsilon_) noexcept | ||
: InfoCollector(target), epsilon(epsilon_) {} | ||
|
||
std::vector<KernelBox> | ||
filter(TensorRefs inputs, TensorRefs outputs) const final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_RMS_NORMALIZATION_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#include "kernel/collectors/rms_normalization.h" | ||
#include "../kernels/rms_normalization/cpu_kernel.hh" | ||
#include "../kernels/rms_normalization/cuda_kernel.hh" | ||
|
||
namespace refactor::kernel { | ||
|
||
#define REGISTER(T) \ | ||
if (auto ptr = T::build(epsilon, inputs); ptr) { \ | ||
ans.emplace_back(std::move(ptr)); \ | ||
} | ||
|
||
std::vector<KernelBox> | ||
RmsNormalizationCollector::filter(TensorRefs inputs, TensorRefs outputs) const { | ||
|
||
std::vector<KernelBox> ans; | ||
switch (_target) { | ||
case decltype(_target)::Cpu: | ||
REGISTER(RmsNormalizationCpu) | ||
break; | ||
case decltype(_target)::Nvidia: | ||
REGISTER(RmsNormalizationCuda) | ||
break; | ||
default: | ||
UNREACHABLEX(void, "Unknown target"); | ||
} | ||
return ans; | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
#include "cpu_kernel.hh" | ||
#include <numeric> | ||
|
||
namespace refactor::kernel { | ||
using K = RmsNormalizationCpu; | ||
using DT = DataType; | ||
|
||
K::RmsNormalizationCpu( | ||
decltype(epsilon) epsilon_, | ||
decltype(dataType) dataType_, | ||
decltype(blockCount) blockCount_, | ||
decltype(blockSize) blockSize_) noexcept | ||
: Kernel(), | ||
epsilon(epsilon_), | ||
dataType(dataType_), | ||
blockCount(blockCount_), | ||
blockSize(blockSize_) {} | ||
|
||
auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox { | ||
auto const &x = inputs[0].get(); | ||
auto const &w = inputs[1].get(); | ||
if ((x.dataType != DataType::F32 && x.dataType != DataType::F64) || x.dataType != w.dataType) { | ||
return nullptr; | ||
} | ||
auto it = x.shape.rbegin(); | ||
dim_t blockSize = *it++; | ||
dim_t blockCount = std::accumulate(it, x.shape.rend(), 1, std::multiplies()); | ||
return std::make_unique<K>(epsilon, x.dataType, blockCount, blockSize); | ||
} | ||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing rms normalization on generic cpu"; | ||
} | ||
|
||
template<decltype(DT::internal) T> | ||
static Routine lowerTyped(float epsilon, dim_t blockCount, dim_t blockSize) { | ||
using namespace runtime; | ||
using dt = typename primitive<T>::type; | ||
|
||
return [epsilon, blockCount, blockSize]// | ||
(Resources &, void *, void const *const *inputs, void *const *outputs) { | ||
auto x = reinterpret_cast<dt const *>(inputs[0]); | ||
auto w = reinterpret_cast<dt const *>(inputs[1]); | ||
auto y = reinterpret_cast<dt *>(outputs[0]); | ||
for (auto i : range0_(blockCount)) { | ||
auto x_ = x + i * blockSize; | ||
auto y_ = y + i * blockSize; | ||
|
||
auto ss = std::accumulate(x_, x_ + blockSize, dt(0), [](auto acc, auto it) { | ||
return acc + it * it; | ||
}); | ||
ss /= blockSize; | ||
ss += epsilon; | ||
ss = 1. / std::sqrt(ss); | ||
|
||
for (auto j : range0_(blockSize)) { | ||
y_[j] = x_[j] * ss * w[j]; | ||
} | ||
} | ||
}; | ||
} | ||
|
||
auto K::lower(Resources &) const noexcept -> RoutineWorkspace { | ||
return dataType == DataType::F32 | ||
? lowerTyped<DataType::F32>(epsilon, blockCount, blockSize) | ||
: lowerTyped<DataType::F64>(epsilon, blockCount, blockSize); | ||
} | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#ifndef KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH | ||
#define KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH | ||
|
||
#include "kernel/kernel.h" | ||
#include "kernel/tensor.h" | ||
|
||
namespace refactor::kernel { | ||
|
||
struct RmsNormalizationCpu final : public Kernel { | ||
float epsilon; | ||
DataType dataType; | ||
dim_t blockCount, blockSize; | ||
|
||
RmsNormalizationCpu( | ||
decltype(epsilon), | ||
decltype(dataType), | ||
decltype(blockCount), | ||
decltype(blockSize)) noexcept; | ||
|
||
static KernelBox build(float, TensorRefs) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
RoutineWorkspace lower(Resources &) const noexcept final; | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_RMS_NORMALIZATION_CPU_KERNEL_HH |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#include "cuda_kernel.hh" | ||
#include <numeric> | ||
|
||
namespace refactor::kernel { | ||
using K = RmsNormalizationCuda; | ||
using DT = DataType; | ||
|
||
K::RmsNormalizationCuda( | ||
decltype(epsilon) epsilon_, | ||
decltype(dataType) dataType_, | ||
decltype(blockCount) blockCount_, | ||
decltype(blockSize) blockSize_) noexcept | ||
: Kernel(), | ||
epsilon(epsilon_), | ||
dataType(dataType_), | ||
blockCount(blockCount_), | ||
blockSize(blockSize_) {} | ||
|
||
auto K::build(float epsilon, TensorRefs inputs) noexcept -> KernelBox { | ||
#ifndef USE_CUDA | ||
return nullptr; | ||
#endif | ||
|
||
auto const &x = inputs[0].get(); | ||
auto const &w = inputs[1].get(); | ||
if (!x.dataType.isFloat() || x.dataType != w.dataType) { | ||
return nullptr; | ||
} | ||
auto it = x.shape.rbegin(); | ||
dim_t blockSize = *it++; | ||
dim_t blockCount = std::accumulate(it, x.shape.rend(), 1, std::multiplies()); | ||
return std::make_unique<K>(epsilon, x.dataType, blockCount, blockSize); | ||
} | ||
auto K::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
|
||
auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } | ||
auto K::description() const noexcept -> std::string_view { | ||
return "Performing rms normalization using CUDA"; | ||
} | ||
|
||
#ifdef USE_CUDA | ||
auto K::lower(Resources &) const -> RoutineWorkspace { | ||
TODO(""); | ||
} | ||
#endif | ||
|
||
}// namespace refactor::kernel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#ifndef KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH | ||
#define KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH | ||
|
||
#include "kernel/kernel.h" | ||
#include "kernel/tensor.h" | ||
|
||
namespace refactor::kernel { | ||
struct RmsNormalizationCuda final : public Kernel { | ||
float epsilon; | ||
DataType dataType; | ||
dim_t blockCount, blockSize; | ||
|
||
RmsNormalizationCuda( | ||
decltype(epsilon), | ||
decltype(dataType), | ||
decltype(blockCount), | ||
decltype(blockSize)) noexcept; | ||
|
||
static KernelBox build(float, TensorRefs) noexcept; | ||
static size_t typeId() noexcept; | ||
|
||
size_t kernelTypeId() const noexcept final; | ||
std::string_view description() const noexcept final; | ||
#ifdef USE_CUDA | ||
RoutineWorkspace lower(Resources &) const final; | ||
#endif | ||
}; | ||
|
||
}// namespace refactor::kernel | ||
|
||
#endif// KERNEL_RMS_NORMALIZATION_CUDA_KERNEL_HH |
23 changes: 23 additions & 0 deletions
23
src/05computation/include/computation/operators/rms_normalization.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#ifndef COMPUTATION_RMS_NORMALIZATION_H | ||
#define COMPUTATION_RMS_NORMALIZATION_H | ||
|
||
#include "../operator.h" | ||
|
||
namespace refactor::computation { | ||
|
||
struct RmsNormalization final : public Operator { | ||
float epsilon; | ||
|
||
constexpr explicit RmsNormalization(float epsilon_) noexcept | ||
: Operator(), epsilon(epsilon_) {} | ||
|
||
static size_t typeId() noexcept; | ||
size_t opTypeId() const noexcept final; | ||
std::string_view name() const noexcept final; | ||
kernel::CollectorBox candidateKernels(Target) const final; | ||
std::string serialize() const noexcept final; | ||
}; | ||
|
||
}// namespace refactor::computation | ||
|
||
#endif// COMPUTATION_RMS_NORMALIZATION_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#include "computation/operators/rms_normalization.h" | ||
#include "kernel/collectors/rms_normalization.h" | ||
|
||
namespace refactor::computation { | ||
using Op = RmsNormalization; | ||
|
||
auto Op::typeId() noexcept -> size_t { | ||
static uint8_t ID = 1; | ||
return reinterpret_cast<size_t>(&ID); | ||
} | ||
auto Op::opTypeId() const noexcept -> size_t { return typeId(); } | ||
auto Op::name() const noexcept -> std::string_view { return "RmsNormalization"; } | ||
auto Op::candidateKernels(Target target) const -> kernel::CollectorBox { | ||
using Collector_ = kernel::RmsNormalizationCollector; | ||
return std::make_unique<Collector_>(target, epsilon); | ||
} | ||
auto Op::serialize() const noexcept -> std::string { | ||
union code { | ||
float f; | ||
int32_t i; | ||
}; | ||
return fmt::format(("{}({:e}={:#010x})"), | ||
name(), epsilon, | ||
code{epsilon}.i); | ||
} | ||
|
||
}// namespace refactor::computation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#include "../src/operators/rms_normalization.hh" | ||
#include "llm/operators.h" | ||
#include <gtest/gtest.h> | ||
|
||
using namespace refactor; | ||
using namespace llm; | ||
|
||
TEST(infer, RmsNormalization) { | ||
llm::register_(); | ||
auto edges = Edges{ | ||
{Tensor::share(DataType::F32, Shape{DimExpr(7), DimExpr(2), DimExpr(3)}, {}), ""}, | ||
{Tensor::share(DataType::F32, Shape{DimExpr(3)}, {}), ""}, | ||
}; | ||
count_t inputs[]{0, 1}; | ||
auto infered = RmsNormalization(1e-6).infer(TensorRefs(edges, inputs), {true}); | ||
ASSERT_TRUE(infered.isOk()); | ||
auto outputs = std::move(infered.unwrap()); | ||
ASSERT_EQ(outputs.size(), 1); | ||
auto y = std::move(outputs[0]); | ||
ASSERT_EQ(y->dataType, DataType::F32); | ||
ASSERT_EQ(y->shape, (Shape{DimExpr(7), DimExpr(2), DimExpr(3)})); | ||
} |