Skip to content

Commit

Permalink
Merge pull request #655 from lukemartinlogan/dev
Browse files Browse the repository at this point in the history
Fix graceful runtime stop
  • Loading branch information
lukemartinlogan authored Dec 30, 2023
2 parents b897556 + 317c8b2 commit 99220e5
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 14 deletions.
28 changes: 24 additions & 4 deletions hrun/include/hrun/hrun_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,16 @@ enum class HrunMode {
struct DomainId {
bitfield32_t flags_; /**< Flags indicating how to interpret id */
u32 id_; /**< The domain id, 0 is NULL */
DOMAIN_FLAG_T kLocal = BIT_OPT(u32, 0); /**< Include local node in scheduling decision */
DOMAIN_FLAG_T kGlobal = BIT_OPT(u32, 1); /**< Use all nodes in scheduling decision */
DOMAIN_FLAG_T kSet = BIT_OPT(u32, 2); /**< ID represents node set ID, not a single node */
DOMAIN_FLAG_T kNode = BIT_OPT(u32, 3); /**< ID represents a specific node */
DOMAIN_FLAG_T kLocal =
BIT_OPT(u32, 0); /**< Use local node in scheduling decision */
DOMAIN_FLAG_T kGlobal =
BIT_OPT(u32, 1); /**< Use all nodes in scheduling decision */
DOMAIN_FLAG_T kNoLocal =
BIT_OPT(u32, 4); /**< Don't use local node in scheduling decision */
DOMAIN_FLAG_T kSet =
BIT_OPT(u32, 2); /**< ID represents node set ID, not a single node */
DOMAIN_FLAG_T kNode =
BIT_OPT(u32, 3); /**< ID represents a specific node */

/** Serialize domain id */
template<typename Ar>
Expand Down Expand Up @@ -170,6 +176,20 @@ struct DomainId {
return id;
}

/** Domain doesn't include this node */
bool IsNoLocal() const {
return flags_.Any(kNoLocal);
}

/** DomainId representing all nodes, except this one */
HSHM_ALWAYS_INLINE
static DomainId GetGlobalMinusLocal() {
DomainId id;
id.id_ = 0;
id.flags_.SetBits(kGlobal | kNoLocal);
return id;
}

/** DomainId represents a named node set */
HSHM_ALWAYS_INLINE
bool IsSet() const {
Expand Down
12 changes: 10 additions & 2 deletions hrun/src/hrun_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,16 @@ Runtime::ResolveDomainId(const DomainId &domain_id) {
std::vector<DomainId> ids;
if (domain_id.IsGlobal()) {
ids.reserve(rpc_.hosts_.size());
for (HostInfo &host_info : rpc_.hosts_) {
ids.push_back(DomainId::GetNode(host_info.node_id_));
if (domain_id.IsNoLocal()) {
for (HostInfo &host_info : rpc_.hosts_) {
if (host_info.node_id_ != rpc_.node_id_) {
ids.push_back(DomainId::GetNode(host_info.node_id_));
}
}
} else {
for (HostInfo &host_info : rpc_.hosts_) {
ids.push_back(DomainId::GetNode(host_info.node_id_));
}
}
} else if (domain_id.IsNode()) {
ids.reserve(1);
Expand Down
2 changes: 1 addition & 1 deletion hrun/src/hrun_stop_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@

int main() {
TRANSPARENT_HRUN();
HRUN_ADMIN->StopRuntimeRoot(hrun::DomainId::GetLocal());
HRUN_ADMIN->StopRuntimeRoot();
}
4 changes: 2 additions & 2 deletions hrun/src/work_orchestrator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ void WorkOrchestrator::Join() {
kill_requested_.store(true);
for (std::unique_ptr<Worker> &worker : workers_) {
worker->thread_->join();
ABT_xstream_join(xstream_);
ABT_xstream_free(&xstream_);
// ABT_xstream_join(xstream_);
// ABT_xstream_free(&xstream_);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,10 @@ class Client : public TaskLibClient {
task, task_node, domain_id);
}
HRUN_TASK_NODE_ADMIN_ROOT(StopRuntime);
void StopRuntimeRoot(const DomainId &domain_id) {
FlushRoot(domain_id);
AsyncStopRuntimeRoot(domain_id);
void StopRuntimeRoot() {
FlushRoot(DomainId::GetGlobal());
AsyncStopRuntimeRoot(DomainId::GetGlobalMinusLocal());
AsyncStopRuntimeRoot(DomainId::GetLocal());
}

/** Set work orchestrator queue policy */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ struct StopRuntimeTask : public Task, TaskFlags<TF_SRL_SYM> {
prio_ = TaskPrio::kAdmin;
task_state_ = HRUN_QM_CLIENT->admin_task_state_;
method_ = Method::kStopRuntime;
task_flags_.SetBits(TASK_FIRE_AND_FORGET);
task_flags_.SetBits(TASK_FIRE_AND_FORGET | TASK_FLUSH);
domain_id_ = domain_id;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class Client : public TaskLibClient {
void Disperse(Task *orig_task,
TaskState *exec,
std::vector<DomainId> &domain_ids) {
if (domain_ids.size() == 0) {
orig_task->SetModuleComplete();
return;
}

// Serialize task + create the wait task
orig_task->UnsetStarted();
BinaryOutputArchive<true> ar(DomainId::GetNode(HRUN_CLIENT->node_id_));
Expand Down
2 changes: 1 addition & 1 deletion test/unit/ipc/test_finalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
#include "hrun_admin/hrun_admin.h"

TEST_CASE("TestFinalize") {
HRUN_ADMIN->AsyncStopRuntimeRoot(hrun::DomainId::GetGlobal());
HRUN_ADMIN->StopRuntimeRoot();
}

0 comments on commit 99220e5

Please sign in to comment.