Skip to content

Commit

Permalink
Merge branch 'rapidsai:branch-24.02' into branch-24.02
Browse files Browse the repository at this point in the history
  • Loading branch information
chang-l authored Mar 2, 2024
2 parents e8a6ce4 + cecd3ff commit 923756c
Show file tree
Hide file tree
Showing 47 changed files with 391 additions and 165 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ cpp/.idea/
cpp/cmake-build-debug/
pylibwholegraph/.idea/
pylibwholegraph/cmake-build-debug/
compile_commands.json
25 changes: 25 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,28 @@
# wholegraph 24.02.00 (12 Feb 2024)

## 🐛 Bug Fixes

- Revert "Exclude tests from builds ([#127)" (#130](https://github.com/rapidsai/wholegraph/pull/127)" (#130)) [@raydouglass](https://github.com/raydouglass)
- Exclude tests from builds ([#127](https://github.com/rapidsai/wholegraph/pull/127)) [@vyasr](https://github.com/vyasr)
- fix a bug for embedding optimizer, which leads to undefined behavior ([#108](https://github.com/rapidsai/wholegraph/pull/108)) [@linhu-nv](https://github.com/linhu-nv)
- fix inferencesample option ([#107](https://github.com/rapidsai/wholegraph/pull/107)) [@chuangz0](https://github.com/chuangz0)

## 🚀 New Features

- allow users to control gather/scatter sms ([#124](https://github.com/rapidsai/wholegraph/pull/124)) [@linhu-nv](https://github.com/linhu-nv)

## 🛠️ Improvements

- Logging level ([#123](https://github.com/rapidsai/wholegraph/pull/123)) [@linhu-nv](https://github.com/linhu-nv)
- Fix pip dependencies ([#118](https://github.com/rapidsai/wholegraph/pull/118)) [@trxcllnt](https://github.com/trxcllnt)
- Remove usages of rapids-env-update ([#117](https://github.com/rapidsai/wholegraph/pull/117)) [@KyleFromNVIDIA](https://github.com/KyleFromNVIDIA)
- refactor CUDA versions in dependencies.yaml ([#115](https://github.com/rapidsai/wholegraph/pull/115)) [@jameslamb](https://github.com/jameslamb)
- Don't overwrite wholegraph_ROOT if provided ([#114](https://github.com/rapidsai/wholegraph/pull/114)) [@vyasr](https://github.com/vyasr)
- added Direct IO support for WholeMemory loading ([#113](https://github.com/rapidsai/wholegraph/pull/113)) [@dongxuy04](https://github.com/dongxuy04)
- Align versions for cudnn, clang-tools, cython, and doxygen with the rest of RAPIDS. ([#112](https://github.com/rapidsai/wholegraph/pull/112)) [@bdice](https://github.com/bdice)
- Reset WholeGraph communicators during the finalize call ([#111](https://github.com/rapidsai/wholegraph/pull/111)) [@chang-l](https://github.com/chang-l)
- Forward-merge branch-23.12 to branch-24.02 ([#102](https://github.com/rapidsai/wholegraph/pull/102)) [@bdice](https://github.com/bdice)

# wholegraph 23.12.00 (6 Dec 2023)

## 🐛 Bug Fixes
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/wholememory/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ wholememory_error_code_t wholememory_destroy_embedding_cache_policy(
* @param memory_location : Memory Location of the underlying WholeMemory
* @param optimizer : Optimizer to use for training, if don't train embedding, use nullptr
* @param cache_policy : Cache policy for this embedding, if don't use cache, use nullptr
* @param user_defined_sms : User-defined sms number for raw embedding gather/scatter
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_create_embedding(
Expand All @@ -139,7 +140,8 @@ wholememory_error_code_t wholememory_create_embedding(
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_optimizer_t optimizer,
wholememory_embedding_cache_policy_t cache_policy);
wholememory_embedding_cache_policy_t cache_policy,
int user_defined_sms = -1);

/**
* Destroy WholeMemory Embedding
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/wholememory/wholememory.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ enum wholememory_distributed_backend_t {
/**
* Initialize WholeMemory library
* @param flags : reserved should be 0
* @param wm_log_level : wholememory log level, the default level is "info"
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_init(unsigned int flags);
wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level = 3);

/**
* Finalize WholeMemory library
Expand Down
8 changes: 6 additions & 2 deletions cpp/include/wholememory/wholememory_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@ extern "C" {
* @param output_tensor : output tensor to gather to, should NOT be WholeMemoryTensor
* @param p_env_fns : pointers to environment functions.
* @param stream : cudaStream_t to use.
* @param gather_sms : the number of stream multiprocessor used in gather kernel
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_tensor,
wholememory_tensor_t indices_tensor,
wholememory_tensor_t output_tensor,
wholememory_env_func_t* p_env_fns,
void* stream);
void* stream,
int gather_sms = -1);

/**
* Scatter Op
Expand All @@ -45,13 +47,15 @@ wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_ten
* @param wholememory_tensor : WholeMemory Tensor of embedding table.
* @param p_env_fns : pointers to environment functions.
* @param stream : cudaStream_t to use.
* @param scatter_sms : the number of stream multiprocessor used in scatter kernel
* @return : wholememory_error_code_t
*/
wholememory_error_code_t wholememory_scatter(wholememory_tensor_t input_tensor,
wholememory_tensor_t indices_tensor,
wholememory_tensor_t wholememory_tensor,
wholememory_env_func_t* p_env_fns,
void* stream);
void* stream,
int scatter_sms = -1);

/**
* Just a test function,
Expand Down
23 changes: 19 additions & 4 deletions cpp/src/wholememory/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,21 @@ wholememory_error_code_t embedding_base::destroy_optimizer_states() noexcept
return WHOLEMEMORY_SUCCESS;
}

wholememory_error_code_t embedding_base::set_gather_sms(int sms) noexcept
{
if (sms != -1) {
if (sms <= 0) {
WHOLEMEMORY_WARN("Illegal SM number for gather/scatter! Will use default size.");
sms = -1;
} else if (sms > 1568) {
WHOLEMEMORY_WARN("SM number for gather/scatter is too large! Will use default size.");
sms = -1;
}
}
gather_sms_ = sms;
return WHOLEMEMORY_SUCCESS;
}

void embedding_base::deallocate() noexcept
{
if (optimizer != nullptr) {
Expand Down Expand Up @@ -477,7 +492,7 @@ wholememory_error_code_t noncached_embedding::gather(wholememory_tensor_t indice
cudaStream_t stream) noexcept
{
WHOLEMEMORY_RETURN_ON_FAIL(
wholememory_gather(allocated_embedding, indices, output, p_env_fns, stream));
wholememory_gather(allocated_embedding, indices, output, p_env_fns, stream, gather_sms_));
return WHOLEMEMORY_SUCCESS;
}

Expand Down Expand Up @@ -845,7 +860,8 @@ wholememory_error_code_t wholememory_create_embedding(
wholememory_memory_type_t memory_type,
wholememory_memory_location_t memory_location,
wholememory_embedding_optimizer_t optimizer,
wholememory_embedding_cache_policy_t cache_policy)
wholememory_embedding_cache_policy_t cache_policy,
int user_defined_sms)
{
wholememory_matrix_description_t embedding_matrix_description;
if (!wholememory_convert_tensor_desc_to_matrix(&embedding_matrix_description,
Expand Down Expand Up @@ -909,10 +925,9 @@ wholememory_error_code_t wholememory_create_embedding(
} else {
embedding_impl_ptr = new wholememory::noncached_embedding();
}

WHOLEMEMORY_RETURN_ON_FAIL(embedding_impl_ptr->allocate(
&embedding_matrix_description, comm, memory_type, memory_location, cache_policy, optimizer));

embedding_impl_ptr->set_gather_sms(user_defined_sms);
*wholememory_embedding = static_cast<wholememory_embedding_t>(embedding_impl_ptr);
return WHOLEMEMORY_SUCCESS;
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/wholememory/embedding.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class embedding_base : public wholememory_embedding_ {
virtual wholememory_error_code_t drop_all_caches(cudaStream_t stream) const noexcept;

wholememory::embedding_cache_base* get_cache_ptr() const { return cache_ptr_; }
wholememory_error_code_t set_gather_sms(int sms) noexcept;

protected:
virtual wholememory_error_code_t init_optimizer_states() noexcept
Expand All @@ -96,6 +97,7 @@ class embedding_base : public wholememory_embedding_ {
wholememory_error_code_t create_optimizer_states() noexcept;
wholememory_error_code_t destroy_optimizer_states() noexcept;

int gather_sms_;
wholememory_comm_t raw_embedding_comm_ = nullptr;
wholememory::embedding_cache_base* cache_ptr_ = nullptr;
wholememory::embedding_optimizer_impl_base* optimizer_impl_base_ = nullptr;
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/wholememory/initialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <cuda.h>
#include <cuda_runtime_api.h>
#include <math.h>
#include <nccl.h>

#include "communicator.hpp"
Expand All @@ -32,7 +33,7 @@ static bool is_wm_init = false;
static const std::string RAFT_NAME = "wholememory";
static cudaDeviceProp* device_props = nullptr;

wholememory_error_code_t init(unsigned int flags) noexcept
wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept
{
try {
std::unique_lock<std::mutex> lock(mu);
Expand All @@ -50,6 +51,7 @@ wholememory_error_code_t init(unsigned int flags) noexcept
WM_CUDA_CHECK(cudaGetDeviceProperties(device_props + i, i));
}
is_wm_init = true;
wholememory::set_log_level(std::pow(10, wm_log_level));
return WHOLEMEMORY_SUCCESS;
} catch (raft::logic_error& logic_error) {
WHOLEMEMORY_ERROR("init failed, logic_error=%s", logic_error.what());
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/wholememory/initialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace wholememory {

wholememory_error_code_t init(unsigned int flags) noexcept;
wholememory_error_code_t init(unsigned int flags, unsigned int wm_log_level) noexcept;

wholememory_error_code_t finalize() noexcept;

Expand Down
5 changes: 4 additions & 1 deletion cpp/src/wholememory/wholememory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
extern "C" {
#endif

wholememory_error_code_t wholememory_init(unsigned int flags) { return wholememory::init(flags); }
wholememory_error_code_t wholememory_init(unsigned int flags, unsigned int wm_log_level)
{
return wholememory::init(flags, wm_log_level);
}

wholememory_error_code_t wholememory_finalize() { return wholememory::finalize(); }

Expand Down
28 changes: 20 additions & 8 deletions cpp/src/wholememory_ops/functions/gather_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,41 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
wholememory_array_description_t indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream);
cudaStream_t stream,
int gather_sms);
wholememory_error_code_t gather_integer_int64_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream);
cudaStream_t stream,
int gather_sms);
wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream);
cudaStream_t stream,
int gather_sms);
wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream);
cudaStream_t stream,
int gather_sms);

wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
wholememory_matrix_description_t embedding_desc,
void* indices,
wholememory_array_description_t indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream)
cudaStream_t stream,
int gather_sms)
{
try {
bool embedding_is_float = wholememory_dtype_is_floating_number(embedding_desc.dtype);
Expand All @@ -73,7 +78,8 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
wholememory_array_description_t,
void*,
wholememory_matrix_description_t,
cudaStream_t) = nullptr;
cudaStream_t,
int) = nullptr;
if (embedding_is_float) {
if (indices_desc.dtype == WHOLEMEMORY_DT_INT) {
p_gather_func = gather_floating_int32_func;
Expand All @@ -87,8 +93,14 @@ wholememory_error_code_t gather_func(wholememory_gref_t embedding_gref,
p_gather_func = gather_integer_int64_func;
}
}
return p_gather_func(
embedding_gref, embedding_desc, indices, indices_desc, output, output_desc, stream);
return p_gather_func(embedding_gref,
embedding_desc,
indices,
indices_desc,
output,
output_desc,
stream,
gather_sms);
} catch (const wholememory::cuda_error& rle) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (const wholememory::logic_error& le) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ void gather_floating_int32_temp_func(wholememory_gref_t embedding_gref,
int64_t indice_count,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream)
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int32_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream);
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt32,
Expand All @@ -46,7 +47,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
wholememory_array_description_t indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream)
cudaStream_t stream,
int gather_sms)
{
try {
WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype));
Expand All @@ -63,7 +65,8 @@ wholememory_error_code_t gather_floating_int32_func(wholememory_gref_t embedding
indices_desc.size,
output,
output_desc,
stream);
stream,
gather_sms);
} catch (const wholememory::cuda_error& wle) {
WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what());
return WHOLEMEMORY_LOGIC_ERROR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ void gather_floating_int64_temp_func(wholememory_gref_t embedding_gref,
int64_t indice_count,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream)
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int64_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream);
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncFloatingInt64,
Expand All @@ -46,7 +47,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
wholememory_array_description_t indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream)
cudaStream_t stream,
int gather_sms)
{
try {
WHOLEMEMORY_CHECK(wholememory_dtype_is_floating_number(embedding_desc.dtype));
Expand All @@ -63,7 +65,8 @@ wholememory_error_code_t gather_floating_int64_func(wholememory_gref_t embedding
indices_desc.size,
output,
output_desc,
stream);
stream,
gather_sms);
} catch (const wholememory::cuda_error& wle) {
WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what());
return WHOLEMEMORY_LOGIC_ERROR;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ void gather_integer_int32_temp_func(wholememory_gref_t embedding_gref,
int64_t indice_count,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream)
cudaStream_t stream,
int gather_sms)
{
gather_temp_func<EmbeddingT, int32_t, OutputT>(
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream);
embedding_gref, embedding_desc, indices, indice_count, output, output_desc, stream, gather_sms);
}

REGISTER_DISPATCH_TWO_TYPES(GatherFuncIntegerInt32,
Expand All @@ -46,7 +47,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
wholememory_array_description_t indices_desc,
void* output,
wholememory_matrix_description_t output_desc,
cudaStream_t stream)
cudaStream_t stream,
int gather_sms)
{
try {
WHOLEMEMORY_CHECK(wholememory_dtype_is_integer_number(embedding_desc.dtype));
Expand All @@ -63,7 +65,8 @@ wholememory_error_code_t gather_integer_int32_func(wholememory_gref_t embedding_
indices_desc.size,
output,
output_desc,
stream);
stream,
gather_sms);
} catch (const wholememory::cuda_error& wle) {
WHOLEMEMORY_ERROR("gather CUDA LOGIC Error %s\n", wle.what());
return WHOLEMEMORY_LOGIC_ERROR;
Expand Down
Loading

0 comments on commit 923756c

Please sign in to comment.