Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

subwarp version gather op for small embedding size #165

Merged
merged 1 commit into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 125 additions & 1 deletion cpp/src/wholememory_ops/functions/gather_scatter_func.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -309,6 +309,62 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref,
return;
}

template <int N>
struct IsPowerOfTwo {
static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0);
};

template <typename EmbeddingT,
typename IndexT,
typename OutputT,
int SUB_WARP_SIZE = 1,
int ALIGNMENT = 1>
__global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
const IndexT* indices,
int64_t indice_count,
OutputT* output,
wholememory_matrix_description_t output_desc)
{
static_assert(IsPowerOfTwo<SUB_WARP_SIZE>::value && SUB_WARP_SIZE < 32,
"SUB_WARP_SIZE must be the power of 2,and smaller than 32.");

auto block = cooperative_groups::this_thread_block();

auto subwarp = cooperative_groups::tiled_partition<SUB_WARP_SIZE>(block);
int sub_warp_id = subwarp.meta_group_size() * blockIdx.x + subwarp.meta_group_rank();
int sub_warp_num = subwarp.meta_group_size() * gridDim.x;

int lane_id_in_sub_warp = subwarp.thread_rank();
wholememory::device_reference<EmbeddingT> embedding_dev_ref(embedding_gref);

int embedding_size = embedding_desc.sizes[1];
int64_t embedding_stride = embedding_desc.stride;
int64_t output_stride = output_desc.stride;

typed_data_vector<EmbeddingT, ALIGNMENT> embeddings;
typed_data_vector<OutputT, ALIGNMENT> outputs;
for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) {
OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx;
IndexT embedding_table_idx = indices[output_idx];
if (embedding_table_idx < 0) continue;
int64_t embedding_offset =
embedding_desc.storage_offset + embedding_table_idx * embedding_stride;

for (int emb_idx = lane_id_in_sub_warp * ALIGNMENT; emb_idx < embedding_size;
emb_idx += ALIGNMENT * SUB_WARP_SIZE) {
mov_data<sizeof(EmbeddingT) * ALIGNMENT>(&embeddings,
&embedding_dev_ref[embedding_offset + emb_idx]);
#pragma unroll
for (int sub_idx = 0; sub_idx < ALIGNMENT; sub_idx++) {
typed_data_vector_at(outputs, sub_idx) =
convert_type<EmbeddingT, OutputT>(typed_data_vector_at(embeddings, sub_idx));
}
mov_data<sizeof(OutputT) * ALIGNMENT>(output_ptr + emb_idx, &outputs);
}
}
}

template <typename EmbeddingT, typename IndexT, typename OutputT>
void gather_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
Expand Down Expand Up @@ -338,6 +394,7 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
int64_t,
OutputT*,
wholememory_matrix_description_t) = nullptr;

switch (alignment) {
case 16: {
kernel_fn = gather_func_kernel<EmbeddingT, IndexT, OutputT, 16>;
Expand Down Expand Up @@ -367,6 +424,73 @@ void gather_temp_func(wholememory_gref_t embedding_gref,
int block_size = 1024;
int block_count = indice_count > 1568 ? 1568 : indice_count;
if (gather_sms != -1) block_count = gather_sms;

// for small embedding size ,use subwarp to gather
int min_threads_per_embedding = embedding_desc.sizes[1] / alignment;
if (min_threads_per_embedding < 32) {
#define SWITCH_GATHER_FUNC_WITH_ALIGNMENT(KERNEL_NAME, SUB_WARP_SIZE) \
switch (alignment) { \
case 16: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 16>; \
break; \
} \
case 8: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 8>; \
break; \
} \
case 4: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 4>; \
break; \
} \
case 2: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 2>; \
break; \
} \
case 1: { \
kernel_fn = KERNEL_NAME<EmbeddingT, IndexT, OutputT, SUB_WARP_SIZE, 1>; \
break; \
} \
default: { \
WHOLEMEMORY_FAIL("gather func alignment=%d.", alignment); \
return; \
} \
}

int threads_per_embedding = 16;
if (min_threads_per_embedding >= 16) {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 16);
threads_per_embedding = 16;
} else if (min_threads_per_embedding < 16 && min_threads_per_embedding >= 8) {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 8);
threads_per_embedding = 8;
} else if (min_threads_per_embedding < 8 && min_threads_per_embedding >= 4) {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 4);
threads_per_embedding = 4;
} else if (min_threads_per_embedding < 4 && min_threads_per_embedding >= 2) {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 2);
threads_per_embedding = 2;
} else {
SWITCH_GATHER_FUNC_WITH_ALIGNMENT(gather_func_sub_warp_kernel, 1);
threads_per_embedding = 1;
}

#undef SWITCH_GATHER_FUNC_WITH_ALIGNMENT
block_size = 128;
int max_blocks_per_sm = 8;
WM_CUDA_CHECK(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, kernel_fn, block_size, 0));

int sm_count = 100;
int device_id = 0;
WM_CUDA_CHECK(cudaGetDevice(&device_id));
WM_CUDA_CHECK(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id));

// block_count = indice_count > 1568 ? 1568 : indice_count;
int min_embedding_per_block = block_size / threads_per_embedding;
block_count = min((int)(indice_count + min_embedding_per_block - 1) / min_embedding_per_block,
sm_count * max_blocks_per_sm * 4);
if (gather_sms != -1) block_count = gather_sms * max_blocks_per_sm;
}
kernel_fn<<<block_count, block_size, 0, stream>>>(embedding_gref,
embedding_desc,
static_cast<const IndexT*>(indices),
Expand Down
12 changes: 11 additions & 1 deletion cpp/tests/wholememory_ops/wholememory_gather_tests.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -311,6 +311,16 @@ INSTANTIATE_TEST_SUITE_P(
.set_embedding_dim(11)
.set_embedding_stride(12)
.set_indices_count(100005),
WholeMemoryGatherTestParam()
.set_memory_type(WHOLEMEMORY_MT_CHUNKED)
.set_embedding_dim(1)
.set_embedding_stride(1)
.set_indices_count(100005),
WholeMemoryGatherTestParam()
.set_memory_type(WHOLEMEMORY_MT_CHUNKED)
.set_embedding_dim(1)
.set_embedding_stride(2)
.set_indices_count(100005),
WholeMemoryGatherTestParam()
.set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED)
.set_embedding_dim(11)
Expand Down
Loading