Skip to content

Commit

Permalink
[RELU][FP16] Add f16x8_pack kernel, boost 2.1x (#42)
Browse files Browse the repository at this point in the history
* Update README.md

* Update relu.cu

* Update relu.py

* Update README.md
  • Loading branch information
DefTruth authored Sep 23, 2024
1 parent 4be041f commit d43c53d
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 86 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
| ✔️ [relu_f16](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x2](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x8](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️|
| ✔️ [relu_f16x8_pack](./relu/relu.cu)|f16|/|[link](./relu/)|⭐️⭐️|
| ✔️ [warp_reduce_f16/bf16/f32/f8/i8](./reduce/block_all_reduce.cu)|all|all|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_reduce_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
| ✔️ [block_all_reduce_f32_f32](./reduce/block_all_reduce.cu)|f32|f32|[link](./reduce/)|⭐️⭐️|
Expand Down
25 changes: 9 additions & 16 deletions relu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- [X] relu_f16_kernel(fp16版本)
- [X] relu_f16x2_kernel(fp16向量化版本)
- [X] relu_f16x8_kernel(fp16向量化版本)
- [X] relu_f16x8_pack_kernel(fp16向量化,pack版本)
- [X] PyTorch bindings


Expand All @@ -24,22 +25,14 @@ python3 relu.py

```bash
--------------------------------------------------------------------------------
out_f32: [0.0, 0.0], time:0.01072860ms
out_f32x4: [0.0, 0.0], time:0.01059222ms
out_f32_th: [0.0, 0.0], time:0.00772071ms
out_f32: [0.0, 0.23360847], time:0.18854451ms
out_f32x4: [0.0, 0.23360847], time:0.18829441ms
out_f32_th: [0.0, 0.23360847], time:0.20471048ms
--------------------------------------------------------------------------------
out_f16: [0.0, 0.0], time:0.01077199ms
out_f16x2: [0.0, 0.0], time:0.01084924ms
out_f16x8: [0.0, 0.0], time:0.01083326ms
out_f16_th: [0.0, 0.0], time:0.00762105ms
--------------------------------------------------------------------------------
out_f32(v2): [0.0, 0.0], time:0.00346351ms
out_f32x4(v2): [0.0, 0.0], time:0.00342798ms
out_f32_th: [0.0, 0.0], time:0.01125073ms
--------------------------------------------------------------------------------
out_f16(v2): [0.0, 0.0], time:0.00343585ms
out_f16x2(v2): [0.0, 0.0], time:0.00339842ms
out_f16x8(v2): [0.0, 0.0], time:0.00347090ms
out_f16_th: [0.0, 0.0], time:0.00776792ms
out_f16: [0.0, 0.23364258], time:0.04058957ms
out_f16x2: [0.0, 0.23364258], time:0.03622127ms
out_f16x8: [0.0, 0.23364258], time:0.03658152ms
out_f16x8pack: [0.0, 0.23364258], time:0.01454449ms
out_f16_th: [0.0, 0.23364258], time:0.04748964ms
--------------------------------------------------------------------------------
```
98 changes: 55 additions & 43 deletions relu/relu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])

// -------------------------------------- FP32 --------------------------------------
// Relu x: N, y: N y=max(0,x)
Expand Down Expand Up @@ -81,6 +82,24 @@ __global__ void relu_f16x8_kernel(half* x, half* y, int N) {
if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; }
}

__global__ void relu_f16x8_pack_kernel(half* x, half* y, int N) {
int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x);
const half2 z2 = {__float2half(0.0f), __float2half(0.0f)};
// temporary register(memory), .local space in ptx, addressable
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
// reinterpret as float4 and load 128 bits in 1 memory issue.
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits

#pragma unroll
for (int i = 0; i < 8; i += 2) {
// __hmax2 for half2 x 4
HALF2(pack_y[i]) = __hmax2(HALF2(pack_x[i]), z2);
}
// reinterpret as float4 and store 128 bits in 1 memory issue.
if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
}


// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
Expand All @@ -92,61 +111,54 @@ if(((T).options().dtype() != (th_type))) { \
throw std::runtime_error("values must be "#th_type); \
}

#define CHECK_TORCH_TENSOR_SHAPE(T, S0) \
if (((T).size(0) != (S0))) { throw std::runtime_error("Tensor size mismatch!"); }

#define TORCH_BINDING_RELU(packed_type, th_type, element_type, n_elements) \
torch::Tensor relu_##packed_type(torch::Tensor x) { \
CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \
auto options = torch::TensorOptions().dtype((th_type)).device( \
torch::kCUDA, 0); \
const int N = x.size(0); \
auto y = torch::zeros({N}, options); \
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
dim3 block(NUM_THREADS_PER_BLOCK); \
dim3 grid(NUM_BLOCKS); \
relu_##packed_type##_kernel<<<grid, block>>>( \
reinterpret_cast<element_type*>(x.data_ptr()), \
reinterpret_cast<element_type*>(y.data_ptr()), N); \
return y; \
}

#define TORCH_BINDING_RELU_V2(packed_type, th_type, element_type, n_elements) \
void relu_##packed_type##_v2(torch::Tensor x, torch::Tensor y) { \
void relu_##packed_type(torch::Tensor x, torch::Tensor y) { \
CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \
CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \
const int N = x.size(0); \
CHECK_TORCH_TENSOR_SHAPE(y, N) \
static const int NUM_THREADS_PER_BLOCK = 256 / (n_elements); \
const int NUM_BLOCKS = (N + 256 - 1) / 256; \
dim3 block(NUM_THREADS_PER_BLOCK); \
dim3 grid(NUM_BLOCKS); \
relu_##packed_type##_kernel<<<grid, block>>>( \
const int ndim = x.dim(); \
if (ndim != 2) { \
int N = 1; \
for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \
dim3 block(256 / (n_elements)); \
dim3 grid((N + 256 - 1) / 256); \
relu_##packed_type##_kernel<<<grid, block>>>( \
reinterpret_cast<element_type*>(x.data_ptr()), \
reinterpret_cast<element_type*>(y.data_ptr()), N); \
} else { \
const int S = x.size(0); \
const int K = x.size(1); \
const int N = S * K; \
if ((K/(n_elements)) <= 1024) { \
dim3 block(K/(n_elements)); \
dim3 grid(S); \
relu_##packed_type##_kernel<<<grid, block>>>( \
reinterpret_cast<element_type*>(x.data_ptr()), \
reinterpret_cast<element_type*>(y.data_ptr()), N); \
} else { \
int N = 1; \
for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \
dim3 block(256 / (n_elements)); \
dim3 grid((N + 256 - 1) / 256); \
relu_##packed_type##_kernel<<<grid, block>>>( \
reinterpret_cast<element_type*>(x.data_ptr()), \
reinterpret_cast<element_type*>(y.data_ptr()), N); \
} \
} \
}

TORCH_BINDING_RELU(f32, torch::kFloat32, float, 1)
TORCH_BINDING_RELU(f32x4, torch::kFloat32, float, 4)
TORCH_BINDING_RELU(f16, torch::kHalf, half, 1)
TORCH_BINDING_RELU(f16x2, torch::kHalf, half, 2)
TORCH_BINDING_RELU(f16x8, torch::kHalf, half, 8)
TORCH_BINDING_RELU_V2(f32, torch::kFloat32, float, 1)
TORCH_BINDING_RELU_V2(f32x4, torch::kFloat32, float, 4)
TORCH_BINDING_RELU_V2(f16, torch::kHalf, half, 1)
TORCH_BINDING_RELU_V2(f16x2, torch::kHalf, half, 2)
TORCH_BINDING_RELU_V2(f16x8, torch::kHalf, half, 8)

