Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Forward-merge branch-24.06 into branch-24.08 #177

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion cpp/src/wholememory_ops/functions/gather_func.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -32,6 +34,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -40,6 +44,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -48,6 +54,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand Down Expand Up @@ -76,6 +84,75 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t,
void* indices,
wholememory_array_description_t,
bool,
void*,
void*,
wholememory_matrix_description_t,
cudaStream_t,
int) = nullptr;
if (embedding_is_float) {
if (indices_desc.dtype == WHOLEMEMORY_DT_INT) {
p_gather_func = gather_floating_int32_func;
} else {
p_gather_func = gather_floating_int64_func;
}
} else {
if (indices_desc.dtype == WHOLEMEMORY_DT_INT) {
p_gather_func = gather_integer_int32_func;
} else {
p_gather_func = gather_integer_int64_func;
}
}
return p_gather_func(embedding_gref,
embedding_desc,
indices,
indices_desc,
false,
nullptr,
output,
output_desc,
stream,
gather_sms);
} catch (const wholememory::cuda_error& rle) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (const wholememory::logic_error& le) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (...) {
return WHOLEMEMORY_LOGIC_ERROR;
}
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t gather_with_sorted_ids_func(
wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
void* raw_indices,
wholememory_array_description_t raw_indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
try {
bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype);
WHOLEMEMORY_CHECK(embedding_is_float ||
wholememory_dtype_is_integer_number(embedding_desc.dtype));
bool output_is_float = wholememory_dtype_is_floating_number(output_desc.dtype);
WHOLEMEMORY_CHECK(output_is_float || wholememory_dtype_is_integer_number(output_desc.dtype));
WHOLEMEMORY_EXPECTS(
embedding_is_float == output_is_float,
"embedding and output should be same number type, e.g. floating number or integer number.");
if (indices_desc.size == 0) { return WHOLEMEMORY_SUCCESS; }
WHOLEMEMORY_CHECK(indices_desc.size == raw_indices_desc.size);
WHOLEMEMORY_CHECK(indices_desc.dtype == raw_indices_desc.dtype);
wholememory_error_code_t (*p_gather_func)(wholememory_gref_t,
wholememory_matrix_description_t,
void* indices,
wholememory_array_description_t,
bool,
void*,
void*,
wholememory_matrix_description_t,
cudaStream_t,
Expand All @@ -97,6 +174,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
embedding_desc,
indices,
indices_desc,
true,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,13 +27,23 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int32_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
gather_temp_func<EmbeddingT, int32_t, OutputT>(embedding_gref,
embedding_desc,
indices,
indice_count,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt32,
Expand All @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,13 +27,23 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int64_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
gather_temp_func<EmbeddingT, int64_t, OutputT>(embedding_gref,
embedding_desc,
indices,
indice_count,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt64,
Expand All @@ -45,6 +55,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +75,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,13 +27,23 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int32_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
gather_temp_func<EmbeddingT, int32_t, OutputT>(embedding_gref,
embedding_desc,
indices,
indice_count,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt32,
Expand All @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,13 +27,23 @@ void gather_integer_int64_temp_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
int64_t indice_count,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int64_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
gather_temp_func<EmbeddingT, int64_t, OutputT>(embedding_gref,
embedding_desc,
indices,
indice_count,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt64,
Expand All @@ -45,6 +55,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
bool gather_with_sorted_ids,
void* raw_indices,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream,
Expand All @@ -63,6 +75,8 @@ wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_
static_cast<char*>(indices) +
indices_desc.storage_offset * wholememory_dtype_get_element_size(indices_desc.dtype),
indices_desc.size,
gather_with_sorted_ids,
raw_indices,
output,
output_desc,
stream,
Expand Down
Loading
Loading