From 9040c0797fb35ef881c84cb359beec814923be4b Mon Sep 17 00:00:00 2001 From: ikolomi Date: Tue, 3 Sep 2024 17:08:36 +0300 Subject: [PATCH] CR changes: Add async method to DisconnectNotifier trait, styling and other cleanups --- redis/src/aio/mod.rs | 5 ++ redis/src/cluster_async/mod.rs | 117 ++++++++++++++++++--------------- 2 files changed, 69 insertions(+), 53 deletions(-) diff --git a/redis/src/aio/mod.rs b/redis/src/aio/mod.rs index 737ad82a7c..ffe2c9e3a2 100644 --- a/redis/src/aio/mod.rs +++ b/redis/src/aio/mod.rs @@ -12,6 +12,7 @@ use std::net::SocketAddr; #[cfg(unix)] use std::path::Path; use std::pin::Pin; +use std::time::Duration; /// Enables the async_std compatibility #[cfg(feature = "async-std-comp")] @@ -91,10 +92,14 @@ pub trait ConnectionLike { } /// Implements ability to notify about disconnection events +#[async_trait] pub trait DisconnectNotifier: Send + Sync { /// Notify about disconnect event fn notify_disconnect(&mut self); + /// Wait for disconnect event with timeout + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration); + /// Intended to be used with Box fn clone_box(&self) -> Box; } diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 9f2ab6c35a..035b4f16f0 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -97,6 +97,8 @@ use backoff_tokio::future::retry; use backoff_tokio::{Error as BackoffError, ExponentialBackoff}; #[cfg(feature = "tokio-comp")] use tokio::{sync::Notify, time::timeout}; +#[cfg(feature = "tokio-comp")] +use async_trait::async_trait; use dispose::{Disposable, Dispose}; use futures::{future::BoxFuture, prelude::*, ready}; @@ -379,20 +381,37 @@ where #[cfg(feature = "tokio-comp")] #[derive(Clone)] struct TokioDisconnectNotifier { - pub disconnect_notifier: Arc, + disconnect_notifier: Arc, } #[cfg(feature = "tokio-comp")] +#[async_trait] impl DisconnectNotifier for TokioDisconnectNotifier { fn notify_disconnect(&mut self) { self.disconnect_notifier.notify_one(); } + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) { + let _ = timeout(*max_wait, async { + self.disconnect_notifier.notified().await; + }) + .await; + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } } +#[cfg(feature = "tokio-comp")] +impl TokioDisconnectNotifier { + fn new() -> TokioDisconnectNotifier { + TokioDisconnectNotifier { + disconnect_notifier: Arc::new(Notify::new()), + } + } +} + type ConnectionMap = connections_container::ConnectionsMap>; type ConnectionsContainer = self::connections_container::ConnectionsContainer>; @@ -406,8 +425,6 @@ pub(crate) struct InnerCore { subscriptions_by_address: RwLock>, unassigned_subscriptions: RwLock, glide_connection_options: GlideConnectionOptions, - #[cfg(feature = "tokio-comp")] - tokio_notify: Arc, } pub(crate) type Core = Arc>; @@ -990,27 +1007,24 @@ where cluster_params: ClusterParams, push_sender: Option>, ) -> RedisResult> { - #[cfg(feature = "tokio-comp")] - let tokio_notify = Arc::new(Notify::new()); - let disconnect_notifier = { #[cfg(feature = "tokio-comp")] { - Some::>(Box::new(TokioDisconnectNotifier { - disconnect_notifier: tokio_notify.clone(), - })) + Some::>(Box::new(TokioDisconnectNotifier::new())) } #[cfg(not(feature = "tokio-comp"))] None }; + let glide_connection_options = GlideConnectionOptions { + push_sender, + disconnect_notifier, + }; + let connections = Self::create_initial_connections( initial_nodes, &cluster_params, - GlideConnectionOptions { - push_sender: push_sender.clone(), - disconnect_notifier: disconnect_notifier.clone(), - }, + glide_connection_options.clone(), ) .await?; @@ -1035,12 +1049,7 @@ where }, ), subscriptions_by_address: RwLock::new(Default::default()), - glide_connection_options: GlideConnectionOptions { - push_sender: push_sender.clone(), - disconnect_notifier: disconnect_notifier.clone(), - }, - #[cfg(feature = "tokio-comp")] - tokio_notify, + glide_connection_options, }); let mut connection = ClusterConnInner { inner, @@ -1227,40 +1236,40 @@ where // In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server. async fn validate_all_user_connections(inner: Arc>) { let mut all_valid_conns = HashMap::new(); - let mut all_nodes_with_slots = HashSet::new(); // prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts - { - let mut nodes_to_delete = Vec::new(); - let connections_container = inner.conn_lock.read().await; - - connections_container - .slot_map - .addresses_for_all_nodes() - .iter() - .for_each(|addr| { - all_nodes_with_slots.insert(String::from(*addr)); - }); + let mut nodes_to_delete = Vec::new(); + let connections_container = inner.conn_lock.read().await; - connections_container - .all_node_connections() - .for_each(|(addr, con)| { - if all_nodes_with_slots.contains(&addr) { - all_valid_conns.insert(addr.clone(), con.clone()); - } else { - nodes_to_delete.push(addr.clone()); - } - }); + let all_nodes_with_slots: HashSet = connections_container + .slot_map + .addresses_for_all_nodes() + .iter() + .map(|addr| String::from(*addr)) + .collect(); + + connections_container + .all_node_connections() + .for_each(|(addr, con)| { + if all_nodes_with_slots.contains(&addr) { + all_valid_conns.insert(addr.clone(), con.clone()); + } else { + nodes_to_delete.push(addr.clone()); + } + }); - for addr in &nodes_to_delete { - connections_container.remove_node(addr); - } + for addr in &nodes_to_delete { + connections_container.remove_node(addr); } + drop(connections_container); + // identify nodes with closed connection let mut addrs_to_refresh = Vec::new(); for (addr, con_fut) in &all_valid_conns { let con = con_fut.clone().await; + // connection object might be present despite the transport being closed if con.is_closed() { + // transport is closed, need to refresh addrs_to_refresh.push(addr.clone()); } } @@ -1289,7 +1298,7 @@ where inner: Arc>, addresses: Vec, conn_type: RefreshConnectionType, - try_existing_node: bool, + check_existing_conn: bool, ) { info!("Started refreshing connections to {:?}", addresses); let connections_container = inner.conn_lock.read().await; @@ -1301,10 +1310,10 @@ where .fold( &*connections_container, |connections_container, address| async move { - let node_option = if try_existing_node { + let node_option = if check_existing_conn { connections_container.remove_node(&address) } else { - Option::None + None }; // override subscriptions for this connection @@ -1541,13 +1550,15 @@ where async fn connections_validation_task(inner: Arc>, interval_duration: Duration) { loop { - #[cfg(feature = "tokio-comp")] - let _ = timeout(interval_duration, async { - inner.tokio_notify.notified().await; - }) - .await; - #[cfg(not(feature = "tokio-comp"))] - let _ = boxed_sleep(interval_duration).await; + if let Some(disconnect_notifier) = + inner.glide_connection_options.disconnect_notifier.clone() + { + disconnect_notifier + .wait_for_disconnect_with_timeout(&interval_duration) + .await; + } else { + let _ = boxed_sleep(interval_duration).await; + } Self::validate_all_user_connections(inner.clone()).await; }