Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
Fix host view for mnnvl (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
chuangz0 authored May 23, 2024
1 parent afb6a11 commit a087085
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
6 changes: 6 additions & 0 deletions cpp/include/wholememory/wholememory.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,12 @@ wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholemem
size_t file_entry_size,
const char* local_file_name);

/**
* @param comm : WholeMemory Comm
* @return : bool
*/
bool wholememory_is_intranode_communicator(wholememory_comm_t comm);

bool wholememory_is_build_with_nvshmem();
#ifdef WITH_NVSHMEM_SUPPORT
wholememory_error_code_t wholememory_get_nvshmem_reference(
Expand Down
8 changes: 7 additions & 1 deletion cpp/src/wholememory/memory_handle.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -99,6 +99,12 @@ class wholememory_impl {
if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_;
if (local_size != nullptr) *local_size = rank_partition_strategy_.local_mem_size;
if (local_offset != nullptr) *local_offset = rank_partition_strategy_.local_mem_offset;
if (location_ == WHOLEMEMORY_ML_HOST && (type_ == WHOLEMEMORY_MT_CONTINUOUS) &&
(!(comm_->is_intranode()))) {
WHOLEMEMORY_WARN(
" Multi-node continuous type wholememory can only be accessed by GPU threads but not CPU "
"threads, regardless of whether the location of wholememory is host.");
}
}
virtual bool get_rank_memory(void** rank_memory_ptr,
size_t* rank_memory_size,
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/wholememory/wholememory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ wholememory_error_code_t wholememory_load_from_hdfs_file(wholememory_handle_t wh
return WHOLEMEMORY_NOT_IMPLEMENTED;
}

bool wholememory_is_intranode_communicator(wholememory_comm_t comm)
{
return wholememory::is_intranode_communicator(comm);
}

bool wholememory_is_build_with_nvshmem()
{
#ifdef WITH_NVSHMEM_SUPPORT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ cdef extern from "wholememory/wholememory.h":

cdef wholememory_distributed_backend_t wholememory_communicator_get_distributed_backend(
wholememory_comm_t comm)

cdef bool wholememory_is_intranode_communicator(wholememory_comm_t comm)

cpdef enum WholeMemoryErrorCode:
Success = WHOLEMEMORY_SUCCESS
Expand Down Expand Up @@ -1113,6 +1113,10 @@ cdef class PyWholeMemoryFlattenDlpack:
cdef wholememory_comm_t comm
cdef int world_rank
cdef int world_size
if self.device_type == MlHost and mem_type == MtContinuous:
check_wholememory_error_code(wholememory_get_communicator(&comm, handle.wholememory_handle))
if wholememory_is_intranode_communicator(comm) == False :
raise ValueError('Multi-node continuous type wholememory does not support host_view. Only supports host_view=false regardless of whether location is host or not.')
global_size = wholememory_get_total_size(handle.wholememory_handle)
if global_size % elt_size != 0:
raise ValueError('global_size=%d not multiple of elt_size=%d' % (global_size, elt_size))
Expand Down

0 comments on commit a087085

Please sign in to comment.