From 5f812a3131d870d3668a7dea056071a3f1089ff0 Mon Sep 17 00:00:00 2001 From: YISH Date: Sat, 16 Dec 2023 10:03:28 +0800 Subject: [PATCH] Merge CUDA templates of scatter_add and scatter --- candle-kernels/src/indexing.cu | 153 +++++++++++++-------------------- 1 file changed, 58 insertions(+), 95 deletions(-) diff --git a/candle-kernels/src/indexing.cu b/candle-kernels/src/indexing.cu index 92204af71b..7ca7a67c56 100644 --- a/candle-kernels/src/indexing.cu +++ b/candle-kernels/src/indexing.cu @@ -112,65 +112,7 @@ extern "C" __global__ void FN_NAME( \ ) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ -template -__device__ void scatter_assign( - const I *ids, - const T *inp, - T *out, - const size_t left_size, - const size_t src_dim_size, - const size_t dst_dim_size, - const size_t right_size -) { - const size_t numel = left_size * right_size; - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { - const size_t pre = i / right_size; - const size_t post = i % right_size; - for (unsigned int j = 0; j < src_dim_size; ++j) { - const size_t src_i = (pre * src_dim_size + j) * right_size + post; - const size_t idx = ids[src_i]; - const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; - out[dst_i] = inp[src_i]; - } - } -} - -#define SCATTER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ -extern "C" __global__ void FN_NAME( \ - const INDEX_TYPENAME *ids, \ - const TYPENAME *inp, \ - TYPENAME *out, \ - const size_t left_size, \ - const size_t src_dim_size, \ - const size_t dst_dim_size, \ - const size_t right_size \ -) { scatter_assign(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ - - -template -__device__ void scatter_add( - const I *ids, - const T *inp, - T *out, - const size_t left_size, - const size_t src_dim_size, - const size_t dst_dim_size, - const size_t right_size -) { - const size_t numel = left_size * right_size; - for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { - const size_t pre = i / right_size; - const size_t post = i % right_size; - for (unsigned int j = 0; j < src_dim_size; ++j) { - const size_t src_i = (pre * src_dim_size + j) * right_size + post; - const size_t idx = ids[src_i]; - const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post; - out[dst_i] += inp[src_i]; - } - } -} - -#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \ +#define SCATTER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME, OP) \ extern "C" __global__ void FN_NAME( \ const INDEX_TYPENAME *ids, \ const TYPENAME *inp, \ @@ -179,7 +121,19 @@ extern "C" __global__ void FN_NAME( \ const size_t src_dim_size, \ const size_t dst_dim_size, \ const size_t right_size \ -) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \ +) { \ + const size_t numel = left_size * right_size;\ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {\ + const size_t pre = i / right_size;\ + const size_t post = i % right_size;\ + for (unsigned int j = 0; j < src_dim_size; ++j) {\ + const size_t src_i = (pre * src_dim_size + j) * right_size + post;\ + const size_t idx = ids[src_i];\ + const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;\ + out[dst_i] OP inp[src_i];\ + }\ + }\ + } \ #if __CUDA_ARCH__ >= 800 @@ -192,9 +146,12 @@ GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16) IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16) IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16) IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16) -SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16) -SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16) -SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16) +SCATTER_OP(__nv_bfloat16, int64_t, sa_i64_bf16, +=) +SCATTER_OP(__nv_bfloat16, uint32_t, sa_u32_bf16, +=) +SCATTER_OP(__nv_bfloat16, uint8_t, sa_u8_bf16, +=) +SCATTER_OP(__nv_bfloat16, int64_t, scatter_i64_bf16, =) +SCATTER_OP(__nv_bfloat16, uint32_t, scatter_u32_bf16, =) +SCATTER_OP(__nv_bfloat16, uint8_t, scatter_u8_bf16, =) #endif #if __CUDA_ARCH__ >= 530 @@ -206,8 +163,10 @@ GATHER_OP(__half, uint32_t, gather_u32_f16) GATHER_OP(__half, uint8_t, gather_u8_f16) IA_OP(__half, uint32_t, ia_u32_f16) IA_OP(__half, uint8_t, ia_u8_f16) -SA_OP(__half, uint32_t, sa_u32_f16) -SA_OP(__half, uint8_t, sa_u8_f16) +SCATTER_OP(__half, uint32_t, sa_u32_f16, +=) +SCATTER_OP(__half, uint8_t, sa_u8_f16, +=) +SCATTER_OP(__half, uint32_t, scatter_u32_f16, =) +SCATTER_OP(__half, uint8_t, scatter_u8_f16, =) #endif IS_OP(float, int64_t, is_i64_f32) @@ -264,39 +223,43 @@ IA_OP(uint8_t, uint8_t, ia_u8_u8) IA_OP(uint32_t, uint8_t, ia_u8_u32) IA_OP(int64_t, uint8_t, ia_u8_i64) -SA_OP(float, int64_t, sa_i64_f32) -SA_OP(double, int64_t, sa_i64_f64) -SA_OP(uint8_t, int64_t, sa_i64_u8) -SA_OP(int64_t, int64_t, sa_i64_i64) -SA_OP(uint32_t, int64_t, sa_i64_u32) -SA_OP(float, uint32_t, sa_u32_f32) -SA_OP(double, uint32_t, sa_u32_f64) -SA_OP(uint8_t, uint32_t, sa_u32_u8) -SA_OP(int64_t, uint32_t, sa_u32_i64) -SA_OP(uint32_t, uint32_t, sa_u32_u32) +#pragma region scatter_add +SCATTER_OP(float, int64_t, sa_i64_f32, +=) +SCATTER_OP(double, int64_t, sa_i64_f64, +=) +SCATTER_OP(uint8_t, int64_t, sa_i64_u8, +=) +SCATTER_OP(int64_t, int64_t, sa_i64_i64, +=) +SCATTER_OP(uint32_t, int64_t, sa_i64_u32, +=) -SA_OP(float, uint8_t, sa_u8_f32) -SA_OP(double, uint8_t, sa_u8_f64) -SA_OP(uint8_t, uint8_t, sa_u8_u8) -SA_OP(uint32_t, uint8_t, sa_u8_u32) -SA_OP(int64_t, uint8_t, sa_u8_i64) +SCATTER_OP(float, uint32_t, sa_u32_f32, +=) +SCATTER_OP(double, uint32_t, sa_u32_f64, +=) +SCATTER_OP(uint8_t, uint32_t, sa_u32_u8, +=) +SCATTER_OP(int64_t, uint32_t, sa_u32_i64, +=) +SCATTER_OP(uint32_t, uint32_t, sa_u32_u32, +=) +SCATTER_OP(float, uint8_t, sa_u8_f32, +=) +SCATTER_OP(double, uint8_t, sa_u8_f64, +=) +SCATTER_OP(uint8_t, uint8_t, sa_u8_u8, +=) +SCATTER_OP(uint32_t, uint8_t, sa_u8_u32, +=) +SCATTER_OP(int64_t, uint8_t, sa_u8_i64, +=) +#pragma endregion scatter_add -SCATTER_OP(float, int64_t, scatter_i64_f32) -SCATTER_OP(double, int64_t, scatter_i64_f64) -SCATTER_OP(uint8_t, int64_t, scatter_i64_u8) -SCATTER_OP(int64_t, int64_t, scatter_i64_i64) -SCATTER_OP(uint32_t, int64_t, scatter_i64_u32) +#pragma region scatter_assign +SCATTER_OP(float, int64_t, scatter_i64_f32, =) +SCATTER_OP(double, int64_t, scatter_i64_f64, =) +SCATTER_OP(uint8_t, int64_t, scatter_i64_u8, =) +SCATTER_OP(int64_t, int64_t, scatter_i64_i64, =) +SCATTER_OP(uint32_t, int64_t, scatter_i64_u32, =) -SCATTER_OP(float, uint32_t, scatter_u32_f32) -SCATTER_OP(double, uint32_t, scatter_u32_f64) -SCATTER_OP(uint8_t, uint32_t, scatter_u32_u8) -SCATTER_OP(int64_t, uint32_t, scatter_u32_i64) -SCATTER_OP(uint32_t, uint32_t, scatter_u32_u32) +SCATTER_OP(float, uint32_t, scatter_u32_f32, =) +SCATTER_OP(double, uint32_t, scatter_u32_f64, =) +SCATTER_OP(uint8_t, uint32_t, scatter_u32_u8, =) +SCATTER_OP(int64_t, uint32_t, scatter_u32_i64, =) +SCATTER_OP(uint32_t, uint32_t, scatter_u32_u32, =) -SCATTER_OP(float, uint8_t, scatter_u8_f32) -SCATTER_OP(double, uint8_t, scatter_u8_f64) -SCATTER_OP(uint8_t, uint8_t, scatter_u8_u8) -SCATTER_OP(uint32_t, uint8_t, scatter_u8_u32) -SCATTER_OP(int64_t, uint8_t, scatter_u8_i64) +SCATTER_OP(float, uint8_t, scatter_u8_f32, =) +SCATTER_OP(double, uint8_t, scatter_u8_f64, =) +SCATTER_OP(uint8_t, uint8_t, scatter_u8_u8, =) +SCATTER_OP(uint32_t, uint8_t, scatter_u8_u32, =) +SCATTER_OP(int64_t, uint8_t, scatter_u8_i64, =) +#pragma endregion scatter_assign \ No newline at end of file