Skip to content

Commit

Permalink
[Reduce][Kernel] Pack f16/bf16x8 & fp8/i8x16 LD/ST (#43)
Browse files Browse the repository at this point in the history
* Update README.md

* Update block_all_reduce.cu

* Update block_all_reduce.py

* Update README.md

* Update block_all_reduce.cu

* Update README.md

* Update block_all_reduce.cu

* Update block_all_reduce.py

* Update README.md

* Delete fuse-multihead-attention directory

* Create elementwise.cu

* Create relu.cu

* Create .gitignore

* Create README.md

* Update README.md
  • Loading branch information
DefTruth authored Sep 24, 2024
1 parent d43c53d commit bf283f2
Show file tree
Hide file tree
Showing 9 changed files with 655 additions and 129 deletions.
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
<img src=https://img.shields.io/github/watchers/DefTruth/cuda-learn-note?color=9cc >
<img src=https://img.shields.io/github/forks/DefTruth/cuda-learn-note.svg?style=social >
<img src=https://img.shields.io/github/stars/DefTruth/cuda-learn-note.svg?style=social >
<img src=https://img.shields.io/badge/Release-v2.3-brightgreen.svg >
<img src=https://img.shields.io/badge/Release-v2.4-brightgreen.svg >
<img src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg >
</div>

📖 **CUDA Learn Notes with PyTorch**: **fp32fp16/bf16fp8/int8**flash_attnsgemmsgemvwarp/block reducedot prodelementwise、softmax、layernorm、rmsnorm、hist etc. 👉News: Most of my time now is focused on **LLM/VLM/Diffusion** Inference. Please check 📖[Awesome-LLM-Inference](https://github.com/DefTruth/Awesome-LLM-Inference) ![](https://img.shields.io/github/stars/DefTruth/Awesome-LLM-Inference.svg?style=social), 📖[Awesome-SD-Inference](https://github.com/DefTruth/Awesome-SD-Inference) ![](https://img.shields.io/github/stars/DefTruth/Awesome-SD-Inference.svg?style=social) and 📖[CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes) ![](https://img.shields.io/github/stars/DefTruth/CUDA-Learn-Notes.svg?style=social) for more details.
🎉 **CUDA Learn Notes**: This repo aims to build a **Modern CUDA Learn Notes with PyTorch** for beginners, including **fp32, fp16/bf16, fp8/int8, Tensor/CUDA Cores**, flash_attn, sgemm, sgemv, hgemm, hgemv, warp/block reduce, dot prod, elementwise, sigmoid, relu, softmax, layernorm, rmsnorm, hist and some CUDA optimization techniques (pack LDST, warp gemv, sliced_k/split_k/pipeline gemm, bank conflicts free, MMA, etc).

<img width="1438" alt="image" src="https://github.com/user-attachments/assets/0c5e5125-586f-43fa-8e8b-e2c61c1afbbe">

Expand Down Expand Up @@ -49,13 +49,20 @@
| ✔️ [block_all_reduce_f16_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x2_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x2_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x8_pack_f16](./reduce/block_all_reduce.cu)|f16|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f16x8_pack_f32](./reduce/block_all_reduce.cu)|f16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x2_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x2_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x8_pack_bf16](./reduce/block_all_reduce.cu)|bf16|bf16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_bf16x8_pack_f32](./reduce/block_all_reduce.cu)|bf16|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e4m3_f16](./reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e5m2_f16](./reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e4m3x16_pack_f16](./reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_fp8_e5m2x16_pack_f16](./reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_i8_i32](./reduce/block_all_reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_i8x16_pack_i32](./reduce/block_all_reduce.cu)|i8|i32|[link](./reduce/)|⭐️⭐️|
| ✔️ [dot_product_f32](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f32x4](./dot-product/dot_product.cu)|f32|f32|[link](./dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16_f32](./dot-product/dot_product.cu)|f16|f32|[link](./dot-product/)|⭐️⭐️|
Expand Down
20 changes: 0 additions & 20 deletions fuse-multihead-attention/.gitignore

This file was deleted.

