From 19210806718330f3397c75ef619c83c79102368e Mon Sep 17 00:00:00 2001 From: zhuofan1123 Date: Wed, 29 May 2024 21:54:49 +0800 Subject: [PATCH] Sort indices before gathering (#174) In continuous/chunked host mode, sorting indices before gathering can improve bandwidth by enhancing memory locality. Authors: - https://github.com/zhuofan1123 Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/174 --- .../wholememory_ops/functions/gather_func.cu | 81 +++++++++++- ...r_func_impl_floating_data_int32_indices.cu | 20 ++- ...r_func_impl_floating_data_int64_indices.cu | 20 ++- ...er_func_impl_integer_data_int32_indices.cu | 20 ++- ...er_func_impl_integer_data_int64_indices.cu | 20 ++- .../functions/gather_scatter_func.cuh | 18 ++- .../functions/gather_scatter_func.h | 14 +- .../functions/sort_indices_func.cu | 125 ++++++++++++++++++ .../functions/sort_indices_func.h | 34 +++++ cpp/src/wholememory_ops/gather_op.cpp | 18 ++- cpp/src/wholememory_ops/gather_op_impl.h | 3 +- .../wholememory_ops/gather_op_impl_mapped.cu | 49 +++++-- .../wholememory_gather_tests.cu | 10 ++ 13 files changed, 401 insertions(+), 31 deletions(-) create mode 100644 cpp/src/wholememory_ops/functions/sort_indices_func.cu create mode 100644 cpp/src/wholememory_ops/functions/sort_indices_func.h diff --git a/cpp/src/wholememory_ops/functions/gather_func.cu b/cpp/src/wholememory_ops/functions/gather_func.cu index 0b79f0f15..271245d78 100644 --- a/cpp/src/wholememory_ops/functions/gather_func.cu +++ b/cpp/src/wholememory_ops/functions/gather_func.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -32,6 +34,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -40,6 +44,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -48,6 +54,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -76,6 +84,75 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t, void* indices, wholememory_array_description_t, + bool, + void*, + void*, + wholememory_matrix_description_t, + cudaStream_t, + int) = nullptr; + if (embedding_is_float) { + if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { + p_gather_func = gather_floating_int32_func; + } else { + p_gather_func = gather_floating_int64_func; + } + } else { + if (indices_desc.dtype == WHOLEMEMORY_DT_INT) { + p_gather_func = gather_integer_int32_func; + } else { + p_gather_func = gather_integer_int64_func; + } + } + return p_gather_func(embedding_gref, + embedding_desc, + indices, + indices_desc, + false, + nullptr, + output, + output_desc, + stream, + gather_sms); + } catch (const wholememory::cuda_error& rle) { + return WHOLEMEMORY_LOGIC_ERROR; + } catch (const wholememory::logic_error& le) { + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_LOGIC_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +wholememory_error_code_t gather_with_sorted_ids_func( + wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + void* indices, + wholememory_array_description_t indices_desc, + void* raw_indices, + wholememory_array_description_t raw_indices_desc, + void* output, + wholememory_matrix_description_t output_desc, + cudaStream_t stream, + int gather_sms) +{ + try { + bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype); + WHOLEMEMORY_CHECK(embedding_is_float || + wholememory_dtype_is_integer_number(embedding_desc.dtype)); + bool output_is_float = wholememory_dtype_is_floating_number(output_desc.dtype); + WHOLEMEMORY_CHECK(output_is_float || wholememory_dtype_is_integer_number(output_desc.dtype)); + WHOLEMEMORY_EXPECTS( + embedding_is_float == output_is_float, + "embedding and output should be same number type, e.g. floating number or integer number."); + if (indices_desc.size == 0) { return WHOLEMEMORY_SUCCESS; } + WHOLEMEMORY_CHECK(indices_desc.size == raw_indices_desc.size); + WHOLEMEMORY_CHECK(indices_desc.dtype == raw_indices_desc.dtype); + wholememory_error_code_t (*p_gather_func)(wholememory_gref_t, + wholememory_matrix_description_t, + void* indices, + wholememory_array_description_t, + bool, + void*, void*, wholememory_matrix_description_t, cudaStream_t, @@ -97,6 +174,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, embedding_desc, indices, indices_desc, + true, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu index c7679c508..a67ac0040 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt32, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu index af9d6d6ec..159aaf9a6 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_floating_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt64, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu index bdb7c0be8..9943cb14b 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int32_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt32, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_ static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu index 6a6c7f330..b06ebad9f 100644 --- a/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu +++ b/cpp/src/wholememory_ops/functions/gather_func_impl_integer_data_int64_indices.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,13 +27,23 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, int gather_sms) { - gather_temp_func( - embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms); + gather_temp_func(embedding_gref, + embedding_desc, + indices, + indice_count, + gather_with_sorted_ids, + raw_indices, + output, + output_desc, + stream, + gather_sms); } REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt64, @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ wholememory_matrix_description_t embedding_desc, void* indices, wholememory_array_description_t indices_desc, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_ static_cast(indices) + indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype), indices_desc.size, + gather_with_sorted_ids, + raw_indices, output, output_desc, stream, diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh index c7983a6dc..a4979f7be 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.cuh @@ -255,6 +255,8 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, const IndexT* indices, int64_t indice_count, + bool gather_with_sorted_ids, + const IndexT* raw_indices, OutputT* output, wholememory_matrix_description_t output_desc) { @@ -284,7 +286,9 @@ __global__ void gather_func_kernel(wholememory_gref_t embedding_gref, for (int64_t output_idx = warp_id; output_idx < indice_count; output_idx += gridDim.x * (blockDim.x / 32)) { - OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; + int64_t raw_output_idx = + gather_with_sorted_ids ? (int64_t)(raw_indices[output_idx]) : output_idx; + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx; if (!use_shm) { my_shared = output_ptr; } int64_t embedding_table_idx = indices[output_idx]; if (embedding_table_idx < 0) continue; @@ -323,6 +327,8 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, const IndexT* indices, int64_t indice_count, + bool gather_with_sorted_ids, + const IndexT* raw_indices, OutputT* output, wholememory_matrix_description_t output_desc) { @@ -345,7 +351,9 @@ __global__ void gather_func_sub_warp_kernel(wholememory_gref_t embedding_gref, typed_data_vector embeddings; typed_data_vector outputs; for (int64_t output_idx = sub_warp_id; output_idx < indice_count; output_idx += sub_warp_num) { - OutputT* output_ptr = output + output_desc.storage_offset + output_stride * output_idx; + int64_t raw_output_idx = + gather_with_sorted_ids ? (int64_t)(raw_indices[output_idx]) : output_idx; + OutputT* output_ptr = output + output_desc.storage_offset + output_stride * raw_output_idx; IndexT embedding_table_idx = indices[output_idx]; if (embedding_table_idx < 0) continue; int64_t embedding_offset = @@ -370,6 +378,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t embedding_desc, void* indices, int64_t indice_count, + bool gather_with_sorted_ids, + void* raw_indices, void* output, wholememory_matrix_description_t output_desc, cudaStream_t stream, @@ -392,6 +402,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref, wholememory_matrix_description_t, const IndexT*, int64_t, + bool, + const IndexT*, OutputT*, wholememory_matrix_description_t) = nullptr; @@ -495,6 +507,8 @@ void gather_temp_func(wholememory_gref_t embedding_gref, embedding_desc, static_cast(indices), indice_count, + gather_with_sorted_ids, + static_cast(raw_indices), static_cast(output), output_desc); WM_CUDA_CHECK(cudaGetLastError()); diff --git a/cpp/src/wholememory_ops/functions/gather_scatter_func.h b/cpp/src/wholememory_ops/functions/gather_scatter_func.h index 0c0b9e4a4..374ea2b39 100644 --- a/cpp/src/wholememory_ops/functions/gather_scatter_func.h +++ b/cpp/src/wholememory_ops/functions/gather_scatter_func.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,18 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref, cudaStream_t stream, int gather_sms = -1); +wholememory_error_code_t gather_with_sorted_ids_func( + wholememory_gref_t embedding_gref, + wholememory_matrix_description_t embedding_desc, + void* indices, + wholememory_array_description_t indices_desc, + void* raw_indices, + wholememory_array_description_t raw_indices_desc, + void* output, + wholememory_matrix_description_t output_desc, + cudaStream_t stream, + int gather_sms); + wholememory_error_code_t scatter_func(const void* input, wholememory_matrix_description_t input_desc, void* indices, diff --git a/cpp/src/wholememory_ops/functions/sort_indices_func.cu b/cpp/src/wholememory_ops/functions/sort_indices_func.cu new file mode 100644 index 000000000..4cbbb0837 --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_indices_func.cu @@ -0,0 +1,125 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sort_indices_func.h" + +#include +#include + +#include "cuda_macros.hpp" +#include "error.hpp" +#include "logger.hpp" +#include "wholememory_ops/register.hpp" + +namespace wholememory_ops { + +template +struct UnsignedType {}; + +template <> +struct UnsignedType { + using UType = unsigned int; +}; + +template <> +struct UnsignedType { + using UType = uint64_t; +}; + +template +void sort_indices_temp_func(const void* indices_before_sort, + wholememory_array_description_t indices_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + auto index_type = indices_desc.dtype; + WHOLEMEMORY_CHECK(indices_desc.storage_offset == 0); + WHOLEMEMORY_CHECK(index_type == WHOLEMEMORY_DT_INT || index_type == WHOLEMEMORY_DT_INT64); + wm_thrust_allocator& allocator = *p_thrust_allocator; + + IndexT* seq_indices = reinterpret_cast(allocator.allocate( + wholememory_get_memory_element_count_from_array(&indices_desc) * sizeof(IndexT))); + thrust::sequence(thrust::cuda::par_nosync(allocator).on(stream), + seq_indices, + seq_indices + indices_desc.size, + 0); + // use UTypeT to put minus indices at last. + using UTypeT = typename UnsignedType::UType; + const UTypeT* indices_to_sort = static_cast(indices_before_sort); + UTypeT* sorted_indice = static_cast(indices_after_sort); + void* cub_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::DeviceRadixSort::SortPairs(cub_temp_storage, + temp_storage_bytes, + indices_to_sort, + sorted_indice, + seq_indices, + static_cast(raw_indices), + indices_desc.size, + 0, + sizeof(UTypeT) * 8, + stream); + cub_temp_storage = allocator.allocate(temp_storage_bytes); + cub::DeviceRadixSort::SortPairs(cub_temp_storage, + temp_storage_bytes, + indices_to_sort, + sorted_indice, + seq_indices, + static_cast(raw_indices), + indices_desc.size, + 0, + sizeof(UTypeT) * 8, + stream); + allocator.deallocate(reinterpret_cast(seq_indices), + wholememory_get_memory_size_from_array(&indices_desc)); + allocator.deallocate(static_cast(cub_temp_storage), temp_storage_bytes); +} + +REGISTER_DISPATCH_ONE_TYPE(SortIndices, sort_indices_temp_func, SINT3264) + +wholememory_error_code_t sort_indices_func(const void* indices_before_sort, + wholememory_array_description_t indice_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream) +{ + try { + DISPATCH_ONE_TYPE(indice_desc.dtype, + SortIndices, + indices_before_sort, + indice_desc, + indices_after_sort, + raw_indices, + p_thrust_allocator, + p_env_fns, + stream); + } catch (wholememory::cuda_error& wce) { + WHOLEMEMORY_ERROR("sort_indices_func CUDA LOGIC Error %s\n", wce.what()); + return WHOLEMEMORY_CUDA_ERROR; + } catch (wholememory::logic_error& wle) { + WHOLEMEMORY_ERROR("sort_indices_func LOGIC Error %s\n", wle.what()); + return WHOLEMEMORY_LOGIC_ERROR; + } catch (...) { + return WHOLEMEMORY_UNKNOW_ERROR; + } + return WHOLEMEMORY_SUCCESS; +} + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/functions/sort_indices_func.h b/cpp/src/wholememory_ops/functions/sort_indices_func.h new file mode 100644 index 000000000..98a7932cb --- /dev/null +++ b/cpp/src/wholememory_ops/functions/sort_indices_func.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +#include +#include + +namespace wholememory_ops { + +wholememory_error_code_t sort_indices_func(const void* indices_before_sort, + wholememory_array_description_t indice_desc, + void* indices_after_sort, + void* raw_indices, + wm_thrust_allocator* p_thrust_allocator, + wholememory_env_func_t* p_env_fns, + cudaStream_t stream); + +} // namespace wholememory_ops diff --git a/cpp/src/wholememory_ops/gather_op.cpp b/cpp/src/wholememory_ops/gather_op.cpp index a6b2e97b5..98d41d222 100644 --- a/cpp/src/wholememory_ops/gather_op.cpp +++ b/cpp/src/wholememory_ops/gather_op.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,11 +27,13 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten void* stream, int gather_sms) { - bool const has_handle = wholememory_tensor_has_handle(wholememory_tensor); - wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_NONE; + bool const has_handle = wholememory_tensor_has_handle(wholememory_tensor); + wholememory_memory_type_t memory_type = WHOLEMEMORY_MT_NONE; + wholememory_memory_location_t memory_location = WHOLEMEMORY_ML_NONE; if (has_handle) { - memory_type = - wholememory_get_memory_type(wholememory_tensor_get_memory_handle(wholememory_tensor)); + auto memory_handle = wholememory_tensor_get_memory_handle(wholememory_tensor); + memory_type = wholememory_get_memory_type(memory_handle); + memory_location = wholememory_get_memory_location(memory_handle); } wholememory_matrix_description_t matrix_description; auto tensor_description = *wholememory_tensor_get_tensor_description(wholememory_tensor); @@ -98,12 +100,18 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten wholememory_gref_t gref; WHOLEMEMORY_RETURN_ON_FAIL(wholememory_tensor_get_global_reference(wholememory_tensor, &gref)); + int64_t entry_size = + tensor_description.sizes[1] * wholememory_dtype_get_element_size(tensor_description.dtype); + bool gather_with_sorted_ids = + (memory_location == WHOLEMEMORY_ML_HOST) && (entry_size <= 512) && + (memory_type == WHOLEMEMORY_MT_CHUNKED || memory_type == WHOLEMEMORY_MT_CONTINUOUS); return wholememory_ops::wholememory_gather_mapped(gref, matrix_description, indices, indices_desc, output, output_desc, + gather_with_sorted_ids, p_env_fns, static_cast(stream), gather_sms); diff --git a/cpp/src/wholememory_ops/gather_op_impl.h b/cpp/src/wholememory_ops/gather_op_impl.h index 6f85d6410..21896ff24 100644 --- a/cpp/src/wholememory_ops/gather_op_impl.h +++ b/cpp/src/wholememory_ops/gather_op_impl.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ wholememory_error_code_t wholememory_gather_mapped( wholememory_array_description_t indice_desc, void* output, wholememory_matrix_description_t output_desc, + bool gather_with_sorted_ids, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms); diff --git a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu index 38e64919d..849005860 100644 --- a/cpp/src/wholememory_ops/gather_op_impl_mapped.cu +++ b/cpp/src/wholememory_ops/gather_op_impl_mapped.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ #include "cuda_macros.hpp" #include "wholememory_ops/functions/gather_scatter_func.h" +#include "wholememory_ops/functions/sort_indices_func.h" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" namespace wholememory_ops { @@ -30,18 +33,46 @@ wholememory_error_code_t wholememory_gather_mapped( wholememory_array_description_t indice_desc, void* output, wholememory_matrix_description_t output_desc, + bool gather_with_sorted_ids, wholememory_env_func_t* p_env_fns, cudaStream_t stream, int gather_sms) { - WHOLEMEMORY_RETURN_ON_FAIL(gather_func(wholememory_gref, - wholememory_desc, - indices, - indice_desc, - output, - output_desc, - stream, - gather_sms)); + if (gather_with_sorted_ids) { + wm_thrust_allocator thrust_allocator(p_env_fns); + temp_memory_handle dev_indices_after_sort(p_env_fns); + void* dev_indices_after_sort_ptr = + dev_indices_after_sort.device_malloc(indice_desc.size, indice_desc.dtype); + temp_memory_handle dev_raw_indices(p_env_fns); + void* dev_raw_indices_ptr = dev_raw_indices.device_malloc(indice_desc.size, indice_desc.dtype); + auto raw_indices_desc = wholememory_create_array_desc(indice_desc.size, 0, indice_desc.dtype); + WHOLEMEMORY_RETURN_ON_FAIL(sort_indices_func(indices, + indice_desc, + dev_indices_after_sort_ptr, + dev_raw_indices_ptr, + &thrust_allocator, + p_env_fns, + stream)); + WHOLEMEMORY_RETURN_ON_FAIL(gather_with_sorted_ids_func(wholememory_gref, + wholememory_desc, + dev_indices_after_sort_ptr, + indice_desc, + dev_raw_indices_ptr, + raw_indices_desc, + output, + output_desc, + stream, + gather_sms)); + } else { + WHOLEMEMORY_RETURN_ON_FAIL(gather_func(wholememory_gref, + wholememory_desc, + indices, + indice_desc, + output, + output_desc, + stream, + gather_sms)); + } WM_CUDA_DEBUG_SYNC_STREAM(stream); return WHOLEMEMORY_SUCCESS; } diff --git a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu index fad314db9..ada9c87e1 100644 --- a/cpp/tests/wholememory_ops/wholememory_gather_tests.cu +++ b/cpp/tests/wholememory_ops/wholememory_gather_tests.cu @@ -301,6 +301,16 @@ INSTANTIATE_TEST_SUITE_P( WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_DISTRIBUTED) .set_memory_location(WHOLEMEMORY_ML_HOST), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .set_embedding_dim(1) + .set_indices_type(WHOLEMEMORY_DT_INT64), + WholeMemoryGatherTestParam() + .set_memory_type(WHOLEMEMORY_MT_CHUNKED) + .set_memory_location(WHOLEMEMORY_ML_HOST) + .set_embedding_dim(1) + .set_indices_type(WHOLEMEMORY_DT_INT64), WholeMemoryGatherTestParam() .set_memory_type(WHOLEMEMORY_MT_CONTINUOUS) .set_embedding_dim(11)