From cdbdd9bb981db4b2ec68e71eff681d3acc4040fb Mon Sep 17 00:00:00 2001 From: Ermal Kaleci Date: Mon, 8 Apr 2024 09:43:58 +0200 Subject: [PATCH] endpoint health (#152) --- benches/bench/main.rs | 1 + config.yml | 8 + eth_config.yml | 12 + src/extensions/client/endpoint.rs | 169 ++++++++++ src/extensions/client/health.rs | 159 +++++++++ src/extensions/client/mock.rs | 10 + src/extensions/client/mod.rs | 517 +++++++++++++++++++----------- src/extensions/client/tests.rs | 74 +++++ src/server.rs | 1 + src/tests/merge_subscription.rs | 1 + src/tests/upstream.rs | 1 + 11 files changed, 763 insertions(+), 190 deletions(-) create mode 100644 src/extensions/client/endpoint.rs create mode 100644 src/extensions/client/health.rs diff --git a/benches/bench/main.rs b/benches/bench/main.rs index ea66c99..2f457c4 100644 --- a/benches/bench/main.rs +++ b/benches/bench/main.rs @@ -217,6 +217,7 @@ fn config() -> Config { format!("ws://{}", SERVER_TWO_ENDPOINT), ], shuffle_endpoints: false, + health_check: None, }), server: Some(ServerConfig { listen_address: SUBWAY_SERVER_ADDR.to_string(), diff --git a/config.yml b/config.yml index c284bf2..890f8ae 100644 --- a/config.yml +++ b/config.yml @@ -3,6 +3,14 @@ extensions: endpoints: - wss://acala-rpc.dwellir.com - wss://acala-rpc-0.aca-api.network + health_check: + interval_sec: 10 # check interval, default is 10s + healthy_response_time_ms: 500 # max response time to be considered healthy, default is 500ms + health_method: system_health + response: # response contains { isSyncing: false } + !contains + - - isSyncing + - !eq false event_bus: substrate_api: stale_timeout_seconds: 180 # rotate endpoint if no new blocks for 3 minutes diff --git a/eth_config.yml b/eth_config.yml index 3363d77..85d6955 100644 --- a/eth_config.yml +++ b/eth_config.yml @@ -2,6 +2,18 @@ extensions: client: endpoints: - wss://eth-rpc-karura-testnet.aca-staging.network + health_check: + interval_sec: 10 # check interval, default is 10s + healthy_response_time_ms: 500 # max response time to be considered healthy, default is 500ms + health_method: net_health # eth-rpc-adapter bodhijs + response: # response contains { isHealthy: true, isRPCOK: true } + !contains + - - isHealthy + - !eq true + - - isRPCOK + - !eq true +# health_method: eth_syncing # eth node +# response: !eq false event_bus: eth_api: stale_timeout_seconds: 180 # rotate endpoint if no new blocks for 3 minutes diff --git a/src/extensions/client/endpoint.rs b/src/extensions/client/endpoint.rs new file mode 100644 index 0000000..8ea351f --- /dev/null +++ b/src/extensions/client/endpoint.rs @@ -0,0 +1,169 @@ +use super::health::{Event, Health}; +use crate::{ + extensions::client::{get_backoff_time, HealthCheckConfig}, + utils::errors, +}; +use jsonrpsee::{ + async_client::Client, + core::client::{ClientT, Subscription, SubscriptionClientT}, + ws_client::WsClientBuilder, +}; +use std::{ + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::Duration, +}; + +pub struct Endpoint { + url: String, + health: Arc, + client_rx: tokio::sync::watch::Receiver>>, + on_client_ready: Arc, + background_tasks: Vec>, +} + +impl Drop for Endpoint { + fn drop(&mut self) { + self.background_tasks.drain(..).for_each(|handle| handle.abort()); + } +} + +impl Endpoint { + pub fn new( + url: String, + request_timeout: Option, + connection_timeout: Option, + health_config: HealthCheckConfig, + ) -> Self { + let (client_tx, client_rx) = tokio::sync::watch::channel(None); + let on_client_ready = Arc::new(tokio::sync::Notify::new()); + let health = Arc::new(Health::new(url.clone(), health_config)); + + let url_ = url.clone(); + let health_ = health.clone(); + let on_client_ready_ = on_client_ready.clone(); + + // This task will try to connect to the endpoint and keep the connection alive + let connection_task = tokio::spawn(async move { + let connect_backoff_counter = Arc::new(AtomicU32::new(0)); + + loop { + tracing::info!("Connecting endpoint: {url_}"); + + let client = WsClientBuilder::default() + .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) + .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) + .max_buffer_capacity_per_subscription(2048) + .max_concurrent_requests(2048) + .max_response_size(20 * 1024 * 1024) + .build(&url_); + + match client.await { + Ok(client) => { + let client = Arc::new(client); + health_.update(Event::ConnectionSuccessful); + _ = client_tx.send(Some(client.clone())); + on_client_ready_.notify_waiters(); + tracing::info!("Endpoint connected: {url_}"); + connect_backoff_counter.store(0, Ordering::Relaxed); + client.on_disconnect().await; + } + Err(err) => { + health_.on_error(&err); + _ = client_tx.send(None); + tracing::warn!("Unable to connect to endpoint: {url_} error: {err}"); + tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; + } + } + // Wait a second before trying to reconnect + tokio::time::sleep(Duration::from_secs(1)).await; + } + }); + + // This task will check the health of the endpoint and update the health score + let health_checker = Health::monitor(health.clone(), client_rx.clone(), on_client_ready.clone()); + + Self { + url, + health, + client_rx, + on_client_ready, + background_tasks: vec![connection_task, health_checker], + } + } + + pub fn url(&self) -> &str { + &self.url + } + + pub fn health(&self) -> &Health { + self.health.as_ref() + } + + pub async fn connected(&self) { + if self.client_rx.borrow().is_some() { + return; + } + self.on_client_ready.notified().await; + } + + pub async fn request( + &self, + method: &str, + params: Vec, + timeout: Duration, + ) -> Result { + let client = self + .client_rx + .borrow() + .clone() + .ok_or(errors::failed("client not connected"))?; + + match tokio::time::timeout(timeout, client.request(method, params.clone())).await { + Ok(Ok(response)) => Ok(response), + Ok(Err(err)) => { + self.health.on_error(&err); + Err(err) + } + Err(_) => { + tracing::error!("request timed out method: {method} params: {params:?}"); + self.health.on_error(&jsonrpsee::core::Error::RequestTimeout); + Err(jsonrpsee::core::Error::RequestTimeout) + } + } + } + + pub async fn subscribe( + &self, + subscribe_method: &str, + params: Vec, + unsubscribe_method: &str, + timeout: Duration, + ) -> Result, jsonrpsee::core::Error> { + let client = self + .client_rx + .borrow() + .clone() + .ok_or(errors::failed("client not connected"))?; + + match tokio::time::timeout( + timeout, + client.subscribe(subscribe_method, params.clone(), unsubscribe_method), + ) + .await + { + Ok(Ok(response)) => Ok(response), + Ok(Err(err)) => { + self.health.on_error(&err); + Err(err) + } + Err(_) => { + tracing::error!("subscribe timed out subscribe: {subscribe_method} params: {params:?}"); + self.health.on_error(&jsonrpsee::core::Error::RequestTimeout); + Err(jsonrpsee::core::Error::RequestTimeout) + } + } + } +} diff --git a/src/extensions/client/health.rs b/src/extensions/client/health.rs new file mode 100644 index 0000000..acdef15 --- /dev/null +++ b/src/extensions/client/health.rs @@ -0,0 +1,159 @@ +use crate::extensions::client::HealthCheckConfig; +use jsonrpsee::{async_client::Client, core::client::ClientT}; +use std::{ + sync::{ + atomic::{AtomicU32, Ordering}, + Arc, + }, + time::Duration, +}; + +#[derive(Debug)] +pub enum Event { + ResponseOk, + SlowResponse, + RequestTimeout, + ConnectionSuccessful, + ConnectionFailed, + StaleChain, +} + +impl Event { + pub fn update_score(&self, current: u32) -> u32 { + u32::min( + match self { + Event::ResponseOk => current.saturating_add(2), + Event::SlowResponse => current.saturating_sub(5), + Event::RequestTimeout | Event::ConnectionFailed | Event::StaleChain => 0, + Event::ConnectionSuccessful => MAX_SCORE / 5 * 4, // 80% of max score + }, + MAX_SCORE, + ) + } +} + +#[derive(Debug, Default)] +pub struct Health { + url: String, + config: HealthCheckConfig, + score: AtomicU32, + unhealthy: tokio::sync::Notify, +} + +const MAX_SCORE: u32 = 100; +const THRESHOLD: u32 = MAX_SCORE / 2; + +impl Health { + pub fn new(url: String, config: HealthCheckConfig) -> Self { + Self { + url, + config, + score: AtomicU32::new(0), + unhealthy: tokio::sync::Notify::new(), + } + } + + pub fn score(&self) -> u32 { + self.score.load(Ordering::Relaxed) + } + + pub fn update(&self, event: Event) { + let current_score = self.score.load(Ordering::Relaxed); + let new_score = event.update_score(current_score); + if new_score == current_score { + return; + } + self.score.store(new_score, Ordering::Relaxed); + log::trace!( + "Endpoint {:?} score updated from: {current_score} to: {new_score}", + self.url + ); + + // Notify waiters if the score has dropped below the threshold + if current_score >= THRESHOLD && new_score < THRESHOLD { + log::warn!("Endpoint {:?} became unhealthy", self.url); + self.unhealthy.notify_waiters(); + } + } + + pub fn on_error(&self, err: &jsonrpsee::core::Error) { + log::warn!("Endpoint {:?} responded with error: {err:?}", self.url); + match err { + jsonrpsee::core::Error::RequestTimeout => { + self.update(Event::RequestTimeout); + } + jsonrpsee::core::Error::Transport(_) + | jsonrpsee::core::Error::RestartNeeded(_) + | jsonrpsee::core::Error::MaxSlotsExceeded => { + self.update(Event::ConnectionFailed); + } + _ => {} + }; + } + + pub async fn unhealthy(&self) { + self.unhealthy.notified().await; + } +} + +impl Health { + pub fn monitor( + health: Arc, + client_rx_: tokio::sync::watch::Receiver>>, + on_client_ready: Arc, + ) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + // no health method + if health.config.health_method.is_none() { + return; + } + + // Wait for the client to be ready before starting the health check + on_client_ready.notified().await; + + let method_name = health.config.health_method.as_ref().expect("checked above"); + let health_response = health.config.response.clone(); + let interval = Duration::from_secs(health.config.interval_sec); + let healthy_response_time = Duration::from_millis(health.config.healthy_response_time_ms); + + let client = match client_rx_.borrow().clone() { + Some(client) => client, + None => return, + }; + + loop { + // Wait for the next interval + tokio::time::sleep(interval).await; + + let request_start = std::time::Instant::now(); + match client + .request::>(method_name, vec![]) + .await + { + Ok(response) => { + let duration = request_start.elapsed(); + + // Check response + if let Some(ref health_response) = health_response { + if !health_response.validate(&response) { + health.update(Event::StaleChain); + continue; + } + } + + // Check response time + if duration > healthy_response_time { + health.update(Event::SlowResponse); + continue; + } + + health.update(Event::ResponseOk); + } + Err(err) => { + health.on_error(&err); + } + } + } + }) + } +} diff --git a/src/extensions/client/mock.rs b/src/extensions/client/mock.rs index fbe0864..d999093 100644 --- a/src/extensions/client/mock.rs +++ b/src/extensions/client/mock.rs @@ -153,6 +153,16 @@ pub async fn dummy_server() -> ( (addr, handle, rx, sub_rx) } +pub async fn dummy_server_extend(extend: Box) -> (SocketAddr, ServerHandle) { + let mut builder = TestServerBuilder::new(); + + extend(&mut builder); + + let (addr, handle) = builder.build().await; + + (addr, handle) +} + pub enum SinkTask { Sleep(u64), Send(JsonValue), diff --git a/src/extensions/client/mod.rs b/src/extensions/client/mod.rs index a05441d..e016467 100644 --- a/src/extensions/client/mod.rs +++ b/src/extensions/client/mod.rs @@ -1,24 +1,14 @@ use std::{ - sync::{ - atomic::{AtomicU32, AtomicUsize}, - Arc, - }, + sync::{atomic::AtomicU32, Arc}, time::Duration, }; use anyhow::anyhow; use async_trait::async_trait; -use futures::TryFutureExt; -use jsonrpsee::{ - core::{ - client::{ClientT, Subscription, SubscriptionClientT}, - Error, JsonValue, - }, - ws_client::{WsClient, WsClientBuilder}, -}; +use jsonrpsee::core::{client::Subscription, Error, JsonValue}; use opentelemetry::trace::FutureExt; use rand::{seq::SliceRandom, thread_rng}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use tokio::sync::Notify; use super::ExtensionRegistry; @@ -28,6 +18,10 @@ use crate::{ utils::{self, errors}, }; +mod endpoint; +mod health; +use endpoint::Endpoint; + #[cfg(test)] pub mod mock; #[cfg(test)] @@ -53,12 +47,72 @@ pub struct ClientConfig { pub endpoints: Vec, #[serde(default = "bool_true")] pub shuffle_endpoints: bool, + pub health_check: Option, } pub fn bool_true() -> bool { true } +#[derive(Deserialize, Debug, Clone)] +pub struct HealthCheckConfig { + #[serde(default = "interval_sec")] + pub interval_sec: u64, + #[serde(default = "healthy_response_time_ms")] + pub healthy_response_time_ms: u64, + pub health_method: Option, + pub response: Option, +} + +impl Default for HealthCheckConfig { + fn default() -> Self { + Self { + interval_sec: interval_sec(), + healthy_response_time_ms: healthy_response_time_ms(), + health_method: None, + response: None, + } + } +} + +pub fn interval_sec() -> u64 { + 10 +} + +pub fn healthy_response_time_ms() -> u64 { + 500 +} + +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum HealthResponse { + Eq(JsonValue), + NotEq(JsonValue), + Contains(Vec<(String, Box)>), +} + +impl HealthResponse { + pub fn validate(&self, response: &JsonValue) -> bool { + match self { + HealthResponse::Eq(value) => value.eq(response), + HealthResponse::NotEq(value) => !value.eq(response), + HealthResponse::Contains(items) => { + for (key, expected) in items { + if let Some(response) = response.get(key) { + if !expected.validate(response) { + return false; + } + } else { + // key missing + return false; + } + } + true + } + } + } +} + #[derive(Debug)] enum Message { Request { @@ -82,12 +136,13 @@ impl Extension for Client { type Config = ClientConfig; async fn from_config(config: &Self::Config, _registry: &ExtensionRegistry) -> Result { + let health_check = config.health_check.clone(); if config.shuffle_endpoints { let mut endpoints = config.endpoints.clone(); endpoints.shuffle(&mut thread_rng()); - Ok(Self::new(endpoints, None, None, None)?) + Ok(Self::new(endpoints, None, None, None, health_check)?) } else { - Ok(Self::new(config.endpoints.clone(), None, None, None)?) + Ok(Self::new(config.endpoints.clone(), None, None, None, health_check)?) } } } @@ -98,14 +153,32 @@ impl Client { request_timeout: Option, connection_timeout: Option, retries: Option, + health_config: Option, ) -> Result { + let health_config = health_config.unwrap_or_default(); let endpoints: Vec<_> = endpoints.into_iter().map(|e| e.as_ref().to_string()).collect(); if endpoints.is_empty() { return Err(anyhow!("No endpoints provided")); } - tracing::debug!("New client with endpoints: {:?}", endpoints); + if let Some(0) = retries { + return Err(anyhow!("Retries need to be at least 1")); + } + + tracing::debug!("New client with endpoints: {endpoints:?}"); + + let endpoints = endpoints + .into_iter() + .map(|e| { + Arc::new(Endpoint::new( + e, + request_timeout, + connection_timeout, + health_config.clone(), + )) + }) + .collect::>(); let (message_tx, mut message_rx) = tokio::sync::mpsc::channel::(100); @@ -115,57 +188,39 @@ impl Client { let rotation_notify_bg = rotation_notify.clone(); let background_task = tokio::spawn(async move { - let connect_backoff_counter = Arc::new(AtomicU32::new(0)); let request_backoff_counter = Arc::new(AtomicU32::new(0)); - let current_endpoint = AtomicUsize::new(0); - - let connect_backoff_counter2 = connect_backoff_counter.clone(); - let build_ws = || async { - let build = || { - let current_endpoint = current_endpoint.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let url = &endpoints[current_endpoint % endpoints.len()]; - - tracing::info!("Connecting to endpoint: {}", url); - - // TODO: make those configurable - WsClientBuilder::default() - .request_timeout(request_timeout.unwrap_or(Duration::from_secs(30))) - .connection_timeout(connection_timeout.unwrap_or(Duration::from_secs(30))) - .max_buffer_capacity_per_subscription(2048) - .max_concurrent_requests(2048) - .max_response_size(20 * 1024 * 1024) - .build(url) - .map_err(|e| (e, url.to_string())) - }; + // Select next endpoint with the highest health score, excluding the current one if provided + let healthiest_endpoint = |exclude: Option>| async { + if endpoints.len() == 1 { + let selected_endpoint = endpoints[0].clone(); + // Ensure it's connected + selected_endpoint.connected().await; + return selected_endpoint; + } - loop { - match build().await { - Ok(ws) => { - let ws = Arc::new(ws); - tracing::info!("Endpoint connected"); - connect_backoff_counter2.store(0, std::sync::atomic::Ordering::Relaxed); - break ws; - } - Err((e, url)) => { - tracing::warn!("Unable to connect to endpoint: '{url}' error: {e}"); - tokio::time::sleep(get_backoff_time(&connect_backoff_counter2)).await; - } - } + let mut endpoints = endpoints.clone(); + // Remove the current endpoint from the list + if let Some(exclude) = exclude { + endpoints.retain(|e| e.url() != exclude.url()); } + // Sort by health score + endpoints.sort_by_key(|endpoint| std::cmp::Reverse(endpoint.health().score())); + // Pick the first one + let selected_endpoint = endpoints[0].clone(); + // Ensure it's connected + selected_endpoint.connected().await; + selected_endpoint }; - let mut ws = build_ws().await; + let mut selected_endpoint = healthiest_endpoint(None).await; - let handle_message = |message: Message, ws: Arc| { + let handle_message = |message: Message, endpoint: Arc, rotation_notify: Arc| { let tx = message_tx_bg.clone(); let request_backoff_counter = request_backoff_counter.clone(); // total timeout for a request - let task_timeout = request_timeout - .unwrap_or(Duration::from_secs(30)) - // buffer 5 seconds for the request to be processed - .saturating_add(Duration::from_secs(5)); + let task_timeout = request_timeout.unwrap_or(Duration::from_secs(30)); tokio::spawn(async move { match message { @@ -182,71 +237,57 @@ impl Client { return; } - if let Ok(result) = - tokio::time::timeout(task_timeout, ws.request(&method, params.clone())).await - { - match result { - result @ Ok(_) => { - request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(result); + match endpoint.request(&method, params.clone(), task_timeout).await { + result @ Ok(_) => { + request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + // make sure it's still connected + if response.is_closed() { + return; } - Err(err) => { - tracing::debug!("Request failed: {:?}", err); - match err { - Error::RequestTimeout - | Error::Transport(_) - | Error::RestartNeeded(_) - | Error::MaxSlotsExceeded => { - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; - - // make sure it's still connected - if response.is_closed() { - return; - } - - // make sure we still have retries left - if retries == 0 { - let _ = response.send(Err(Error::RequestTimeout)); - return; - } - - if matches!(err, Error::RequestTimeout) { - tx.send(Message::RotateEndpoint) - .await - .expect("Failed to send rotate message"); - } - - tx.send(Message::Request { - method, - params, - response, - retries, - }) - .await - .expect("Failed to send request message"); + let _ = response.send(result); + } + Err(err) => { + tracing::debug!("Request failed: {err:?}"); + match err { + Error::RequestTimeout + | Error::Transport(_) + | Error::RestartNeeded(_) + | Error::MaxSlotsExceeded => { + // Make sure endpoint is rotated + rotation_notify.notified().await; + + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; + + // make sure it's still connected + if response.is_closed() { + return; } - err => { - // make sure it's still connected - if response.is_closed() { - return; - } - // not something we can handle, send it back to the caller - let _ = response.send(Err(err)); + + // make sure we still have retries left + if retries == 0 { + let _ = response.send(Err(Error::RequestTimeout)); + return; + } + + tx.send(Message::Request { + method, + params, + response, + retries, + }) + .await + .expect("Failed to send request message"); + } + err => { + // make sure it's still connected + if response.is_closed() { + return; } + // not something we can handle, send it back to the caller + let _ = response.send(Err(err)); } } } - } else { - tracing::error!("request timed out method: {} params: {:?}", method, params); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(Err(Error::RequestTimeout)); } } Message::Subscribe { @@ -258,75 +299,61 @@ impl Client { } => { retries = retries.saturating_sub(1); - if let Ok(result) = tokio::time::timeout( - task_timeout, - ws.subscribe(&subscribe, params.clone(), &unsubscribe), - ) - .await + match endpoint + .subscribe(&subscribe, params.clone(), &unsubscribe, task_timeout) + .await { - match result { - result @ Ok(_) => { - request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(result); + result @ Ok(_) => { + request_backoff_counter.store(0, std::sync::atomic::Ordering::Relaxed); + // make sure it's still connected + if response.is_closed() { + return; } - Err(err) => { - tracing::debug!("Subscribe failed: {:?}", err); - match err { - Error::RequestTimeout - | Error::Transport(_) - | Error::RestartNeeded(_) - | Error::MaxSlotsExceeded => { - tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; - - // make sure it's still connected - if response.is_closed() { - return; - } - - // make sure we still have retries left - if retries == 0 { - let _ = response.send(Err(Error::RequestTimeout)); - return; - } - - if matches!(err, Error::RequestTimeout) { - tx.send(Message::RotateEndpoint) - .await - .expect("Failed to send rotate message"); - } - - tx.send(Message::Subscribe { - subscribe, - params, - unsubscribe, - response, - retries, - }) - .await - .expect("Failed to send subscribe message") + let _ = response.send(result); + } + Err(err) => { + tracing::debug!("Subscribe failed: {err:?}"); + match err { + Error::RequestTimeout + | Error::Transport(_) + | Error::RestartNeeded(_) + | Error::MaxSlotsExceeded => { + // Make sure endpoint is rotated + rotation_notify.notified().await; + + tokio::time::sleep(get_backoff_time(&request_backoff_counter)).await; + + // make sure it's still connected + if response.is_closed() { + return; + } + + // make sure we still have retries left + if retries == 0 { + let _ = response.send(Err(Error::RequestTimeout)); + return; } - err => { - // make sure it's still connected - if response.is_closed() { - return; - } - // not something we can handle, send it back to the caller - let _ = response.send(Err(err)); + + tx.send(Message::Subscribe { + subscribe, + params, + unsubscribe, + response, + retries, + }) + .await + .expect("Failed to send subscribe message") + } + err => { + // make sure it's still connected + if response.is_closed() { + return; } + // not something we can handle, send it back to the caller + let _ = response.send(Err(err)); } } } - } else { - tracing::error!("subscribe timed out subscribe: {} params: {:?}", subscribe, params); - // make sure it's still connected - if response.is_closed() { - return; - } - let _ = response.send(Err(Error::RequestTimeout)); } } Message::RotateEndpoint => { @@ -338,20 +365,25 @@ impl Client { loop { tokio::select! { - _ = ws.on_disconnect() => { - tracing::info!("Endpoint disconnected"); - tokio::time::sleep(get_backoff_time(&connect_backoff_counter)).await; - ws = build_ws().await; + _ = selected_endpoint.health().unhealthy() => { + // Current selected endpoint is unhealthy, try to rotate to another one. + // In case of all endpoints are unhealthy, we don't want to keep rotating but stick with the healthiest one. + let new_selected_endpoint = healthiest_endpoint(None).await; + if new_selected_endpoint.url() != selected_endpoint.url() { + tracing::warn!("Switch to endpoint: {new_url}", new_url=new_selected_endpoint.url()); + selected_endpoint = new_selected_endpoint; + rotation_notify_bg.notify_waiters(); + } } message = message_rx.recv() => { tracing::trace!("Received message {message:?}"); match message { Some(Message::RotateEndpoint) => { + tracing::info!("Rotating endpoint ..."); + selected_endpoint = healthiest_endpoint(Some(selected_endpoint.clone())).await; rotation_notify_bg.notify_waiters(); - tracing::info!("Rotate endpoint"); - ws = build_ws().await; } - Some(message) => handle_message(message, ws.clone()), + Some(message) => handle_message(message, selected_endpoint.clone(), rotation_notify_bg.clone()), None => { tracing::debug!("Client dropped"); break; @@ -362,10 +394,6 @@ impl Client { } }); - if let Some(0) = retries { - return Err(anyhow!("Retries need to be at least 1")); - } - Ok(Self { sender: message_tx, rotation_notify, @@ -375,7 +403,7 @@ impl Client { } pub fn with_endpoints(endpoints: impl IntoIterator>) -> Result { - Self::new(endpoints, None, None, None) + Self::new(endpoints, None, None, None, None) } pub async fn request(&self, method: &str, params: Vec) -> CallResult { @@ -435,7 +463,7 @@ impl Client { } } -fn get_backoff_time(counter: &Arc) -> Duration { +pub fn get_backoff_time(counter: &Arc) -> Duration { let min_time = 100u64; let step = 100u64; let max_count = 10u32; @@ -465,3 +493,112 @@ fn test_get_backoff_time() { vec![100, 200, 500, 1000, 1700, 2600, 3700, 5000, 6500, 8200, 10100, 10100] ); } + +#[test] +fn health_response_serialize_deserialize_works() { + let response = HealthResponse::Contains(vec![( + "isSyncing".to_string(), + Box::new(HealthResponse::Eq(false.into())), + )]); + + let expected = serde_yaml::from_str::( + r" + !contains + - - isSyncing + - !eq false + ", + ) + .unwrap(); + + assert_eq!(response, expected); +} + +#[test] +fn health_response_validation_works() { + use serde_json::json; + + let expected = serde_yaml::from_str::( + r" + !eq true + ", + ) + .unwrap(); + assert!(expected.validate(&json!(true))); + assert!(!expected.validate(&json!(false))); + + let expected = serde_yaml::from_str::( + r" + !contains + - - isSyncing + - !eq false + ", + ) + .unwrap(); + let cases = [ + (json!({ "isSyncing": false }), true), + (json!({ "isSyncing": true }), false), + (json!({ "isSyncing": false, "peers": 2 }), true), + (json!({ "isSyncing": true, "peers": 2 }), false), + (json!({}), false), + (json!(true), false), + ]; + for (input, output) in cases { + assert_eq!(expected.validate(&input), output); + } + + // multiple items + let expected = serde_yaml::from_str::( + r" + !contains + - - isSyncing + - !eq false + - - peers + - !eq 3 + ", + ) + .unwrap(); + let cases = [ + (json!({ "isSyncing": false, "peers": 3 }), true), + (json!({ "isSyncing": false, "peers": 2 }), false), + (json!({ "isSyncing": true, "peers": 3 }), false), + ]; + for (input, output) in cases { + assert_eq!(expected.validate(&input), output); + } + + // works with strings + let expected = serde_yaml::from_str::( + r" + !contains + - - foo + - !eq bar + ", + ) + .unwrap(); + assert!(expected.validate(&json!({ "foo": "bar" }))); + assert!(!expected.validate(&json!({ "foo": "bar bar" }))); + + // multiple nested items + let expected = serde_yaml::from_str::( + r" + !contains + - - foo + - !contains + - - one + - !eq subway + - - two + - !not_eq subway + ", + ) + .unwrap(); + let cases = [ + (json!({ "foo": { "one": "subway", "two": "not_subway" } }), true), + (json!({ "foo": { "one": "subway", "two": "subway" } }), false), + (json!({ "foo": { "subway": "one" } }), false), + (json!({ "bar" : { "foo": { "subway": "one", "two": "subway" } }}), false), + (json!({ "foo": "subway" }), false), + ]; + for (input, output) in cases { + assert_eq!(expected.validate(&input), output); + } +} diff --git a/src/extensions/client/tests.rs b/src/extensions/client/tests.rs index c8c9c7b..c76a75d 100644 --- a/src/extensions/client/tests.rs +++ b/src/extensions/client/tests.rs @@ -152,6 +152,7 @@ async fn retry_requests_successful() { Some(Duration::from_millis(100)), None, Some(2), + None, ) .unwrap(); @@ -189,6 +190,7 @@ async fn retry_requests_out_of_retries() { Some(Duration::from_millis(100)), None, Some(2), + None, ) .unwrap(); @@ -216,3 +218,75 @@ async fn retry_requests_out_of_retries() { handle1.stop().unwrap(); handle2.stop().unwrap(); } + +#[tokio::test] +async fn health_check_works() { + let (addr1, handle1) = dummy_server_extend(Box::new(|builder| { + let mut system_health = builder.register_method("system_health"); + tokio::spawn(async move { + loop { + tokio::select! { + Some(req) = system_health.recv() => { + req.respond(json!({ "isSyncing": true, "peers": 1, "shouldHavePeers": true })); + } + } + } + }); + })) + .await; + + let (addr2, handle2) = dummy_server_extend(Box::new(|builder| { + let mut system_health = builder.register_method("system_health"); + tokio::spawn(async move { + loop { + tokio::select! { + Some(req) = system_health.recv() => { + req.respond(json!({ "isSyncing": false, "peers": 1, "shouldHavePeers": true })); + } + } + } + }); + })) + .await; + + let client = Client::new( + [format!("ws://{addr1}"), format!("ws://{addr2}")], + None, + None, + None, + Some(HealthCheckConfig { + interval_sec: 1, + healthy_response_time_ms: 250, + health_method: Some("system_health".into()), + response: Some(HealthResponse::Contains(vec![( + "isSyncing".to_string(), + Box::new(HealthResponse::Eq(false.into())), + )])), + }), + ) + .unwrap(); + + // first endpoint is stale + let res = client.request("system_health", vec![]).await; + assert_eq!( + res.unwrap(), + json!({ "isSyncing": true, "peers": 1, "shouldHavePeers": true }) + ); + + // wait for the health check to run + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(1_050)).await; + }) + .await + .unwrap(); + + // second endpoint is healthy + let res = client.request("system_health", vec![]).await; + assert_eq!( + res.unwrap(), + json!({ "isSyncing": false, "peers": 1, "shouldHavePeers": true }) + ); + + handle1.stop().unwrap(); + handle2.stop().unwrap(); +} diff --git a/src/server.rs b/src/server.rs index 0994de3..824a977 100644 --- a/src/server.rs +++ b/src/server.rs @@ -239,6 +239,7 @@ mod tests { client: Some(ClientConfig { endpoints: vec![endpoint], shuffle_endpoints: false, + health_check: None, }), server: Some(ServerConfig { listen_address: "127.0.0.1".to_string(), diff --git a/src/tests/merge_subscription.rs b/src/tests/merge_subscription.rs index a41bb4b..37a7c53 100644 --- a/src/tests/merge_subscription.rs +++ b/src/tests/merge_subscription.rs @@ -49,6 +49,7 @@ async fn merge_subscription_works() { client: Some(ClientConfig { endpoints: vec![format!("ws://{addr}")], shuffle_endpoints: false, + health_check: None, }), server: Some(ServerConfig { listen_address: "0.0.0.0".to_string(), diff --git a/src/tests/upstream.rs b/src/tests/upstream.rs index 84c38f5..9c730da 100644 --- a/src/tests/upstream.rs +++ b/src/tests/upstream.rs @@ -31,6 +31,7 @@ async fn upstream_error_propagate() { client: Some(ClientConfig { endpoints: vec![format!("ws://{addr}")], shuffle_endpoints: false, + health_check: None, }), server: Some(ServerConfig { listen_address: "0.0.0.0".to_string(),