18 changes: 18 additions & 0 deletions nvidia-nsight/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
*.i
*.ii
*.gpu
*.ptx
*.cubin
*.fatbin
*.so
*.a
*.dylib
*.dll
*.lib
.DS_Store
build
*.whl
tmp
*.nsys*
*.profile*
*.cubin
2 changes: 2 additions & 0 deletions nvidia-nsight/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# NVIDIA Nsight System

114 changes: 114 additions & 0 deletions nvidia-nsight/elementwise.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>

#define WARP_SIZE 32
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])

// -------------------------------------- FP32 --------------------------------------
// ElementWise Add
// grid(N/256), block(256)
// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
__global__ void elementwise_add_f32_kernel(float* a, float* b, float* c, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) c[idx] = a[idx] + b[idx];
}

// ElementWise Add + Vec4
// grid(N/256), block(256/4)
// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
__global__ void elementwise_add_f32x4_kernel(float* a, float* b, float* c, int N) {
int idx = 4 * (blockIdx.x * blockDim.x + threadIdx.x);
if (idx < N) {
float4 reg_a = FLOAT4(a[idx]);
float4 reg_b = FLOAT4(b[idx]);
float4 reg_c;
reg_c.x = reg_a.x + reg_b.x;
reg_c.y = reg_a.y + reg_b.y;
reg_c.z = reg_a.z + reg_b.z;
reg_c.w = reg_a.w + reg_b.w;
FLOAT4(c[idx]) = reg_c;
}
}

// -------------------------------------- FP16 --------------------------------------
// ElementWise Add
// grid(N/256), block(256)
// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
__global__ void elementwise_add_f16_kernel(half* a, half* b, half* c, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) c[idx] = __hadd(a[idx], b[idx]);
}

// a: Nx1, b: Nx1, c: Nx1, c = elementwise_add(a, b)
__global__ void elementwise_add_f16x2_kernel(half* a, half* b, half* c, int N) {
int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
if (idx < N) {
half2 reg_a = HALF2(a[idx]);
half2 reg_b = HALF2(b[idx]);
half2 reg_c;
reg_c.x = __hadd(reg_a.x, reg_b.x);
reg_c.y = __hadd(reg_a.y, reg_b.y);
HALF2(c[idx]) = reg_c;
}
}

__global__ void elementwise_add_f16x8_kernel(half* a, half* b, half* c, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
// manual unroll and improve L2 cache hit rate.
// Only L2 cache: load 32 bytes in 1 memory issue (default)
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
half2 reg_a_0 = HALF2(a[idx + 0]);
half2 reg_a_1 = HALF2(a[idx + 2]);
half2 reg_a_2 = HALF2(a[idx + 4]);
half2 reg_a_3 = HALF2(a[idx + 6]);
half2 reg_b_0 = HALF2(b[idx + 0]);
half2 reg_b_1 = HALF2(b[idx + 2]);
half2 reg_b_2 = HALF2(b[idx + 4]);
half2 reg_b_3 = HALF2(b[idx + 6]);
half2 reg_c_0, reg_c_1, reg_c_2, reg_c_3;
reg_c_0.x = __hadd(reg_a_0.x, reg_b_0.x);
reg_c_0.y = __hadd(reg_a_0.y, reg_b_0.y);
reg_c_1.x = __hadd(reg_a_1.x, reg_b_1.x);
reg_c_1.y = __hadd(reg_a_1.y, reg_b_1.y);
reg_c_2.x = __hadd(reg_a_2.x, reg_b_2.x);
reg_c_2.y = __hadd(reg_a_2.y, reg_b_2.y);
reg_c_3.x = __hadd(reg_a_3.x, reg_b_3.x);
reg_c_3.y = __hadd(reg_a_3.y, reg_b_3.y);
if ((idx + 0) < N) { HALF2(c[idx + 0]) = reg_c_0; }
if ((idx + 2) < N) { HALF2(c[idx + 2]) = reg_c_1; }
if ((idx + 4) < N) { HALF2(c[idx + 4]) = reg_c_2; }
if ((idx + 6) < N) { HALF2(c[idx + 6]) = reg_c_3; }
}

