Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Commit

Permalink
allow users to choose shm allocation method for chunked/continous hos…
Browse files Browse the repository at this point in the history
…t memory (#187)

1. The default shm option is still SYSTEMV, but users can choose POSIX API through system env using `export WG_USE_POSIX_SHM=1`.
2. `unlink` shm files immediately after `shm_open` to avoid leftover memory in `/dev/shm` in case of a wholegraph crash.

Authors:
  - https://github.com/linhu-nv

Approvers:
  - Chuang Zhu (https://github.com/chuangz0)
  - Brad Rees (https://github.com/BradReesWork)

URL: #187
  • Loading branch information
linhu-nv authored Jun 18, 2024
1 parent ba505af commit 4ee62ba
Showing 1 changed file with 83 additions and 71 deletions.
154 changes: 83 additions & 71 deletions cpp/src/wholememory/memory_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,76 +456,87 @@ class global_mapped_host_wholememory_impl : public wholememory_impl {
host_memory_full_path.append("_").append("wm_host_").append(std::to_string(tensor_id));
return host_memory_full_path;
}
#define USE_SYSTEMV_SHM

#define SYSTEMV_SHM_PROJ_ID (0xE601EEEE)
void create_and_map_shared_host_memory()
{
WHOLEMEMORY_CHECK(is_intranode_communicator(comm_));
#ifdef USE_SYSTEMV_SHM
std::string shm_full_path = "/tmp/";
shm_full_path.append(get_host_memory_full_path(comm_, handle_->handle_id));
FILE* shm_fp = fopen(shm_full_path.c_str(), "w");
WHOLEMEMORY_CHECK(shm_fp != nullptr);
WHOLEMEMORY_CHECK(fclose(shm_fp) == 0);
auto shm_key = ftok(shm_full_path.c_str(), SYSTEMV_SHM_PROJ_ID);
WHOLEMEMORY_CHECK(shm_key != (key_t)-1);
const char* shm_env_var = std::getenv("WG_USE_POSIX_SHM");
if (shm_env_var == nullptr || shm_env_var[0] == '0') {
use_systemv_shm_ = true;
} else {
use_systemv_shm_ = false;
}
std::string shm_full_path;
if (use_systemv_shm_) {
shm_full_path = "/tmp/";
shm_full_path.append(get_host_memory_full_path(comm_, handle_->handle_id));
FILE* shm_fp = fopen(shm_full_path.c_str(), "w");
WHOLEMEMORY_CHECK(shm_fp != nullptr);
WHOLEMEMORY_CHECK(fclose(shm_fp) == 0);
} else {
shm_full_path = get_host_memory_full_path(comm_, handle_->handle_id);
}
int shm_id = -1;
#else
auto shm_full_path = get_host_memory_full_path(comm_, handle_->handle_id);
int shm_fd = -1;
#endif
int shm_fd = -1;
if (comm_->world_rank == 0) {
#ifdef USE_SYSTEMV_SHM
shm_id = shmget(shm_key, alloc_strategy_.local_alloc_size, 0644 | IPC_CREAT | IPC_EXCL);
if (shm_id == -1) {
WHOLEMEMORY_FATAL(
"Create host shared memory from IPC key %d failed, Reason=%s", shm_key, strerror(errno));
}
#else
shm_fd = shm_open(shm_full_path.c_str(), O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
if (shm_fd < 0) {
WHOLEMEMORY_FATAL("Create host shared memory from file %s failed, Reason=%s.",
shm_full_path.c_str(),
strerror(errno));
if (use_systemv_shm_) {
auto shm_key = ftok(shm_full_path.c_str(), SYSTEMV_SHM_PROJ_ID);
WHOLEMEMORY_CHECK(shm_key != (key_t)-1);
shm_id = shmget(shm_key, alloc_strategy_.local_alloc_size, 0644 | IPC_CREAT | IPC_EXCL);
if (shm_id == -1) {
WHOLEMEMORY_FATAL("Create host shared memory from IPC key %d failed, Reason=%s",
shm_key,
strerror(errno));
}
} else {
shm_fd = shm_open(shm_full_path.c_str(), O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
if (shm_fd < 0) {
WHOLEMEMORY_FATAL("Create host shared memory from file %s failed, Reason=%s.",
shm_full_path.c_str(),
strerror(errno));
}
WHOLEMEMORY_CHECK(ftruncate(shm_fd, alloc_strategy_.local_alloc_size) == 0);
}
WHOLEMEMORY_CHECK(ftruncate(shm_fd, alloc_strategy_.local_alloc_size) == 0);
#endif
communicator_barrier(comm_);
} else {
communicator_barrier(comm_);
#ifdef USE_SYSTEMV_SHM
shm_id = shmget(shm_key, alloc_strategy_.local_alloc_size, 0644);
if (shm_id == -1) {
WHOLEMEMORY_FATAL(
"Get host shared memory from IPC key %d failed, Reason=%s", shm_key, strerror(errno));
}
#else
shm_fd = shm_open(shm_full_path.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
if (shm_fd < 0) {
WHOLEMEMORY_FATAL("Rank=%d open host shared memory from file %s failed.",
comm_->world_rank,
shm_full_path.c_str());
if (use_systemv_shm_) {
auto shm_key = ftok(shm_full_path.c_str(), SYSTEMV_SHM_PROJ_ID);
WHOLEMEMORY_CHECK(shm_key != (key_t)-1);
shm_id = shmget(shm_key, alloc_strategy_.local_alloc_size, 0644);
if (shm_id == -1) {
WHOLEMEMORY_FATAL(
"Get host shared memory from IPC key %d failed, Reason=%s", shm_key, strerror(errno));
}
} else {
shm_fd = shm_open(shm_full_path.c_str(), O_RDWR, S_IRUSR | S_IWUSR);
if (shm_fd < 0) {
WHOLEMEMORY_FATAL("Rank=%d open host shared memory from file %s failed.",
comm_->world_rank,
shm_full_path.c_str());
}
}
#endif
}
communicator_barrier(comm_);
if (!use_systemv_shm_ && comm_->world_rank == 0) {
WHOLEMEMORY_CHECK(shm_unlink(shm_full_path.c_str()) == 0);
}
void* mmap_ptr = nullptr;
#ifdef USE_SYSTEMV_SHM
mmap_ptr = shmat(shm_id, nullptr, 0);
WHOLEMEMORY_CHECK(mmap_ptr != (void*)-1);
#else
mmap_ptr = mmap(
nullptr, alloc_strategy_.total_alloc_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
WHOLEMEMORY_CHECK(mmap_ptr != (void*)-1);
#endif
if (use_systemv_shm_) {
mmap_ptr = shmat(shm_id, nullptr, 0);
WHOLEMEMORY_CHECK(mmap_ptr != (void*)-1);
} else {
mmap_ptr = mmap(
nullptr, alloc_strategy_.total_alloc_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
WHOLEMEMORY_CHECK(mmap_ptr != (void*)-1);
}
memset(static_cast<char*>(mmap_ptr) + rank_partition_strategy_.local_mem_offset,
0,
rank_partition_strategy_.local_mem_size);
WM_CUDA_CHECK_NO_THROW(
cudaHostRegister(mmap_ptr, alloc_strategy_.total_alloc_size, cudaHostRegisterDefault));
#ifndef USE_SYSTEMV_SHM
WHOLEMEMORY_CHECK(close(shm_fd) == 0);
#endif
if (!use_systemv_shm_) WHOLEMEMORY_CHECK(close(shm_fd) == 0);
void* dev_ptr = nullptr;
WM_CUDA_CHECK_NO_THROW(cudaHostGetDevicePointer(&dev_ptr, mmap_ptr, 0));
WHOLEMEMORY_CHECK(dev_ptr == mmap_ptr);
Expand All @@ -540,31 +551,30 @@ class global_mapped_host_wholememory_impl : public wholememory_impl {
void* ptr = shared_host_handle_.shared_host_memory_ptr;
if (ptr == nullptr) return;
WM_CUDA_CHECK(cudaHostUnregister(ptr));
#ifdef USE_SYSTEMV_SHM
std::string shm_full_path = "/tmp/";
shm_full_path.append(get_host_memory_full_path(comm_, handle_->handle_id));
auto shm_key = ftok(shm_full_path.c_str(), SYSTEMV_SHM_PROJ_ID);
WHOLEMEMORY_CHECK(shm_key != (key_t)-1);
int shm_id = shmget(shm_key, alloc_strategy_.local_alloc_size, 0644);
if (shm_id == -1) {
WHOLEMEMORY_FATAL("Get host shared memory from IPC key %d for delete failed, Reason=%s",
shm_key,
strerror(errno));
std::string shm_full_path;
int shm_id = -1;
if (use_systemv_shm_) {
shm_full_path = "/tmp/";
shm_full_path.append(get_host_memory_full_path(comm_, handle_->handle_id));
auto shm_key = ftok(shm_full_path.c_str(), SYSTEMV_SHM_PROJ_ID);
WHOLEMEMORY_CHECK(shm_key != (key_t)-1);
shm_id = shmget(shm_key, alloc_strategy_.local_alloc_size, 0644);
if (shm_id == -1) {
WHOLEMEMORY_FATAL("Get host shared memory from IPC key %d for delete failed, Reason=%s",
shm_key,
strerror(errno));
}
WHOLEMEMORY_CHECK(shmdt(ptr) == 0);
} else {
shm_full_path = get_host_memory_full_path(comm_, handle_->handle_id);
WHOLEMEMORY_CHECK(munmap(ptr, alloc_strategy_.total_alloc_size) == 0);
}
WHOLEMEMORY_CHECK(shmdt(ptr) == 0);
#else
auto shm_full_path = get_host_memory_full_path(comm_, handle_->handle_id);
WHOLEMEMORY_CHECK(munmap(ptr, alloc_strategy_.total_alloc_size) == 0);
#endif
communicator_barrier(comm_);
#ifdef USE_SYSTEMV_SHM
if (comm_->world_rank == 0) {
if (use_systemv_shm_ && comm_->world_rank == 0) {
WHOLEMEMORY_CHECK(shmctl(shm_id, IPC_RMID, nullptr) == 0);
WHOLEMEMORY_CHECK(unlink(shm_full_path.c_str()) == 0);
}
#else
if (comm_->world_rank == 0) { WHOLEMEMORY_CHECK(shm_unlink(shm_full_path.c_str()) == 0); }
#endif

communicator_barrier(comm_);
shared_host_handle_.shared_host_memory_ptr = nullptr;
} catch (const wholememory::logic_error& wle) {
Expand All @@ -579,6 +589,8 @@ class global_mapped_host_wholememory_impl : public wholememory_impl {
struct shared_host_handle {
void* shared_host_memory_ptr = nullptr;
} shared_host_handle_;

bool use_systemv_shm_;
};

// Implementation for continuous device wholememory that need global map.
Expand Down

0 comments on commit 4ee62ba

Please sign in to comment.