Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

quick fix to a map_indice bug && add comment for parameter round_robin_size #172

Merged
merged 1 commit into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/src/wholememory/file_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ static size_t get_handle_partial_size(size_t handle_size,
* @param suggested_buffer_size : Suggested buffer size to read.
* @param wm_rank : WholeMemory rank.
* @param wm_world_size : WholeMemory world size.
* @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy.
* @param round_robin_size : continuous embedding size of a rank using round robin shard strategy.
*/
static void read_file_list_to_local_memory_roundrobin(char* local_ptr,
size_t local_size,
Expand Down Expand Up @@ -407,7 +407,7 @@ static void read_file_list_to_local_memory(char* local_ptr,
* @param suggested_buffer_size : Suggested buffer size to read.
* @param wm_rank : WholeMemory rank.
* @param wm_world_size : WholeMemory world size.
* @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy.
* @param round_robin_size : continuous embedding size of a rank using round robin shard strategy.
* @param dev_id : the device bound to the rank.
*/
static void read_file_list_to_local_memory_roundrobin_with_multi_threads(
Expand Down Expand Up @@ -878,7 +878,7 @@ static void read_file_list_to_local_memory_with_multi_threads(char* local_ptr,
* @param suggested_buffer_size : Suggested buffer size to read.
* @param wm_rank : WholeMemory rank.
* @param wm_world_size : WholeMemory world size.
* @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy.
* @param round_robin_size : continuous embedding size of a rank using round robin shard strategy.
*/
static void read_file_list_to_local_memory_roundrobin_directio(
char* local_ptr,
Expand Down Expand Up @@ -1546,7 +1546,7 @@ static void read_file_list_to_local_memory_directio_with_multi_thread(
* @param suggested_buffer_size : Suggested buffer size to read.
* @param wm_rank : WholeMemory rank.
* @param wm_world_size : WholeMemory world size.
* @param round_robin_size : continuous embedding size of a rank using round robin shard stratehy.
* @param round_robin_size : continuous embedding size of a rank using round robin shard strategy.
* @param dev_id : the device bound to the rank.
*/
static void read_file_list_to_local_memory_roundrobin_directio_with_multi_threads(
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/wholememory_ops/functions/map_indices_func.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ void storage_idx2wm_emb_idx_temp_fn(void* indice_ptr,
if (block_num > 1568) block_num = 1568;
IndexT* indice = static_cast<IndexT*>(indice_ptr);
IndexT* mapped_indice = static_cast<IndexT*>(mapped_indice_ptr);
storage_idx2wm_emb_idx_kernel<<<block_num, block_size>>>(
storage_idx2wm_emb_idx_kernel<<<block_num, block_size, 0, stream>>>(
indice, mapped_indice, indice_size, world_size, entry_per_rank, round_robin_size);
WM_CUDA_CHECK(cudaStreamSynchronize(stream));
return;
Expand Down
4 changes: 3 additions & 1 deletion python/pylibwholegraph/pylibwholegraph/torch/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def create_embedding(
cache_policy: Union[WholeMemoryCachePolicy, None] = None,
random_init: bool = False,
gather_sms: int = -1,
round_robin_size=0,
round_robin_size: int = 0,
):
r"""
Create embedding
Expand All @@ -419,6 +419,7 @@ def create_embedding(
:param optimizer: optimizer
:param cache_policy: cache policy
:param gather_sms: the number of SMs used in gather process
:param round_robin_size: continuous embedding size of a rank using round robin shard strategy
:return: WholeMemoryEmbedding
"""
if optimizer is None:
Expand Down Expand Up @@ -491,6 +492,7 @@ def create_embedding_from_filelist(
:param optimizer: optimizer
:param cache_policy: cache policy
:param gather_sms: the number of SMs used in gather process
:param round_robin_size: continuous embedding size of a rank using round robin shard strategy
:return:
"""
if isinstance(filelist, str):
Expand Down
1 change: 1 addition & 0 deletions python/pylibwholegraph/pylibwholegraph/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def from_filelist(self, filelist: Union[List[str], str], round_robin_size: int =
"""
Load WholeMemory Tensor from file lists
:param filelist: file list to load from
:param round_robin_size: continuous embedding size of a rank using round robin shard strategy
:return: None
"""
if isinstance(filelist, str):
Expand Down
Loading