__global__ void elementwise_add_f16x8_pack_kernel(half* a, half* b, half* c, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
// temporary register(memory), .local space in ptx, addressable
half pack_a[8], pack_b[8], pack_c[8]; // 8x16 bits=128 bits.
// reinterpret as float4 and load 128 bits in 1 memory issue.
LDST128BITS(pack_a[0]) = LDST128BITS(a[idx]); // load 128 bits
LDST128BITS(pack_b[0]) = LDST128BITS(b[idx]); // load 128 bits

#pragma unroll
for (int i = 0; i < 8; i += 2) {
// __hadd2 for half2 x 4
HALF2(pack_c[i]) = __hadd2(HALF2(pack_a[i]), HALF2(pack_b[i]));
}
// reinterpret as float4 and store 128 bits in 1 memory issue.
if ((idx + 7) < N) { LDST128BITS(c[idx]) = LDST128BITS(pack_c[0]); }
}


99 changes: 99 additions & 0 deletions nvidia-nsight/relu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

#define WARP_SIZE 32
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])

// -------------------------------------- FP32 --------------------------------------
// Relu x: N, y: N y=max(0,x)
// grid(N/256), block(K=256)
__global__ void relu_f32_kernel(float* x, float* y, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) y[idx] = fmaxf(0.0f, x[idx]);
}

// Relu x: N, y: N y=max(0,x) Vec4
// grid(N/256/4), block(256/4)
__global__ void relu_f32x4_kernel(float* x, float* y, int N) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4;
if (idx < N) {
float4 reg_x = FLOAT4(x[idx]);
float4 reg_y;
reg_y.x = fmaxf(0.0f, reg_x.x);
reg_y.y = fmaxf(0.0f, reg_x.y);
reg_y.z = fmaxf(0.0f, reg_x.z);
reg_y.w = fmaxf(0.0f, reg_x.w);
FLOAT4(y[idx]) = reg_y;
}
}

// -------------------------------------- FP16 --------------------------------------
__global__ void relu_f16_kernel(half* x, half* y, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) y[idx] = __hmax(__float2half(0.0f), x[idx]);
}

__global__ void relu_f16x2_kernel(half* x, half* y, int N) {
int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x);
if (idx < N) {
half2 reg_x = HALF2(x[idx]);
half2 reg_y = HALF2(y[idx]);
reg_y.x = __hmax(__float2half(0.0f), reg_x.x);
reg_y.y = __hmax(__float2half(0.0f), reg_x.y);
HALF2(y[idx]) = reg_y;
}
}

__global__ void relu_f16x8_kernel(half* x, half* y, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
// manual unroll and improve L2 cache hit rate.
// Only L2 cache: load 32 bytes in 1 memory issue (default)
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
half2 reg_x_0 = HALF2(x[idx + 0]);
half2 reg_x_1 = HALF2(x[idx + 2]);
half2 reg_x_2 = HALF2(x[idx + 4]);
half2 reg_x_3 = HALF2(x[idx + 6]);
half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3;
reg_y_0.x = __hmax(__float2half(0.0f), reg_x_0.x);
reg_y_0.y = __hmax(__float2half(0.0f), reg_x_0.y);
reg_y_1.x = __hmax(__float2half(0.0f), reg_x_1.x);
reg_y_1.y = __hmax(__float2half(0.0f), reg_x_1.y);
reg_y_2.x = __hmax(__float2half(0.0f), reg_x_2.x);
reg_y_2.y = __hmax(__float2half(0.0f), reg_x_2.y);
reg_y_3.x = __hmax(__float2half(0.0f), reg_x_3.x);
reg_y_3.y = __hmax(__float2half(0.0f), reg_x_3.y);
if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_y_0; }
if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_y_1; }
if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_y_2; }
if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; }
}

__global__ void relu_f16x8_pack_kernel(half* x, half* y, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
const half2 z2 = {__float2half(0.0f), __float2half(0.0f)};
// temporary register(memory), .local space in ptx, addressable
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
// reinterpret as float4 and load 128 bits in 1 memory issue.
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits

#pragma unroll
for (int i = 0; i < 8; i += 2) {
// __hmax2 for half2 x 4
HALF2(pack_y[i]) = __hmax2(HALF2(pack_x[i]), z2);
}
// reinterpret as float4 and store 128 bits in 1 memory issue.
if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
}

Loading

0 comments on commit bf283f2

Please sign in to comment.