Skip to content

Commit

Permalink
Merge CUDA templates of scatter_add and scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish committed Dec 16, 2023
1 parent 11d62be commit 5f812a3
Showing 1 changed file with 58 additions and 95 deletions.
153 changes: 58 additions & 95 deletions candle-kernels/src/indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,65 +112,7 @@ extern "C" __global__ void FN_NAME( \
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \


template<typename T, typename I>
__device__ void scatter_assign(
const I *ids,
const T *inp,
T *out,
const size_t left_size,
const size_t src_dim_size,
const size_t dst_dim_size,
const size_t right_size
) {
const size_t numel = left_size * right_size;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
const size_t pre = i / right_size;
const size_t post = i % right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] = inp[src_i];
}
}
}

#define SCATTER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
const TYPENAME *inp, \
TYPENAME *out, \
const size_t left_size, \
const size_t src_dim_size, \
const size_t dst_dim_size, \
const size_t right_size \
) { scatter_assign(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \


template<typename T, typename I>
__device__ void scatter_add(
const I *ids,
const T *inp,
T *out,
const size_t left_size,
const size_t src_dim_size,
const size_t dst_dim_size,
const size_t right_size
) {
const size_t numel = left_size * right_size;
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
const size_t pre = i / right_size;
const size_t post = i % right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
const size_t idx = ids[src_i];
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
}

#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
#define SCATTER_OP(TYPENAME, INDEX_TYPENAME, FN_NAME, OP) \
extern "C" __global__ void FN_NAME( \
const INDEX_TYPENAME *ids, \
const TYPENAME *inp, \
Expand All @@ -179,7 +121,19 @@ extern "C" __global__ void FN_NAME( \
const size_t src_dim_size, \
const size_t dst_dim_size, \
const size_t right_size \
) { scatter_add(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
) { \
const size_t numel = left_size * right_size;\
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {\
const size_t pre = i / right_size;\
const size_t post = i % right_size;\
for (unsigned int j = 0; j < src_dim_size; ++j) {\
const size_t src_i = (pre * src_dim_size + j) * right_size + post;\
const size_t idx = ids[src_i];\
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;\
out[dst_i] OP inp[src_i];\
}\
}\
} \


#if __CUDA_ARCH__ >= 800
Expand All @@ -192,9 +146,12 @@ GATHER_OP(__nv_bfloat16, uint8_t, gather_u8_bf16)
IA_OP(__nv_bfloat16, int64_t, ia_i64_bf16)
IA_OP(__nv_bfloat16, uint32_t, ia_u32_bf16)
IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
SCATTER_OP(__nv_bfloat16, int64_t, sa_i64_bf16, +=)
SCATTER_OP(__nv_bfloat16, uint32_t, sa_u32_bf16, +=)
SCATTER_OP(__nv_bfloat16, uint8_t, sa_u8_bf16, +=)
SCATTER_OP(__nv_bfloat16, int64_t, scatter_i64_bf16, =)
SCATTER_OP(__nv_bfloat16, uint32_t, scatter_u32_bf16, =)
SCATTER_OP(__nv_bfloat16, uint8_t, scatter_u8_bf16, =)
#endif

#if __CUDA_ARCH__ >= 530
Expand All @@ -206,8 +163,10 @@ GATHER_OP(__half, uint32_t, gather_u32_f16)
GATHER_OP(__half, uint8_t, gather_u8_f16)
IA_OP(__half, uint32_t, ia_u32_f16)
IA_OP(__half, uint8_t, ia_u8_f16)
SA_OP(__half, uint32_t, sa_u32_f16)
SA_OP(__half, uint8_t, sa_u8_f16)
SCATTER_OP(__half, uint32_t, sa_u32_f16, +=)
SCATTER_OP(__half, uint8_t, sa_u8_f16, +=)
SCATTER_OP(__half, uint32_t, scatter_u32_f16, =)
SCATTER_OP(__half, uint8_t, scatter_u8_f16, =)
#endif

IS_OP(float, int64_t, is_i64_f32)
Expand Down Expand Up @@ -264,39 +223,43 @@ IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

SA_OP(float, int64_t, sa_i64_f32)
SA_OP(double, int64_t, sa_i64_f64)
SA_OP(uint8_t, int64_t, sa_i64_u8)
SA_OP(int64_t, int64_t, sa_i64_i64)
SA_OP(uint32_t, int64_t, sa_i64_u32)

SA_OP(float, uint32_t, sa_u32_f32)
SA_OP(double, uint32_t, sa_u32_f64)
SA_OP(uint8_t, uint32_t, sa_u32_u8)
SA_OP(int64_t, uint32_t, sa_u32_i64)
SA_OP(uint32_t, uint32_t, sa_u32_u32)
#pragma region scatter_add
SCATTER_OP(float, int64_t, sa_i64_f32, +=)
SCATTER_OP(double, int64_t, sa_i64_f64, +=)
SCATTER_OP(uint8_t, int64_t, sa_i64_u8, +=)
SCATTER_OP(int64_t, int64_t, sa_i64_i64, +=)
SCATTER_OP(uint32_t, int64_t, sa_i64_u32, +=)

SA_OP(float, uint8_t, sa_u8_f32)
SA_OP(double, uint8_t, sa_u8_f64)
SA_OP(uint8_t, uint8_t, sa_u8_u8)
SA_OP(uint32_t, uint8_t, sa_u8_u32)
SA_OP(int64_t, uint8_t, sa_u8_i64)
SCATTER_OP(float, uint32_t, sa_u32_f32, +=)
SCATTER_OP(double, uint32_t, sa_u32_f64, +=)
SCATTER_OP(uint8_t, uint32_t, sa_u32_u8, +=)
SCATTER_OP(int64_t, uint32_t, sa_u32_i64, +=)
SCATTER_OP(uint32_t, uint32_t, sa_u32_u32, +=)

SCATTER_OP(float, uint8_t, sa_u8_f32, +=)
SCATTER_OP(double, uint8_t, sa_u8_f64, +=)
SCATTER_OP(uint8_t, uint8_t, sa_u8_u8, +=)
SCATTER_OP(uint32_t, uint8_t, sa_u8_u32, +=)
SCATTER_OP(int64_t, uint8_t, sa_u8_i64, +=)
#pragma endregion scatter_add

SCATTER_OP(float, int64_t, scatter_i64_f32)
SCATTER_OP(double, int64_t, scatter_i64_f64)
SCATTER_OP(uint8_t, int64_t, scatter_i64_u8)
SCATTER_OP(int64_t, int64_t, scatter_i64_i64)
SCATTER_OP(uint32_t, int64_t, scatter_i64_u32)
#pragma region scatter_assign
SCATTER_OP(float, int64_t, scatter_i64_f32, =)
SCATTER_OP(double, int64_t, scatter_i64_f64, =)
SCATTER_OP(uint8_t, int64_t, scatter_i64_u8, =)
SCATTER_OP(int64_t, int64_t, scatter_i64_i64, =)
SCATTER_OP(uint32_t, int64_t, scatter_i64_u32, =)

SCATTER_OP(float, uint32_t, scatter_u32_f32)
SCATTER_OP(double, uint32_t, scatter_u32_f64)
SCATTER_OP(uint8_t, uint32_t, scatter_u32_u8)
SCATTER_OP(int64_t, uint32_t, scatter_u32_i64)
SCATTER_OP(uint32_t, uint32_t, scatter_u32_u32)
SCATTER_OP(float, uint32_t, scatter_u32_f32, =)
SCATTER_OP(double, uint32_t, scatter_u32_f64, =)
SCATTER_OP(uint8_t, uint32_t, scatter_u32_u8, =)
SCATTER_OP(int64_t, uint32_t, scatter_u32_i64, =)
SCATTER_OP(uint32_t, uint32_t, scatter_u32_u32, =)

SCATTER_OP(float, uint8_t, scatter_u8_f32)
SCATTER_OP(double, uint8_t, scatter_u8_f64)
SCATTER_OP(uint8_t, uint8_t, scatter_u8_u8)
SCATTER_OP(uint32_t, uint8_t, scatter_u8_u32)
SCATTER_OP(int64_t, uint8_t, scatter_u8_i64)
SCATTER_OP(float, uint8_t, scatter_u8_f32, =)
SCATTER_OP(double, uint8_t, scatter_u8_f64, =)
SCATTER_OP(uint8_t, uint8_t, scatter_u8_u8, =)
SCATTER_OP(uint32_t, uint8_t, scatter_u8_u32, =)
SCATTER_OP(int64_t, uint8_t, scatter_u8_i64, =)
#pragma endregion scatter_assign

0 comments on commit 5f812a3

Please sign in to comment.