diff --git a/Cargo.lock b/Cargo.lock index bbe35a3ac..8043d01f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3162,6 +3162,7 @@ dependencies = [ "dialoguer", "dotenvy", "figment", + "futures-util", "http-body-util", "hyper", "ipnetwork", @@ -3199,6 +3200,7 @@ dependencies = [ "rand_chacha", "reqwest", "rustls", + "sd-notify", "sentry", "sentry-tower", "sentry-tracing", @@ -5323,6 +5325,15 @@ dependencies = [ "sha2", ] +[[package]] +name = "sd-notify" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b943eadf71d8b69e661330cb0e2656e31040acf21ee7708e2c238a0ec6af2bf4" +dependencies = [ + "libc", +] + [[package]] name = "sea-query" version = "0.32.1" diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index a3785210a..7c20c3b28 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -23,6 +23,7 @@ console = "0.15.10" dialoguer = { version = "0.11.0", features = ["fuzzy-select"] } dotenvy = "0.15.7" figment.workspace = true +futures-util.workspace = true http-body-util.workspace = true hyper.workspace = true ipnetwork = "0.20.0" @@ -32,6 +33,7 @@ rand.workspace = true rand_chacha = "0.3.1" reqwest.workspace = true rustls.workspace = true +sd-notify = "0.4.5" serde_json.workspace = true serde_yaml = "0.9.34" sqlx.workspace = true diff --git a/crates/cli/src/commands/server.rs b/crates/cli/src/commands/server.rs index 53bfa899e..565c6bfda 100644 --- a/crates/cli/src/commands/server.rs +++ b/crates/cli/src/commands/server.rs @@ -24,11 +24,10 @@ use tracing::{info, info_span, warn, Instrument}; use crate::{ app_state::AppState, - shutdown::ShutdownManager, + lifecycle::LifecycleManager, util::{ database_pool_from_config, mailer_from_config, password_manager_from_config, - policy_factory_from_config, register_sighup, site_config_from_config, - templates_from_config, + policy_factory_from_config, site_config_from_config, templates_from_config, }, }; @@ -57,7 +56,7 @@ impl Options { #[allow(clippy::too_many_lines)] pub async fn run(self, figment: &Figment) -> anyhow::Result { let span = info_span!("cli.run.init").entered(); - let shutdown = ShutdownManager::new()?; + let mut shutdown = LifecycleManager::new()?; let config = AppConfig::extract(figment)?; info!(version = crate::VERSION, "Starting up"); @@ -145,6 +144,7 @@ impl Options { // Load and compile the templates let templates = templates_from_config(&config.templates, &site_config, &url_builder).await?; + shutdown.register_reloadable(&templates); let http_client = mas_http::reqwest_client(); @@ -186,6 +186,9 @@ impl Options { shutdown.task_tracker(), shutdown.soft_shutdown_token(), ); + + shutdown.register_reloadable(&activity_tracker); + let trusted_proxies = config.http.trusted_proxies.clone(); // Build a rate limiter. @@ -197,9 +200,6 @@ impl Options { // Explicitly the config to properly zeroize secret keys drop(config); - // Listen for SIGHUP - register_sighup(&templates, &activity_tracker)?; - limiter.start(); let graphql_schema = mas_handlers::graphql_schema( diff --git a/crates/cli/src/commands/worker.rs b/crates/cli/src/commands/worker.rs index 3e3df8402..33cf52b5b 100644 --- a/crates/cli/src/commands/worker.rs +++ b/crates/cli/src/commands/worker.rs @@ -14,7 +14,7 @@ use mas_router::UrlBuilder; use tracing::{info, info_span}; use crate::{ - shutdown::ShutdownManager, + lifecycle::LifecycleManager, util::{ database_pool_from_config, mailer_from_config, site_config_from_config, templates_from_config, @@ -26,7 +26,7 @@ pub(super) struct Options {} impl Options { pub async fn run(self, figment: &Figment) -> anyhow::Result { - let shutdown = ShutdownManager::new()?; + let shutdown = LifecycleManager::new()?; let span = info_span!("cli.worker.init").entered(); let config = AppConfig::extract(figment)?; diff --git a/crates/cli/src/lifecycle.rs b/crates/cli/src/lifecycle.rs new file mode 100644 index 000000000..aa79d2114 --- /dev/null +++ b/crates/cli/src/lifecycle.rs @@ -0,0 +1,239 @@ +// Copyright 2024, 2025 New Vector Ltd. +// +// SPDX-License-Identifier: AGPL-3.0-only +// Please see LICENSE in the repository root for full details. + +use std::{future::Future, process::ExitCode, time::Duration}; + +use futures_util::future::{BoxFuture, Either}; +use mas_handlers::ActivityTracker; +use mas_templates::Templates; +use tokio::signal::unix::{Signal, SignalKind}; +use tokio_util::{sync::CancellationToken, task::TaskTracker}; + +/// A helper to manage the lifecycle of the service, inclusing handling graceful +/// shutdowns and configuration reloads. +/// +/// It will listen for SIGTERM and SIGINT signals, and will trigger a soft +/// shutdown on the first signal, and a hard shutdown on the second signal or +/// after a timeout. +/// +/// Users of this manager should use the `soft_shutdown_token` to react to a +/// soft shutdown, which should gracefully finish requests and close +/// connections, and the `hard_shutdown_token` to react to a hard shutdown, +/// which should drop all connections and finish all requests. +/// +/// They should also use the `task_tracker` to make it track things running, so +/// that it knows when the soft shutdown is over and worked. +/// +/// It also integrates with [`sd_notify`] to notify the service manager of the +/// state of the service. +pub struct LifecycleManager { + hard_shutdown_token: CancellationToken, + soft_shutdown_token: CancellationToken, + task_tracker: TaskTracker, + sigterm: Signal, + sigint: Signal, + sighup: Signal, + timeout: Duration, + reload_handlers: Vec BoxFuture<'static, ()>>>, +} + +/// Represents a thing that can be reloaded with a SIGHUP +pub trait Reloadable: Clone + Send { + fn reload(&self) -> impl Future + Send; +} + +impl Reloadable for ActivityTracker { + async fn reload(&self) { + self.flush().await; + } +} + +impl Reloadable for Templates { + async fn reload(&self) { + if let Err(err) = self.reload().await { + tracing::error!( + error = &err as &dyn std::error::Error, + "Failed to reload templates" + ); + } + } +} + +/// A wrapper around [`sd_notify::notify`] that logs any errors +fn notify(states: &[sd_notify::NotifyState]) { + if let Err(e) = sd_notify::notify(false, states) { + tracing::error!( + error = &e as &dyn std::error::Error, + "Failed to notify service manager" + ); + } +} + +impl LifecycleManager { + /// Create a new shutdown manager, installing the signal handlers + /// + /// # Errors + /// + /// Returns an error if the signal handler could not be installed + pub fn new() -> Result { + let hard_shutdown_token = CancellationToken::new(); + let soft_shutdown_token = hard_shutdown_token.child_token(); + let sigterm = tokio::signal::unix::signal(SignalKind::terminate())?; + let sigint = tokio::signal::unix::signal(SignalKind::interrupt())?; + let sighup = tokio::signal::unix::signal(SignalKind::hangup())?; + let timeout = Duration::from_secs(60); + let task_tracker = TaskTracker::new(); + + notify(&[sd_notify::NotifyState::MainPid(std::process::id())]); + + Ok(Self { + hard_shutdown_token, + soft_shutdown_token, + task_tracker, + sigterm, + sigint, + sighup, + timeout, + reload_handlers: Vec::new(), + }) + } + + /// Add a handler to be called when the server gets a SIGHUP + pub fn register_reloadable(&mut self, reloadable: &(impl Reloadable + 'static)) { + let reloadable = reloadable.clone(); + self.reload_handlers.push(Box::new(move || { + let reloadable = reloadable.clone(); + Box::pin(async move { reloadable.reload().await }) + })); + } + + /// Get a reference to the task tracker + #[must_use] + pub fn task_tracker(&self) -> &TaskTracker { + &self.task_tracker + } + + /// Get a cancellation token that can be used to react to a hard shutdown + #[must_use] + pub fn hard_shutdown_token(&self) -> CancellationToken { + self.hard_shutdown_token.clone() + } + + /// Get a cancellation token that can be used to react to a soft shutdown + #[must_use] + pub fn soft_shutdown_token(&self) -> CancellationToken { + self.soft_shutdown_token.clone() + } + + /// Run until we finish completely shutting down. + pub async fn run(mut self) -> ExitCode { + notify(&[sd_notify::NotifyState::Ready]); + + // This will be `Some` if we have the watchdog enabled, and `None` if not + let mut watchdog_interval = { + let mut watchdog_usec = 0; + if sd_notify::watchdog_enabled(false, &mut watchdog_usec) { + Some(tokio::time::interval(Duration::from_micros( + watchdog_usec / 2, + ))) + } else { + None + } + }; + + // Wait for a first shutdown signal and trigger the soft shutdown + let likely_crashed = loop { + // This makes a Future that will either yield the watchdog tick if enabled, or a + // pending Future if not + let watchdog_tick = if let Some(watchdog_interval) = &mut watchdog_interval { + Either::Left(watchdog_interval.tick()) + } else { + Either::Right(futures_util::future::pending()) + }; + + tokio::select! { + () = self.soft_shutdown_token.cancelled() => { + tracing::warn!("Another task triggered a shutdown, it likely crashed! Shutting down"); + break true; + }, + + _ = self.sigterm.recv() => { + tracing::info!("Shutdown signal received (SIGTERM), shutting down"); + break false; + }, + + _ = self.sigint.recv() => { + tracing::info!("Shutdown signal received (SIGINT), shutting down"); + break false; + }, + + _ = watchdog_tick => { + notify(&[ + sd_notify::NotifyState::Watchdog, + ]); + }, + + _ = self.sighup.recv() => { + tracing::info!("Reload signal received (SIGHUP), reloading"); + + notify(&[ + sd_notify::NotifyState::Reloading, + sd_notify::NotifyState::monotonic_usec_now() + .expect("Failed to read monotonic clock") + ]); + + // XXX: if one handler takes a long time, it will block the + // rest of the shutdown process, which is not ideal. We + // should probably have a timeout here + futures_util::future::join_all( + self.reload_handlers + .iter() + .map(|handler| handler()) + ).await; + + notify(&[sd_notify::NotifyState::Ready]); + + tracing::info!("Reloading done"); + }, + } + }; + + notify(&[sd_notify::NotifyState::Stopping]); + + self.soft_shutdown_token.cancel(); + self.task_tracker.close(); + + // Start the timeout + let timeout = tokio::time::sleep(self.timeout); + tokio::select! { + _ = self.sigterm.recv() => { + tracing::warn!("Second shutdown signal received (SIGTERM), abort"); + }, + _ = self.sigint.recv() => { + tracing::warn!("Second shutdown signal received (SIGINT), abort"); + }, + () = timeout => { + tracing::warn!("Shutdown timeout reached, abort"); + }, + () = self.task_tracker.wait() => { + // This is the "happy path", we have gracefully shutdown + }, + } + + self.hard_shutdown_token().cancel(); + + // TODO: we may want to have a time out on the task tracker, in case we have + // really stuck tasks on it + self.task_tracker().wait().await; + + tracing::info!("All tasks are done, exitting"); + + if likely_crashed { + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } + } +} diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index f487853bc..049a2fdc0 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -18,8 +18,8 @@ use tracing_subscriber::{ mod app_state; mod commands; +mod lifecycle; mod server; -mod shutdown; mod sync; mod telemetry; mod util; diff --git a/crates/cli/src/shutdown.rs b/crates/cli/src/shutdown.rs deleted file mode 100644 index 5386166af..000000000 --- a/crates/cli/src/shutdown.rs +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2024 New Vector Ltd. -// -// SPDX-License-Identifier: AGPL-3.0-only -// Please see LICENSE in the repository root for full details. - -use std::{process::ExitCode, time::Duration}; - -use tokio::signal::unix::{Signal, SignalKind}; -use tokio_util::{sync::CancellationToken, task::TaskTracker}; - -/// A helper to manage graceful shutdowns and track tasks that gracefully -/// shutdown. -/// -/// It will listen for SIGTERM and SIGINT signals, and will trigger a soft -/// shutdown on the first signal, and a hard shutdown on the second signal or -/// after a timeout. -/// -/// Users of this manager should use the `soft_shutdown_token` to react to a -/// soft shutdown, which should gracefully finish requests and close -/// connections, and the `hard_shutdown_token` to react to a hard shutdown, -/// which should drop all connections and finish all requests. -/// -/// They should also use the `task_tracker` to make it track things running, so -/// that it knows when the soft shutdown is over and worked. -pub struct ShutdownManager { - hard_shutdown_token: CancellationToken, - soft_shutdown_token: CancellationToken, - task_tracker: TaskTracker, - sigterm: Signal, - sigint: Signal, - timeout: Duration, -} - -impl ShutdownManager { - /// Create a new shutdown manager, installing the signal handlers - /// - /// # Errors - /// - /// Returns an error if the signal handler could not be installed - pub fn new() -> Result { - let hard_shutdown_token = CancellationToken::new(); - let soft_shutdown_token = hard_shutdown_token.child_token(); - let sigterm = tokio::signal::unix::signal(SignalKind::terminate())?; - let sigint = tokio::signal::unix::signal(SignalKind::interrupt())?; - let timeout = Duration::from_secs(60); - let task_tracker = TaskTracker::new(); - - Ok(Self { - hard_shutdown_token, - soft_shutdown_token, - task_tracker, - sigterm, - sigint, - timeout, - }) - } - - /// Get a reference to the task tracker - #[must_use] - pub fn task_tracker(&self) -> &TaskTracker { - &self.task_tracker - } - - /// Get a cancellation token that can be used to react to a hard shutdown - #[must_use] - pub fn hard_shutdown_token(&self) -> CancellationToken { - self.hard_shutdown_token.clone() - } - - /// Get a cancellation token that can be used to react to a soft shutdown - #[must_use] - pub fn soft_shutdown_token(&self) -> CancellationToken { - self.soft_shutdown_token.clone() - } - - /// Run until we finish completely shutting down. - pub async fn run(mut self) -> ExitCode { - // Wait for a first signal and trigger the soft shutdown - let likely_crashed = tokio::select! { - () = self.soft_shutdown_token.cancelled() => { - tracing::warn!("Another task triggered a shutdown, it likely crashed! Shutting down"); - true - }, - - _ = self.sigterm.recv() => { - tracing::info!("Shutdown signal received (SIGTERM), shutting down"); - false - }, - - _ = self.sigint.recv() => { - tracing::info!("Shutdown signal received (SIGINT), shutting down"); - false - }, - }; - - self.soft_shutdown_token.cancel(); - self.task_tracker.close(); - - // Start the timeout - let timeout = tokio::time::sleep(self.timeout); - tokio::select! { - _ = self.sigterm.recv() => { - tracing::warn!("Second shutdown signal received (SIGTERM), abort"); - }, - _ = self.sigint.recv() => { - tracing::warn!("Second shutdown signal received (SIGINT), abort"); - }, - () = timeout => { - tracing::warn!("Shutdown timeout reached, abort"); - }, - () = self.task_tracker.wait() => { - // This is the "happy path", we have gracefully shutdown - }, - } - - self.hard_shutdown_token().cancel(); - - // TODO: we may want to have a time out on the task tracker, in case we have - // really stuck tasks on it - self.task_tracker().wait().await; - - tracing::info!("All tasks are done, exitting"); - - if likely_crashed { - ExitCode::FAILURE - } else { - ExitCode::SUCCESS - } - } -} diff --git a/crates/cli/src/util.rs b/crates/cli/src/util.rs index 3d3d8f676..b117d0177 100644 --- a/crates/cli/src/util.rs +++ b/crates/cli/src/util.rs @@ -14,7 +14,7 @@ use mas_config::{ }; use mas_data_model::SiteConfig; use mas_email::{MailTransport, Mailer}; -use mas_handlers::{passwords::PasswordManager, ActivityTracker}; +use mas_handlers::passwords::PasswordManager; use mas_policy::PolicyFactory; use mas_router::UrlBuilder; use mas_templates::{SiteConfigExt, TemplateLoadingError, Templates}; @@ -22,7 +22,7 @@ use sqlx::{ postgres::{PgConnectOptions, PgPoolOptions}, ConnectOptions, PgConnection, PgPool, }; -use tracing::{error, info, log::LevelFilter}; +use tracing::log::LevelFilter; pub async fn password_manager_from_config( config: &PasswordsConfig, @@ -313,37 +313,6 @@ pub async fn database_connection_from_config( .context("could not connect to the database") } -/// Reload templates on SIGHUP -pub fn register_sighup( - templates: &Templates, - activity_tracker: &ActivityTracker, -) -> anyhow::Result<()> { - #[cfg(unix)] - { - let mut signal = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())?; - let templates = templates.clone(); - let activity_tracker = activity_tracker.clone(); - - tokio::spawn(async move { - loop { - if signal.recv().await.is_none() { - // No more signals will be received, breaking - break; - }; - - info!("SIGHUP received, reloading templates & flushing activity tracker"); - - activity_tracker.flush().await; - templates.clone().reload().await.unwrap_or_else(|err| { - error!(?err, "Error while reloading templates"); - }); - } - }); - } - - Ok(()) -} - #[cfg(test)] mod tests { use rand::SeedableRng;