From a0870854ac8a301ca16d5e62db7a683c4bb150f7 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Thu, 23 May 2024 20:42:54 +0800 Subject: [PATCH] Fix host view for mnnvl (#166) Authors: - Chuang Zhu (https://github.com/chuangz0) Approvers: - https://github.com/linhu-nv - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/166 --- cpp/include/wholememory/wholememory.h | 6 ++++++ cpp/src/wholememory/memory_handle.cpp | 8 +++++++- cpp/src/wholememory/wholememory.cpp | 5 +++++ .../pylibwholegraph/binding/wholememory_binding.pyx | 6 +++++- 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index 66bd993fd..08f16213f 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -387,6 +387,12 @@ wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholemem size_t file_entry_size, const char* local_file_name); +/** + * @param comm : WholeMemory Comm + * @return : bool + */ +bool wholememory_is_intranode_communicator(wholememory_comm_t comm); + bool wholememory_is_build_with_nvshmem(); #ifdef WITH_NVSHMEM_SUPPORT wholememory_error_code_t wholememory_get_nvshmem_reference( diff --git a/cpp/src/wholememory/memory_handle.cpp b/cpp/src/wholememory/memory_handle.cpp index 8024b461a..ca8b0ad75 100644 --- a/cpp/src/wholememory/memory_handle.cpp +++ b/cpp/src/wholememory/memory_handle.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,6 +99,12 @@ class wholememory_impl { if (local_ptr != nullptr) *local_ptr = local_partition_memory_pointer_; if (local_size != nullptr) *local_size = rank_partition_strategy_.local_mem_size; if (local_offset != nullptr) *local_offset = rank_partition_strategy_.local_mem_offset; + if (location_ == WHOLEMEMORY_ML_HOST && (type_ == WHOLEMEMORY_MT_CONTINUOUS) && + (!(comm_->is_intranode()))) { + WHOLEMEMORY_WARN( + " Multi-node continuous type wholememory can only be accessed by GPU threads but not CPU " + "threads, regardless of whether the location of wholememory is host."); + } } virtual bool get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, diff --git a/cpp/src/wholememory/wholememory.cpp b/cpp/src/wholememory/wholememory.cpp index 59dcc89bb..180da2f01 100644 --- a/cpp/src/wholememory/wholememory.cpp +++ b/cpp/src/wholememory/wholememory.cpp @@ -261,6 +261,11 @@ wholememory_error_code_t wholememory_load_from_hdfs_file(wholememory_handle_t wh return WHOLEMEMORY_NOT_IMPLEMENTED; } +bool wholememory_is_intranode_communicator(wholememory_comm_t comm) +{ + return wholememory::is_intranode_communicator(comm); +} + bool wholememory_is_build_with_nvshmem() { #ifdef WITH_NVSHMEM_SUPPORT diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 7cbffadd4..feffa9162 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -184,7 +184,7 @@ cdef extern from "wholememory/wholememory.h": cdef wholememory_distributed_backend_t wholememory_communicator_get_distributed_backend( wholememory_comm_t comm) - + cdef bool wholememory_is_intranode_communicator(wholememory_comm_t comm) cpdef enum WholeMemoryErrorCode: Success = WHOLEMEMORY_SUCCESS @@ -1113,6 +1113,10 @@ cdef class PyWholeMemoryFlattenDlpack: cdef wholememory_comm_t comm cdef int world_rank cdef int world_size + if self.device_type == MlHost and mem_type == MtContinuous: + check_wholememory_error_code(wholememory_get_communicator(&comm, handle.wholememory_handle)) + if wholememory_is_intranode_communicator(comm) == False : + raise ValueError('Multi-node continuous type wholememory does not support host_view. Only supports host_view=false regardless of whether location is host or not.') global_size = wholememory_get_total_size(handle.wholememory_handle) if global_size % elt_size != 0: raise ValueError('global_size=%d not multiple of elt_size=%d' % (global_size, elt_size))