Skip to content

Commit

Permalink
Merge pull request #98 from firstbatchxyz/erhant/graceful-cancellatio…
Browse files Browse the repository at this point in the history
…n [skip ci]

Fixed cancellation signal issue
  • Loading branch information
erhant authored Aug 20, 2024
2 parents a487e7f + de9f73a commit cc49f4d
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub(crate) mod utils;

/// Crate version of the compute node.
/// This value is attached within the published messages.
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
pub const DRIA_COMPUTE_NODE_VERSION: &str = env!("CARGO_PKG_VERSION");

pub use config::DriaComputeNodeConfig;
pub use node::DriaComputeNode;
60 changes: 53 additions & 7 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use tokio_util::sync::CancellationToken;

use dkn_compute::{DriaComputeNode, DriaComputeNodeConfig};
use tokio::signal::unix::{signal, SignalKind};
use tokio_util::sync::CancellationToken;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.init();
log::info!(
"Initializing Dria Compute Node (version {})",
dkn_compute::VERSION
dkn_compute::DRIA_COMPUTE_NODE_VERSION
);

// create configurations & check required services
Expand All @@ -23,12 +23,58 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
panic!("Service check failed.")
}

let token = CancellationToken::new();

// launch the node
let mut node = DriaComputeNode::new(config, CancellationToken::new()).await?;
if let Err(err) = node.launch().await {
log::error!("Node error: {}", err);
panic!("Node failed.")
let node_token = token.clone();
let node_handle = tokio::spawn(async move {
match DriaComputeNode::new(config, node_token).await {
Ok(mut node) => {
if let Err(err) = node.launch().await {
log::error!("Node launch error: {}", err);
panic!("Node failed.")
};
}
Err(err) => {
log::error!("Node setup error: {}", err);
panic!("Could not setup node.")
}
}
});

// add cancellation check
tokio::spawn(async move {
if let Err(err) = wait_for_termination(token.clone()).await {
log::error!("Error waiting for termination: {}", err);
log::error!("Cancelling due to unexpected error.");
token.cancel();
};
});

// wait for tasks to complete
if let Err(err) = node_handle.await {
log::error!("Node handle error: {}", err);
panic!("Could not exit Node thread handle.");
};

Ok(())
}

/// Waits for SIGTERM or SIGINT, and cancels the given token when the signal is received.
async fn wait_for_termination(cancellation: CancellationToken) -> std::io::Result<()> {
let mut sigterm = signal(SignalKind::terminate())?; // Docker sends SIGTERM
let mut sigint = signal(SignalKind::interrupt())?; // Ctrl+C sends SIGINT
tokio::select! {
_ = sigterm.recv() => log::warn!("Recieved SIGTERM"),
_ = sigint.recv() => log::warn!("Recieved SIGINT"),
_ = cancellation.cancelled() => {
// no need to wait if cancelled anyways
// although this is not likely to happen
return Ok(());
}
};

log::info!("Terminating the node...");
cancellation.cancel();
Ok(())
}
24 changes: 2 additions & 22 deletions src/node.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::{str::FromStr, time::Duration};

use libp2p::{gossipsub, Multiaddr};
use tokio::signal::unix::{signal, SignalKind};
use std::{str::FromStr, time::Duration};
use tokio_util::sync::CancellationToken;

use crate::{
Expand Down Expand Up @@ -217,7 +215,7 @@ impl DriaComputeNode {

}
},
_ = wait_for_termination(self.cancellation.clone()) => break,
_ = self.cancellation.cancelled() => break,
}
}

Expand Down Expand Up @@ -273,24 +271,6 @@ impl DriaComputeNode {
}
}

/// Waits for SIGTERM or SIGINT, and cancels the given token when the signal is received.
async fn wait_for_termination(cancellation: CancellationToken) -> std::io::Result<()> {
let mut sigterm = signal(SignalKind::terminate())?; // Docker sends SIGTERM
let mut sigint = signal(SignalKind::interrupt())?; // Ctrl+C sends SIGINT
tokio::select! {
_ = sigterm.recv() => log::warn!("Recieved SIGTERM"),
_ = sigint.recv() => log::warn!("Recieved SIGINT"),
_ = cancellation.cancelled() => {
// no need to wait if cancelled anyways
return Ok(());
}
};

log::info!("Terminating the node...");
cancellation.cancel();
Ok(())
}

#[cfg(test)]
mod tests {
use crate::{p2p::P2PMessage, DriaComputeNode, DriaComputeNodeConfig};
Expand Down
6 changes: 3 additions & 3 deletions src/p2p/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl P2PMessage {
Self {
payload: BASE64_STANDARD.encode(payload),
topic: topic.to_string(),
version: crate::VERSION.to_string(),
version: crate::DRIA_COMPUTE_NODE_VERSION.to_string(),
timestamp: get_current_time_nanos(),
}
}
Expand Down Expand Up @@ -197,7 +197,7 @@ mod tests {
"{\"hello\":\"world\"}"
);
assert_eq!(message.topic, "test-topic");
assert_eq!(message.version, crate::VERSION);
assert_eq!(message.version, crate::DRIA_COMPUTE_NODE_VERSION);
assert!(message.timestamp > 0);

let parsed_body = message.parse_payload(false).expect("Should decode");
Expand All @@ -224,7 +224,7 @@ mod tests {
"{\"hello\":\"world\"}"
);
assert_eq!(message.topic, "test-topic");
assert_eq!(message.version, crate::VERSION);
assert_eq!(message.version, crate::DRIA_COMPUTE_NODE_VERSION);
assert!(message.timestamp > 0);

assert!(message.is_signed(&pk).expect("Should check signature"));
Expand Down

0 comments on commit cc49f4d

Please sign in to comment.