TORCH_BINDING_RELU(f32, torch::kFloat32, float, 1)
TORCH_BINDING_RELU(f32x4, torch::kFloat32, float, 4)
TORCH_BINDING_RELU(f16, torch::kHalf, half, 1)
TORCH_BINDING_RELU(f16x2, torch::kHalf, half, 2)
TORCH_BINDING_RELU(f16x8, torch::kHalf, half, 8)
TORCH_BINDING_RELU(f16x8_pack, torch::kHalf, half, 8)

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(relu_f32)
TORCH_BINDING_COMMON_EXTENSION(relu_f32x4)
TORCH_BINDING_COMMON_EXTENSION(relu_f16)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x8)
TORCH_BINDING_COMMON_EXTENSION(relu_f32_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f32x4_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x2_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x8_v2)
TORCH_BINDING_COMMON_EXTENSION(relu_f16x8_pack)
}
42 changes: 15 additions & 27 deletions relu/relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,27 @@ def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str,
total_time = (end - start) * 1000 # ms
mean_time = total_time / iters
out_info = f"out_{tag}"
out_val = out.detach().cpu().numpy().tolist()[:2]
out_val = out.flatten().detach().cpu().numpy().tolist()[:2]
out_val = [round(v, 8) for v in out_val]
print(f"{out_info:>15}: {out_val}, time:{mean_time:.8f}ms")
print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms")
if show_all: print(out)
return out, mean_time


print("-" * 80)
N_ELEMENTS = 256*256*4
x = torch.randn((N_ELEMENTS)).cuda().float()
run_benchmark(lib.relu_f32, x, "f32")
run_benchmark(lib.relu_f32x4, x, "f32x4")
run_benchmark(torch.relu, x , "f32_th")
S, K = 4096, 4096
x = torch.randn((S, K)).cuda().float().contiguous()
y = torch.zeros_like(x).cuda().float().contiguous()
run_benchmark(lib.relu_f32, x, "f32", y)
run_benchmark(lib.relu_f32x4, x, "f32x4", y)
run_benchmark(torch.relu, x, "f32_th")

print("-" * 80)
x_f16 = x.half()
run_benchmark(lib.relu_f16, x_f16, "f16")
run_benchmark(lib.relu_f16x2, x_f16, "f16x2")
run_benchmark(lib.relu_f16x8, x_f16, "f16x8")
run_benchmark(torch.relu, x_f16 , "f16_th")

print("-" * 80)
# v2: no copy of y Tensor
y = torch.zeros_like(x).cuda().float()
run_benchmark(lib.relu_f32_v2, x, "f32(v2)", y)
run_benchmark(lib.relu_f32x4_v2, x, "f32x4(v2)", y)
run_benchmark(torch.relu, x , "f32_th")

print("-" * 80)
# v2: no copy of y Tensor
y_f16 = torch.zeros_like(x_f16).cuda().half()
run_benchmark(lib.relu_f16_v2, x_f16, "f16(v2)", y_f16)
run_benchmark(lib.relu_f16x2_v2, x_f16, "f16x2(v2)", y_f16)
run_benchmark(lib.relu_f16x8_v2, x_f16, "f16x8(v2)", y_f16)
run_benchmark(torch.relu, x_f16 , "f16_th")
x_f16 = x.half().contiguous()
y_f16 = y.half().contiguous()
run_benchmark(lib.relu_f16, x_f16, "f16", y_f16)
run_benchmark(lib.relu_f16x2, x_f16, "f16x2", y_f16)
run_benchmark(lib.relu_f16x8, x_f16, "f16x8", y_f16)
run_benchmark(lib.relu_f16x8_pack, x_f16, "f16x8pack", y_f16)
run_benchmark(torch.relu, x_f16, "f16_th")
print("-" * 80)

0 comments on commit d43c53d

Please sign in to comment.