Skip to content

Commit

Permalink
feat(llm): 添加 rms normalization 的 cuda 核函数
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 26, 2024
1 parent eca56d8 commit e15f28f
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 23 deletions.
45 changes: 24 additions & 21 deletions src/04kernel/src/kernels/rms_normalization/cpu_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "cpu_kernel.hh"
#include <execution>
#include <numeric>

namespace refactor::kernel {
using K = RmsNormalizationCpu;
using DT = DataType;

K::RmsNormalizationCpu(
decltype(epsilon) epsilon_,
Expand Down Expand Up @@ -37,38 +37,41 @@ namespace refactor::kernel {
return "Performing rms normalization on generic cpu";
}

template<decltype(DT::internal) T>
template<class T>
static Routine lowerTyped(float epsilon, dim_t blockCount, dim_t blockSize) {
using namespace runtime;
using dt = typename primitive<T>::type;

return [epsilon, blockCount, blockSize]//
(Resources &, void *, void const *const *inputs, void *const *outputs) {
auto x = reinterpret_cast<dt const *>(inputs[0]);
auto w = reinterpret_cast<dt const *>(inputs[1]);
auto y = reinterpret_cast<dt *>(outputs[0]);
for (auto i : range0_(blockCount)) {
auto x_ = x + i * blockSize;
auto y_ = y + i * blockSize;
auto x = reinterpret_cast<T const *>(inputs[0]);
auto w = reinterpret_cast<T const *>(inputs[1]);
auto y = reinterpret_cast<T *>(outputs[0]);
std::for_each_n(
std::execution::par_unseq,
natural_t(0),
blockCount,
[blockSize, epsilon, x, w, y](auto i) {
auto x_ = x + i * blockSize;
auto y_ = y + i * blockSize;

auto ss = std::accumulate(x_, x_ + blockSize, dt(0), [](auto acc, auto it) {
return acc + it * it;
});
ss /= blockSize;
ss += epsilon;
ss = 1. / std::sqrt(ss);
auto ss = std::accumulate(
x_, x_ + blockSize, 0,
[](auto acc, auto it) { return acc + it * it; });
ss /= blockSize;
ss += epsilon;
ss = 1. / std::sqrt(ss);

for (auto j : range0_(blockSize)) {
y_[j] = x_[j] * ss * w[j];
}
}
for (auto j : range0_(blockSize)) {
y_[j] = x_[j] * ss * w[j];
}
});
};
}

auto K::lower(Resources &) const noexcept -> RoutineWorkspace {
return dataType == DataType::F32
? lowerTyped<DataType::F32>(epsilon, blockCount, blockSize)
: lowerTyped<DataType::F64>(epsilon, blockCount, blockSize);
? lowerTyped<float>(epsilon, blockCount, blockSize)
: lowerTyped<double>(epsilon, blockCount, blockSize);
}

}// namespace refactor::kernel
77 changes: 75 additions & 2 deletions src/04kernel/src/kernels/rms_normalization/cuda_kernel.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#include "cuda_kernel.hh"
#include <numeric>

#ifdef USE_CUDA
#include "../../generator/nvrtc_repo.h"
#include <cuda_runtime.h>
#include <sstream>
#endif

namespace refactor::kernel {
using K = RmsNormalizationCuda;
using DT = DataType;

K::RmsNormalizationCuda(
decltype(epsilon) epsilon_,
Expand Down Expand Up @@ -42,9 +47,77 @@ namespace refactor::kernel {
}

#ifdef USE_CUDA

// 0: data type
// 1: block size
// 2: epsilon cast
constexpr static const char *TEMPLATE = R"~(
#include <cub/cub.cuh>
static __device__ __forceinline__ {0:} squareSum({0:} a, {0:} b) {{
return a * a + b * b;
}}
extern "C" __global__ void kernel(
{0:} *__restrict__ const y,
{0:} const *__restrict__ const x,
{0:} const *__restrict__ const w,
float epsilon_) {{
auto epsilon = {2:}(epsilon_);
x += blockIdx.x * blockDim.x + threadIdx.x;
y += blockIdx.x * blockDim.x + threadIdx.x;;
w += threadIdx.x;
using BlockReduce = cub::BlockReduce<{0:}, {1:}>;
__shared__ typename BlockReduce::TempStorage tempStorage;
__shared__ {0:} rms;
auto acc = BlockReduce(tempStorage).Reduce(*x, squareSum);
if (threadIdx.x == 0) {{
rms = rsqrt(acc / blockDim.x + epsilon);
}}
__syncthreads();
*y = *x * rms * *w;
}}
)~";

auto K::lower(Resources &) const -> RoutineWorkspace {
TODO("");
using namespace runtime;

std::stringstream ss;
ss << "RmsNorm" << nvrtc::dataType(dataType) << blockSize;
ss << ".cu";
auto name = ss.str();
auto code = fmt::format(
TEMPLATE,
nvrtc::dataType(dataType),// 0
blockSize, // 1
// clang-format off
dataType == DataType::F32 ? ""
: dataType == DataType::F64 ? "static_cast<float>"
: dataType == DataType::FP16 ? "__half2float"
: dataType == DataType::BF16 ? "__bfloat162float"
: UNREACHABLEX(const char*, "unreachable")
// clang-format on
);

return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"),
epsilon_ = this->epsilon,
blockCount = this->blockCount,
blockSize = this->blockSize]//
(Resources &, void *, void const *const *inputs, void *const *outputs) {
auto y = outputs[0];
auto x = inputs[0];
auto w = inputs[1];
auto epsilon = epsilon_;
void *args[]{&y, &x, &w, &epsilon};
h->launch(blockCount, 1, 1,
blockSize, 1, 1,
0, args);
};
}

#endif

}// namespace refactor::kernel

0 comments on commit e15f28f

Please sign in to comment.