From 2c7faec028d3ab4a751c309f6af9d6c706e79864 Mon Sep 17 00:00:00 2001 From: ikolomi Date: Sun, 18 Aug 2024 15:17:43 +0300 Subject: [PATCH 1/3] Introduce a fast reconnect process for async cluster connections. The process is periodic and can be configured via ClusterParams. This process ensures that all expected user connections exist and have not been passively closed. The expected connections are calculated from the current slot map. Additionally, for the Tokio runtime, an instant disconnect notification is available, allowing the reconnect process to be triggered instantly without waiting for the periodic check. This process is especially important for pub/sub support, as passive disconnects can render a pub/sub subscriber inoperative. Three integration tests are introduced with this feature: a generic fast reconnect test, pub/sub resilience to passive disconnects, and pub/sub resilience to scale-out. --- redis-test/src/lib.rs | 4 + redis/examples/async-await.rs | 2 +- redis/examples/async-connection-loss.rs | 4 +- redis/examples/async-multiplexed.rs | 5 +- redis/examples/async-pub-sub.rs | 2 +- redis/examples/async-scan.rs | 2 +- redis/src/aio/connection.rs | 5 + redis/src/aio/connection_manager.rs | 6 + redis/src/aio/mod.rs | 18 + redis/src/aio/multiplexed_connection.rs | 62 +- redis/src/client.rs | 31 + redis/src/cluster_async/connections_logic.rs | 89 +- redis/src/cluster_async/mod.rs | 188 ++- redis/src/cluster_client.rs | 16 + redis/src/sentinel.rs | 6 +- redis/tests/support/mock_cluster.rs | 8 + redis/tests/support/mod.rs | 16 +- redis/tests/test_async.rs | 10 +- redis/tests/test_async_async_std.rs | 2 +- .../test_async_cluster_connections_logic.rs | 8 + redis/tests/test_cluster_async.rs | 1141 +++++++++-------- redis/tests/test_sentinel.rs | 18 +- 22 files changed, 1015 insertions(+), 628 deletions(-) diff --git a/redis-test/src/lib.rs b/redis-test/src/lib.rs index fb21e13bf..cafe8a347 100644 --- a/redis-test/src/lib.rs +++ b/redis-test/src/lib.rs @@ -288,6 +288,10 @@ impl AioConnectionLike for MockRedisConnection { fn get_db(&self) -> i64 { 0 } + + fn is_closed(&self) -> bool { + false + } } #[cfg(test)] diff --git a/redis/examples/async-await.rs b/redis/examples/async-await.rs index 36b8182a8..b52776a46 100644 --- a/redis/examples/async-await.rs +++ b/redis/examples/async-await.rs @@ -4,7 +4,7 @@ use redis::AsyncCommands; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut con = client.get_multiplexed_async_connection(None).await?; + let mut con = client.get_multiplexed_async_connection(None, None).await?; con.set("key1", b"foo").await?; diff --git a/redis/examples/async-connection-loss.rs b/redis/examples/async-connection-loss.rs index 4c2d54d08..90af361f2 100644 --- a/redis/examples/async-connection-loss.rs +++ b/redis/examples/async-connection-loss.rs @@ -80,7 +80,9 @@ async fn main() -> RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); match mode { - Mode::Default => run_multi(client.get_multiplexed_tokio_connection(None).await?).await?, + Mode::Default => { + run_multi(client.get_multiplexed_tokio_connection(None, None).await?).await? + } Mode::Reconnect => run_multi(client.get_connection_manager().await?).await?, #[allow(deprecated)] Mode::Deprecated => run_single(client.get_async_connection(None).await?).await?, diff --git a/redis/examples/async-multiplexed.rs b/redis/examples/async-multiplexed.rs index b057b759c..9c8c73235 100644 --- a/redis/examples/async-multiplexed.rs +++ b/redis/examples/async-multiplexed.rs @@ -34,7 +34,10 @@ async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> { async fn main() { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let con = client.get_multiplexed_tokio_connection(None).await.unwrap(); + let con = client + .get_multiplexed_tokio_connection(None, None) + .await + .unwrap(); let cmds = (0..100).map(|i| test_cmd(&con, i)); let result = future::try_join_all(cmds).await.unwrap(); diff --git a/redis/examples/async-pub-sub.rs b/redis/examples/async-pub-sub.rs index 3dbb7e0f9..15634e2b0 100644 --- a/redis/examples/async-pub-sub.rs +++ b/redis/examples/async-pub-sub.rs @@ -5,7 +5,7 @@ use redis::AsyncCommands; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut publish_conn = client.get_multiplexed_async_connection(None).await?; + let mut publish_conn = client.get_multiplexed_async_connection(None, None).await?; let mut pubsub_conn = client.get_async_pubsub().await?; pubsub_conn.subscribe("wavephone").await?; diff --git a/redis/examples/async-scan.rs b/redis/examples/async-scan.rs index 55e33d0ea..6f55ac933 100644 --- a/redis/examples/async-scan.rs +++ b/redis/examples/async-scan.rs @@ -5,7 +5,7 @@ use redis::{AsyncCommands, AsyncIter}; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut con = client.get_multiplexed_async_connection(None).await?; + let mut con = client.get_multiplexed_async_connection(None, None).await?; con.set("async-key1", b"foo").await?; con.set("async-key2", b"foo").await?; diff --git a/redis/src/aio/connection.rs b/redis/src/aio/connection.rs index d78ef0850..6b1f6e657 100644 --- a/redis/src/aio/connection.rs +++ b/redis/src/aio/connection.rs @@ -305,6 +305,11 @@ where fn get_db(&self) -> i64 { self.db } + + fn is_closed(&self) -> bool { + // always false for AsyncRead + AsyncWrite (cant do better) + false + } } /// Represents a `PubSub` connection. diff --git a/redis/src/aio/connection_manager.rs b/redis/src/aio/connection_manager.rs index 0070d9773..741086d76 100644 --- a/redis/src/aio/connection_manager.rs +++ b/redis/src/aio/connection_manager.rs @@ -196,6 +196,7 @@ impl ConnectionManager { response_timeout, connection_timeout, None, + None, ) }) .await @@ -301,4 +302,9 @@ impl ConnectionLike for ConnectionManager { fn get_db(&self) -> i64 { self.client.connection_info().redis.db } + + fn is_closed(&self) -> bool { + // always return false due to automatic reconnect + false + } } diff --git a/redis/src/aio/mod.rs b/redis/src/aio/mod.rs index 04ebe960f..737ad82a7 100644 --- a/redis/src/aio/mod.rs +++ b/redis/src/aio/mod.rs @@ -85,6 +85,24 @@ pub trait ConnectionLike { /// also might be incorrect if the connection like object is not /// actually connected. fn get_db(&self) -> i64; + + /// Returns the state of the connection + fn is_closed(&self) -> bool; +} + +/// Implements ability to notify about disconnection events +pub trait DisconnectNotifier: Send + Sync { + /// Notify about disconnect event + fn notify_disconnect(&mut self); + + /// Intended to be used with Box + fn clone_box(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } } // Initial setup for every connection. diff --git a/redis/src/aio/multiplexed_connection.rs b/redis/src/aio/multiplexed_connection.rs index 64e1ed7f2..c085461cf 100644 --- a/redis/src/aio/multiplexed_connection.rs +++ b/redis/src/aio/multiplexed_connection.rs @@ -1,5 +1,6 @@ use super::{ConnectionLike, Runtime}; use crate::aio::setup_connection; +use crate::aio::DisconnectNotifier; use crate::cmd::Cmd; #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] use crate::parser::ValueCodec; @@ -23,6 +24,7 @@ use std::fmt; use std::fmt::Debug; use std::io; use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::task::{self, Poll}; use std::time::Duration; @@ -73,19 +75,11 @@ struct PipelineMessage { /// items being output by the `Stream` (the number is specified at time of sending). With the /// interface provided by `Pipeline` an easy interface of request to response, hiding the `Stream` /// and `Sink`. +#[derive(Clone)] struct Pipeline { sender: mpsc::Sender>, - push_manager: Arc>, -} - -impl Clone for Pipeline { - fn clone(&self) -> Self { - Pipeline { - sender: self.sender.clone(), - push_manager: self.push_manager.clone(), - } - } + is_stream_closed: Arc, } impl Debug for Pipeline @@ -104,6 +98,8 @@ pin_project! { in_flight: VecDeque, error: Option, push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, } } @@ -111,7 +107,12 @@ impl PipelineSink where T: Stream> + 'static, { - fn new(sink_stream: T, push_manager: Arc>) -> Self + fn new( + sink_stream: T, + push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, + ) -> Self where T: Sink + Stream> + 'static, { @@ -120,6 +121,8 @@ where in_flight: VecDeque::new(), error: None, push_manager, + disconnect_notifier, + is_stream_closed, } } @@ -130,7 +133,15 @@ where Some(result) => result, // The redis response stream is not going to produce any more items so we `Err` // to break out of the `forward` combinator and stop handling requests - None => return Poll::Ready(Err(())), + None => { + // this is the right place to notify about the passive TCP disconnect + // In other places we cannot distinguish between the active destruction of MultiplexedConnection and passive disconnect + if let Some(disconnect_notifier) = self.as_mut().project().disconnect_notifier { + disconnect_notifier.notify_disconnect(); + } + self.is_stream_closed.store(true, Ordering::Relaxed); + return Poll::Ready(Err(())); + } }; self.as_mut().send_result(item); } @@ -296,7 +307,10 @@ impl Pipeline where SinkItem: Send + 'static, { - fn new(sink_stream: T) -> (Self, impl Future) + fn new( + sink_stream: T, + disconnect_notifier: Option>, + ) -> (Self, impl Future) where T: Sink + Stream> + 'static, T: Send + 'static, @@ -308,7 +322,13 @@ where let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE); let push_manager: Arc> = Arc::new(ArcSwap::new(Arc::new(PushManager::default()))); - let sink = PipelineSink::new::(sink_stream, push_manager.clone()); + let is_stream_closed = Arc::new(AtomicBool::new(false)); + let sink = PipelineSink::new::( + sink_stream, + push_manager.clone(), + disconnect_notifier, + is_stream_closed.clone(), + ); let f = stream::poll_fn(move |cx| receiver.poll_recv(cx)) .map(Ok) .forward(sink) @@ -317,6 +337,7 @@ where Pipeline { sender, push_manager, + is_stream_closed, }, f, ) @@ -363,6 +384,10 @@ where async fn set_push_manager(&mut self, push_manager: PushManager) { self.push_manager.store(Arc::new(push_manager)); } + + pub fn is_closed(&self) -> bool { + self.is_stream_closed.load(Ordering::Relaxed) + } } /// A connection object which can be cloned, allowing requests to be be sent concurrently @@ -392,6 +417,7 @@ impl MultiplexedConnection { connection_info: &ConnectionInfo, stream: C, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<(Self, impl Future)> where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, @@ -401,6 +427,7 @@ impl MultiplexedConnection { stream, std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -412,6 +439,7 @@ impl MultiplexedConnection { stream: C, response_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<(Self, impl Future)> where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, @@ -429,7 +457,7 @@ impl MultiplexedConnection { let codec = ValueCodec::default() .framed(stream) .and_then(|msg| async move { msg }); - let (mut pipeline, driver) = Pipeline::new(codec); + let (mut pipeline, driver) = Pipeline::new(codec, disconnect_notifier); let driver = boxed(driver); let pm = PushManager::default(); if let Some(sender) = push_sender { @@ -560,6 +588,10 @@ impl ConnectionLike for MultiplexedConnection { fn get_db(&self) -> i64 { self.db } + + fn is_closed(&self) -> bool { + self.pipeline.is_closed() + } } impl MultiplexedConnection { /// Subscribes to a new channel. diff --git a/redis/src/client.rs b/redis/src/client.rs index 7ace00089..534c186d9 100644 --- a/redis/src/client.rs +++ b/redis/src/client.rs @@ -1,5 +1,8 @@ use std::time::Duration; +#[cfg(feature = "aio")] +use crate::aio::DisconnectNotifier; + use crate::{ connection::{connect, Connection, ConnectionInfo, ConnectionLike, IntoConnectionInfo}, push_manager::PushInfo, @@ -147,11 +150,13 @@ impl Client { pub async fn get_multiplexed_async_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { self.get_multiplexed_async_connection_with_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -167,6 +172,7 @@ impl Client { response_timeout: std::time::Duration, connection_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { let result = match Runtime::locate() { #[cfg(feature = "tokio-comp")] @@ -177,6 +183,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ), ) .await @@ -189,6 +196,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ), ) .await @@ -213,6 +221,7 @@ impl Client { pub async fn get_multiplexed_async_connection_and_ip( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> { match Runtime::locate() { #[cfg(feature = "tokio-comp")] @@ -221,6 +230,7 @@ impl Client { Duration::MAX, None, push_sender, + disconnect_notifier, ) .await } @@ -230,6 +240,7 @@ impl Client { Duration::MAX, None, push_sender, + disconnect_notifier, ) .await } @@ -247,6 +258,7 @@ impl Client { response_timeout: std::time::Duration, connection_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { let result = Runtime::locate() .timeout( @@ -255,6 +267,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ), ) .await; @@ -275,11 +288,13 @@ impl Client { pub async fn get_multiplexed_tokio_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { self.get_multiplexed_tokio_connection_with_response_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -295,6 +310,7 @@ impl Client { response_timeout: std::time::Duration, connection_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { let result = Runtime::locate() .timeout( @@ -303,6 +319,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ), ) .await; @@ -323,11 +340,13 @@ impl Client { pub async fn get_multiplexed_async_std_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult { self.get_multiplexed_async_std_connection_with_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -344,6 +363,7 @@ impl Client { &self, response_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -352,6 +372,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ) .await .map(|(conn, driver, _ip)| (conn, driver)) @@ -367,6 +388,7 @@ impl Client { pub async fn create_multiplexed_tokio_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -374,6 +396,7 @@ impl Client { self.create_multiplexed_tokio_connection_with_response_timeout( std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await .map(|conn_res| (conn_res.0, conn_res.1)) @@ -391,6 +414,7 @@ impl Client { &self, response_timeout: std::time::Duration, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -399,6 +423,7 @@ impl Client { response_timeout, None, push_sender, + disconnect_notifier, ) .await .map(|(conn, driver, _ip)| (conn, driver)) @@ -414,6 +439,7 @@ impl Client { pub async fn create_multiplexed_async_std_connection( &self, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -421,6 +447,7 @@ impl Client { self.create_multiplexed_async_std_connection_with_response_timeout( std::time::Duration::MAX, push_sender, + disconnect_notifier, ) .await } @@ -624,6 +651,7 @@ impl Client { response_timeout: std::time::Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> where T: crate::aio::RedisRuntime, @@ -633,6 +661,7 @@ impl Client { response_timeout, socket_addr, push_sender, + disconnect_notifier, ) .await?; T::spawn(driver); @@ -644,6 +673,7 @@ impl Client { response_timeout: std::time::Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -658,6 +688,7 @@ impl Client { con, response_timeout, push_sender, + disconnect_notifier, ) .await .map(|res| (res.0, res.1, ip)) diff --git a/redis/src/cluster_async/connections_logic.rs b/redis/src/cluster_async/connections_logic.rs index 96d9965c3..dc3fd82d0 100644 --- a/redis/src/cluster_async/connections_logic.rs +++ b/redis/src/cluster_async/connections_logic.rs @@ -5,7 +5,7 @@ use super::{ Connect, }; use crate::{ - aio::{ConnectionLike, Runtime}, + aio::{ConnectionLike, DisconnectNotifier, Runtime}, cluster::get_connection_info, cluster_client::ClusterParams, push_manager::PushInfo, @@ -57,6 +57,7 @@ pub(crate) async fn get_or_create_conn( params: &ClusterParams, conn_type: RefreshConnectionType, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult> where C: ConnectionLike + Send + Clone + Sync + Connect + 'static, @@ -73,14 +74,23 @@ where conn_type, Some(node), push_sender, + disconnect_notifier, ) .await .get_node(), } } else { - connect_and_check(addr, params.clone(), None, conn_type, None, push_sender) - .await - .get_node() + connect_and_check( + addr, + params.clone(), + None, + conn_type, + None, + push_sender, + disconnect_notifier, + ) + .await + .get_node() } } @@ -102,6 +112,7 @@ pub(crate) async fn connect_and_check_all_connections( params: ClusterParams, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> ConnectAndCheckResult where C: ConnectionLike + Connect + Send + Sync + 'static + Clone, @@ -113,8 +124,16 @@ where socket_addr, push_sender.clone(), false, + disconnect_notifier.clone(), + ), + create_connection( + addr, + params.clone(), + socket_addr, + push_sender, + true, + disconnect_notifier, ), - create_connection(addr, params.clone(), socket_addr, push_sender, true), ) .await { @@ -160,11 +179,21 @@ async fn connect_and_check_only_management_conn( params: ClusterParams, socket_addr: Option, prev_node: AsyncClusterNode, + disconnect_notifier: Option>, ) -> ConnectAndCheckResult where C: ConnectionLike + Connect + Send + Sync + 'static + Clone, { - match create_connection::(addr, params.clone(), socket_addr, None, true).await { + match create_connection::( + addr, + params.clone(), + socket_addr, + None, + true, + disconnect_notifier, + ) + .await + { Err(conn_err) => failed_management_connection(addr, prev_node.user_connection, conn_err), Ok(mut connection) => { @@ -241,6 +270,7 @@ pub async fn connect_and_check( conn_type: RefreshConnectionType, node: Option>, push_sender: Option>, + disconnect_notifier: Option>, ) -> ConnectAndCheckResult where C: ConnectionLike + Connect + Send + Sync + 'static + Clone, @@ -252,6 +282,7 @@ where params.clone(), socket_addr, push_sender, + disconnect_notifier, ) .await { @@ -265,15 +296,36 @@ where // Refreshing only the management connection requires the node to exist alongside a user connection. Otherwise, refresh all connections. match node { Some(node) => { - connect_and_check_only_management_conn(addr, params, socket_addr, node).await + connect_and_check_only_management_conn( + addr, + params, + socket_addr, + node, + disconnect_notifier, + ) + .await } None => { - connect_and_check_all_connections(addr, params, socket_addr, push_sender).await + connect_and_check_all_connections( + addr, + params, + socket_addr, + push_sender, + disconnect_notifier, + ) + .await } } } RefreshConnectionType::AllConnections => { - connect_and_check_all_connections(addr, params, socket_addr, push_sender).await + connect_and_check_all_connections( + addr, + params, + socket_addr, + push_sender, + disconnect_notifier, + ) + .await } } } @@ -283,12 +335,20 @@ async fn create_and_setup_user_connection( params: ClusterParams, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult> where C: ConnectionLike + Connect + Send + 'static, { - let mut connection: ConnectionWithIp = - create_connection(node, params.clone(), socket_addr, push_sender, false).await?; + let mut connection: ConnectionWithIp = create_connection( + node, + params.clone(), + socket_addr, + push_sender, + false, + disconnect_notifier, + ) + .await?; setup_user_connection(&mut connection.conn, params).await?; Ok(connection) } @@ -328,6 +388,7 @@ async fn create_connection( socket_addr: Option, push_sender: Option>, is_management: bool, + disconnect_notifier: Option>, ) -> RedisResult> where C: ConnectionLike + Connect + Send + 'static, @@ -339,12 +400,18 @@ where params.pubsub_subscriptions = None; } let info = get_connection_info(node, params)?; + // management connection does not require notifications or disconnect notifications C::connect( info, response_timeout, connection_timeout, socket_addr, if !is_management { push_sender } else { None }, + if !is_management { + disconnect_notifier + } else { + None + }, ) .await .map(|conn| conn.into()) diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index cf977dd2a..aa5ea47f1 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -58,7 +58,7 @@ use std::{ use tokio::task::JoinHandle; use crate::{ - aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection, Runtime}, + aio::{get_socket_addrs, ConnectionLike, DisconnectNotifier, MultiplexedConnection, Runtime}, cluster::slot_cmd, cluster_async::connections_logic::{ get_host_and_port_from_addr, get_or_create_conn, ConnectionFuture, RefreshConnectionType, @@ -91,6 +91,8 @@ use backoff_std_async::{Error as BackoffError, ExponentialBackoff}; use backoff_tokio::future::retry; #[cfg(feature = "tokio-comp")] use backoff_tokio::{Error as BackoffError, ExponentialBackoff}; +#[cfg(feature = "tokio-comp")] +use tokio::{sync::Notify, time::timeout}; use dispose::{Disposable, Dispose}; use futures::{future::BoxFuture, prelude::*, ready}; @@ -370,6 +372,23 @@ where } } +#[cfg(feature = "tokio-comp")] +#[derive(Clone)] +struct TokioDisconnectNotifier { + pub disconnect_notifier: Arc, +} + +#[cfg(feature = "tokio-comp")] +impl DisconnectNotifier for TokioDisconnectNotifier { + fn notify_disconnect(&mut self) { + self.disconnect_notifier.notify_one(); + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + type ConnectionMap = connections_container::ConnectionsMap>; type ConnectionsContainer = self::connections_container::ConnectionsContainer>; @@ -383,6 +402,9 @@ pub(crate) struct InnerCore { push_sender: Option>, subscriptions_by_address: RwLock>, unassigned_subscriptions: RwLock, + disconnect_notifier: Option>, + #[cfg(feature = "tokio-comp")] + tokio_notify: Arc, } pub(crate) type Core = Arc>; @@ -461,6 +483,8 @@ pub(crate) struct ClusterConnInner { refresh_error: Option, // Handler of the periodic check task. periodic_checks_handler: Option>, + // Handler of fast connection validation task + connections_validation_handler: Option>, } impl Dispose for ClusterConnInner { @@ -471,6 +495,12 @@ impl Dispose for ClusterConnInner { #[cfg(feature = "tokio-comp")] handle.abort() } + if let Some(handle) = self.connections_validation_handler { + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + block_on(handle.cancel()); + #[cfg(feature = "tokio-comp")] + handle.abort() + } } } @@ -957,9 +987,27 @@ where cluster_params: ClusterParams, push_sender: Option>, ) -> RedisResult> { - let connections = - Self::create_initial_connections(initial_nodes, &cluster_params, push_sender.clone()) - .await?; + #[cfg(feature = "tokio-comp")] + let tokio_notify = Arc::new(Notify::new()); + + let disconnect_notifier = { + #[cfg(feature = "tokio-comp")] + { + Some::>(Box::new(TokioDisconnectNotifier { + disconnect_notifier: tokio_notify.clone(), + })) + } + #[cfg(not(feature = "tokio-comp"))] + None + }; + + let connections = Self::create_initial_connections( + initial_nodes, + &cluster_params, + push_sender.clone(), + disconnect_notifier.clone(), + ) + .await?; let topology_checks_interval = cluster_params.topology_checks_interval; let slots_refresh_rate_limiter = cluster_params.slots_refresh_rate_limit; @@ -983,6 +1031,9 @@ where }, ), subscriptions_by_address: RwLock::new(Default::default()), + disconnect_notifier: disconnect_notifier.clone(), + #[cfg(feature = "tokio-comp")] + tokio_notify, }); let mut connection = ClusterConnInner { inner, @@ -990,6 +1041,7 @@ where refresh_error: None, state: ConnectionState::PollComplete, periodic_checks_handler: None, + connections_validation_handler: None, }; Self::refresh_slots_and_subscriptions_with_retries( connection.inner.clone(), @@ -1010,6 +1062,22 @@ where } } + let connections_validation_interval = cluster_params.connections_validation_interval; + if let Some(duration) = connections_validation_interval { + let connections_validation_handler = + ClusterConnInner::connections_validation_task(connection.inner.clone(), duration); + #[cfg(feature = "tokio-comp")] + { + connection.connections_validation_handler = + Some(tokio::spawn(connections_validation_handler)); + } + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + { + connection.connections_validation_handler = + Some(spawn(connections_validation_handler)); + } + } + Ok(Disposable::new(connection)) } @@ -1058,6 +1126,7 @@ where initial_nodes: &[ConnectionInfo], params: &ClusterParams, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisResult> { let initial_nodes: Vec<(String, Option)> = Self::try_to_expand_initial_nodes(initial_nodes).await; @@ -1067,6 +1136,7 @@ where let push_sender = push_sender.clone(); // set subscriptions to none, they will be applied upon the topology discovery params.pubsub_subscriptions = None; + let disconnect_notifier = disconnect_notifier.clone(); async move { let result = connect_and_check( @@ -1076,6 +1146,7 @@ where RefreshConnectionType::AllConnections, None, push_sender, + disconnect_notifier, ) .await .get_node(); @@ -1122,6 +1193,7 @@ where &inner.initial_nodes, &inner.cluster_params, None, + inner.disconnect_notifier.clone(), ) .await { @@ -1145,22 +1217,93 @@ where } } + // Validate all existing user connections and try to reconnect if nessesary. + // In addition, as a safety measure, drop nodes that do not have any assigned slots. + // This function serves as a cheap alternative to slot_refresh() and thus can be used much more frequently. + // The function does not discover the topology from the cluster and assumes the cached topology is valid. + // In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server. + async fn validate_all_user_connections(inner: Arc>) { + let mut all_valid_conns = HashMap::new(); + let mut all_nodes_with_slots = HashSet::new(); + // prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts + { + let mut nodes_to_delete = Vec::new(); + let connections_container = inner.conn_lock.read().await; + + connections_container + .slot_map + .addresses_for_all_nodes() + .iter() + .for_each(|addr| { + all_nodes_with_slots.insert(String::from(*addr)); + }); + + connections_container + .all_node_connections() + .for_each(|(addr, con)| { + if all_nodes_with_slots.contains(&addr) { + all_valid_conns.insert(addr.clone(), con.clone()); + } else { + nodes_to_delete.push(addr.clone()); + } + }); + + for addr in &nodes_to_delete { + connections_container.remove_node(addr); + } + } + + // identify nodes with closed connection + let mut addrs_to_refresh = Vec::new(); + for (addr, con_fut) in &all_valid_conns { + let con = con_fut.clone().await; + if con.is_closed() { + addrs_to_refresh.push(addr.clone()); + } + } + + // identify missing nodes + addrs_to_refresh.extend( + all_nodes_with_slots + .iter() + .filter(|addr| !all_valid_conns.contains_key(*addr)) + .cloned(), + ); + + if !addrs_to_refresh.is_empty() { + // dont try existing nodes since we know a. it does not exist. b. exist but its connection is closed + Self::refresh_connections( + inner.clone(), + addrs_to_refresh, + RefreshConnectionType::AllConnections, + false, + ) + .await; + } + } + async fn refresh_connections( inner: Arc>, addresses: Vec, conn_type: RefreshConnectionType, + try_existing_node: bool, ) { info!("Started refreshing connections to {:?}", addresses); let connections_container = inner.conn_lock.read().await; let cluster_params = &inner.cluster_params; let subscriptions_by_address = &inner.subscriptions_by_address; let push_sender = &inner.push_sender; + let disconnect_notifier = &inner.disconnect_notifier; stream::iter(addresses.into_iter()) .fold( &*connections_container, |connections_container, address| async move { - let node_option = connections_container.remove_node(&address); + let node_option = if try_existing_node { + connections_container.remove_node(&address) + } else { + Option::None + }; // override subscriptions for this connection let mut cluster_params = cluster_params.clone(); @@ -1173,6 +1316,7 @@ where &cluster_params, conn_type, push_sender.clone(), + disconnect_notifier.clone(), ) .await; match node { @@ -1394,6 +1538,20 @@ where } } + async fn connections_validation_task(inner: Arc>, interval_duration: Duration) { + loop { + #[cfg(feature = "tokio-comp")] + let _ = timeout(interval_duration, async { + inner.tokio_notify.notified().await; + }) + .await; + #[cfg(not(feature = "tokio-comp"))] + let _ = boxed_sleep(interval_duration).await; + + Self::validate_all_user_connections(inner.clone()).await; + } + } + async fn refresh_pubsub_subscriptions(inner: Arc>) { if inner.cluster_params.protocol != crate::types::ProtocolVersion::RESP3 { return; @@ -1471,17 +1629,12 @@ where drop(subs_by_address_guard); if !addrs_to_refresh.is_empty() { - let conns_read_guard = inner.conn_lock.read().await; - // have to remove or otherwise the refresh_connection wont trigger node recreation - for addr_to_refresh in addrs_to_refresh.iter() { - conns_read_guard.remove_node(addr_to_refresh); - } - drop(conns_read_guard); // immediately trigger connection reestablishment Self::refresh_connections( inner.clone(), addrs_to_refresh.into_iter().collect(), RefreshConnectionType::AllConnections, + false, ) .await; } @@ -1517,6 +1670,7 @@ where inner, failed_connections, RefreshConnectionType::OnlyManagementConnection, + true, ) .await; } @@ -1616,6 +1770,7 @@ where &cluster_params, RefreshConnectionType::AllConnections, inner.push_sender.clone(), + inner.disconnect_notifier.clone(), ) .await; if let Ok(node) = node { @@ -1911,6 +2066,7 @@ where RefreshConnectionType::AllConnections, None, core.push_sender.clone(), + core.disconnect_notifier.clone(), ) .await .get_node() @@ -2221,6 +2377,7 @@ where self.inner.clone(), addresses, RefreshConnectionType::OnlyUserConnection, + true, ), ))); } @@ -2329,7 +2486,12 @@ where fn get_db(&self) -> i64 { 0 } + + fn is_closed(&self) -> bool { + false + } } + /// Implements the process of connecting to a Redis server /// and obtaining a connection handle. pub trait Connect: Sized { @@ -2342,6 +2504,7 @@ pub trait Connect: Sized { connection_timeout: Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a; @@ -2354,6 +2517,7 @@ impl Connect for MultiplexedConnection { connection_timeout: Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisFuture<'a, (MultiplexedConnection, Option)> where T: IntoConnectionInfo + Send + 'a, @@ -2371,6 +2535,7 @@ impl Connect for MultiplexedConnection { response_timeout, socket_addr, push_sender, + disconnect_notifier, ), ) .await? @@ -2382,6 +2547,7 @@ impl Connect for MultiplexedConnection { response_timeout, socket_addr, push_sender, + disconnect_notifier, )) .await? } diff --git a/redis/src/cluster_client.rs b/redis/src/cluster_client.rs index 7c4763179..5815bede1 100644 --- a/redis/src/cluster_client.rs +++ b/redis/src/cluster_client.rs @@ -42,6 +42,8 @@ struct BuilderParams { #[cfg(feature = "cluster-async")] topology_checks_interval: Option, #[cfg(feature = "cluster-async")] + connections_validation_interval: Option, + #[cfg(feature = "cluster-async")] slots_refresh_rate_limit: SlotsRefreshRateLimit, client_name: Option, response_timeout: Option, @@ -138,6 +140,8 @@ pub struct ClusterParams { pub(crate) topology_checks_interval: Option, #[cfg(feature = "cluster-async")] pub(crate) slots_refresh_rate_limit: SlotsRefreshRateLimit, + #[cfg(feature = "cluster-async")] + pub(crate) connections_validation_interval: Option, pub(crate) tls_params: Option, pub(crate) client_name: Option, pub(crate) connection_timeout: Duration, @@ -169,6 +173,8 @@ impl ClusterParams { topology_checks_interval: value.topology_checks_interval, #[cfg(feature = "cluster-async")] slots_refresh_rate_limit: value.slots_refresh_rate_limit, + #[cfg(feature = "cluster-async")] + connections_validation_interval: value.connections_validation_interval, tls_params, client_name: value.client_name, response_timeout: value.response_timeout.unwrap_or(Duration::MAX), @@ -393,6 +399,16 @@ impl ClusterClientBuilder { self } + /// Enables periodic connections checks for this client. + /// If enabled, the conenctions to the cluster nodes will be validated periodicatly, per configured interval. + /// In addition, for tokio runtime, passive disconnections could be detected instantly, + /// triggering reestablishemnt, w/o waiting for the next periodic check. + #[cfg(feature = "cluster-async")] + pub fn periodic_connections_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.connections_validation_interval = Some(interval); + self + } + /// Sets the rate limit for slot refresh operations in the cluster. /// /// This method configures the interval duration between consecutive slot diff --git a/redis/src/sentinel.rs b/redis/src/sentinel.rs index 2e30ec02d..8b853f643 100644 --- a/redis/src/sentinel.rs +++ b/redis/src/sentinel.rs @@ -301,7 +301,7 @@ fn find_valid_master( #[cfg(feature = "aio")] async fn async_check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { if let Ok(client) = Client::open(connection_info.clone()) { - if let Ok(mut conn) = client.get_multiplexed_async_connection(None).await { + if let Ok(mut conn) = client.get_multiplexed_async_connection(None, None).await { let result: RedisResult> = crate::cmd("ROLE").query_async(&mut conn).await; return check_role_result(&result, target_role); } @@ -366,7 +366,7 @@ async fn async_reconnect( ) -> RedisResult<()> { let sentinel_client = Client::open(connection_info.clone())?; let new_connection = sentinel_client - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await?; connection.replace(new_connection); Ok(()) @@ -768,6 +768,6 @@ impl SentinelClient { #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] pub async fn get_async_connection(&mut self) -> RedisResult { let client = self.async_get_client().await?; - client.get_multiplexed_async_connection(None).await + client.get_multiplexed_async_connection(None, None).await } } diff --git a/redis/tests/support/mock_cluster.rs b/redis/tests/support/mock_cluster.rs index b9f27710b..93acff5e8 100644 --- a/redis/tests/support/mock_cluster.rs +++ b/redis/tests/support/mock_cluster.rs @@ -29,6 +29,9 @@ use futures::future; #[cfg(feature = "cluster-async")] use tokio::runtime::Runtime; +#[cfg(feature = "aio")] +use redis::aio::DisconnectNotifier; + type Handler = Arc Result<(), RedisResult> + Send + Sync>; pub struct MockConnectionBehavior { @@ -135,6 +138,7 @@ impl cluster_async::Connect for MockConnection { _connection_timeout: Duration, _socket_addr: Option, _push_sender: Option>, + _disconnect_notifier: Option>, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a, @@ -369,6 +373,10 @@ impl aio::ConnectionLike for MockConnection { fn get_db(&self) -> i64 { 0 } + + fn is_closed(&self) -> bool { + false + } } impl redis::ConnectionLike for MockConnection { diff --git a/redis/tests/support/mod.rs b/redis/tests/support/mod.rs index 24f786c2e..96ce71e6a 100644 --- a/redis/tests/support/mod.rs +++ b/redis/tests/support/mod.rs @@ -501,7 +501,9 @@ impl TestContext { #[cfg(feature = "aio")] pub async fn async_connection(&self) -> redis::RedisResult { - self.client.get_multiplexed_async_connection(None).await + self.client + .get_multiplexed_async_connection(None, None) + .await } #[cfg(feature = "aio")] @@ -513,7 +515,9 @@ impl TestContext { pub async fn async_connection_async_std( &self, ) -> redis::RedisResult { - self.client.get_multiplexed_async_std_connection(None).await + self.client + .get_multiplexed_async_std_connection(None, None) + .await } pub fn stop_server(&mut self) { @@ -531,14 +535,18 @@ impl TestContext { pub async fn multiplexed_async_connection_tokio( &self, ) -> redis::RedisResult { - self.client.get_multiplexed_tokio_connection(None).await + self.client + .get_multiplexed_tokio_connection(None, None) + .await } #[cfg(feature = "async-std-comp")] pub async fn multiplexed_async_connection_async_std( &self, ) -> redis::RedisResult { - self.client.get_multiplexed_async_std_connection(None).await + self.client + .get_multiplexed_async_std_connection(None, None) + .await } pub fn get_version(&self) -> Version { diff --git a/redis/tests/test_async.rs b/redis/tests/test_async.rs index c0fc7fe3e..f7c892a26 100644 --- a/redis/tests/test_async.rs +++ b/redis/tests/test_async.rs @@ -100,7 +100,7 @@ mod basic_async { fn dont_panic_on_closed_multiplexed_connection() { let ctx = TestContext::new(); let client = ctx.client.clone(); - let connect = client.get_multiplexed_async_connection(None); + let connect = client.get_multiplexed_async_connection(None, None); drop(ctx); block_on_all(async move { @@ -584,7 +584,7 @@ mod basic_async { let client = redis::Client::open(coninfo).unwrap(); let err = client - .get_multiplexed_tokio_connection(None) + .get_multiplexed_tokio_connection(None, None) .await .err() .unwrap(); @@ -916,7 +916,7 @@ mod basic_async { let millisecond = std::time::Duration::from_millis(1); let mut retries = 0; loop { - match client.get_multiplexed_async_connection(None).await { + match client.get_multiplexed_async_connection(None, None).await { Err(err) => { if err.is_connection_refusal() { tokio::time::sleep(millisecond).await; @@ -986,7 +986,7 @@ mod basic_async { let client = build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, true) .unwrap(); - let connect = client.get_multiplexed_async_connection(None); + let connect = client.get_multiplexed_async_connection(None, None); block_on_all(connect.and_then(|mut con| async move { redis::cmd("SET") .arg("key1") @@ -1007,7 +1007,7 @@ mod basic_async { let client = build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, false) .unwrap(); - let connect = client.get_multiplexed_async_connection(None); + let connect = client.get_multiplexed_async_connection(None, None); let result = block_on_all(connect.and_then(|mut con| async move { redis::cmd("SET") .arg("key1") diff --git a/redis/tests/test_async_async_std.rs b/redis/tests/test_async_async_std.rs index aabe58320..ae2ae8443 100644 --- a/redis/tests/test_async_async_std.rs +++ b/redis/tests/test_async_async_std.rs @@ -61,7 +61,7 @@ fn test_args_async_std() { fn dont_panic_on_closed_multiplexed_connection() { let ctx = TestContext::new(); let client = ctx.client.clone(); - let connect = client.get_multiplexed_async_std_connection(None); + let connect = client.get_multiplexed_async_std_connection(None, None); drop(ctx); block_on_all_using_async_std(async move { diff --git a/redis/tests/test_async_cluster_connections_logic.rs b/redis/tests/test_async_cluster_connections_logic.rs index 2a5bab6ae..07e41a699 100644 --- a/redis/tests/test_async_cluster_connections_logic.rs +++ b/redis/tests/test_async_cluster_connections_logic.rs @@ -73,6 +73,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let node = assert_full_success(result); @@ -109,6 +110,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let (node, _) = assert_partial_result(result); @@ -127,6 +129,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let (node, _) = assert_partial_result(result); @@ -160,6 +163,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let node = assert_full_success(result); @@ -197,6 +201,7 @@ mod test_connect_and_check { RefreshConnectionType::AllConnections, None, None, + None, ) .await; let err = result.get_error().unwrap(); @@ -248,6 +253,7 @@ mod test_connect_and_check { RefreshConnectionType::OnlyManagementConnection, Some(node), None, + None, ) .await; let node = assert_full_success(result); @@ -295,6 +301,7 @@ mod test_connect_and_check { RefreshConnectionType::OnlyManagementConnection, Some(node), None, + None, ) .await; let (node, _) = assert_partial_result(result); @@ -357,6 +364,7 @@ mod test_connect_and_check { RefreshConnectionType::OnlyUserConnection, Some(node), None, + None, ) .await; let node = assert_full_success(result); diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 7d1249c3e..8c1d0d7e0 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -21,7 +21,7 @@ mod cluster_async { use std::ops::Add; use redis::{ - aio::{ConnectionLike, MultiplexedConnection}, + aio::{ConnectionLike, DisconnectNotifier, MultiplexedConnection}, cluster::ClusterClient, cluster_async::{testing::MANAGEMENT_CONN_NAME, ClusterConnection, Connect}, cluster_routing::{ @@ -44,6 +44,60 @@ mod cluster_async { )) } + fn validate_subscriptions( + pubsub_subs: &PubSubSubscriptionInfo, + notifications_rx: &mut mpsc::UnboundedReceiver, + allow_disconnects: bool, + ) { + let mut subscribe_cnt = + if let Some(exact_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Exact) { + exact_subs.len() + } else { + 0 + }; + + let mut psubscribe_cnt = + if let Some(pattern_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Pattern) { + pattern_subs.len() + } else { + 0 + }; + + let mut ssubscribe_cnt = + if let Some(sharded_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Sharded) { + sharded_subs.len() + } else { + 0 + }; + + for _ in 0..(subscribe_cnt + psubscribe_cnt + ssubscribe_cnt) { + let result = notifications_rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data: _ } = result.unwrap(); + assert!( + kind == PushKind::Subscribe + || kind == PushKind::PSubscribe + || kind == PushKind::SSubscribe + || if allow_disconnects { + kind == PushKind::Disconnection + } else { + false + } + ); + if kind == PushKind::Subscribe { + subscribe_cnt -= 1; + } else if kind == PushKind::PSubscribe { + psubscribe_cnt -= 1; + } else if kind == PushKind::SSubscribe { + ssubscribe_cnt -= 1; + } + } + + assert!(subscribe_cnt == 0); + assert!(psubscribe_cnt == 0); + assert!(ssubscribe_cnt == 0); + } + #[test] fn test_async_cluster_basic_cmd() { let cluster = TestClusterContext::new(3, 0); @@ -382,7 +436,7 @@ mod cluster_async { .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); let mut conn = client - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await .unwrap_or_else(|e| panic!("Failed to get connection: {e}")); @@ -482,6 +536,7 @@ mod cluster_async { connection_timeout: std::time::Duration, socket_addr: Option, push_sender: Option>, + disconnect_notifier: Option>, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a, @@ -493,6 +548,7 @@ mod cluster_async { connection_timeout, socket_addr, push_sender, + disconnect_notifier, ) .await?; Ok((ErrorConnection { inner }, None)) @@ -521,6 +577,10 @@ mod cluster_async { fn get_db(&self) -> i64 { self.inner.get_db() } + + fn is_closed(&self) -> bool { + true + } } #[test] @@ -2683,546 +2743,522 @@ mod cluster_async { } #[test] - fn test_async_cluster_restore_resp3_pubsub_state_after_complete_server_disconnect() { - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder.retries(3).use_protocol(ProtocolVersion::RESP3), - // //|builder| builder.retries(3), - // false, - // ); - - // block_on_all(async move { - // let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); - // let mut connection = cluster.async_connection(Some(tx.clone())).await; - // // assuming the implementation of TestCluster assigns the slots monotonicaly incerasing with the nodes - // let route_0 = redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); - // let node_0_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route_0); - // let route_2 = redis::cluster_routing::Route::new(16 * 1024 - 1, redis::cluster_routing::SlotAddr::Master); - // let node_2_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route_2); - - // let result = connection - // .route_command(&redis::Cmd::new().arg("SUBSCRIBE").arg("test_channel"), RoutingInfo::SingleNode(node_0_route.clone())) - // //.route_command(&redis::Cmd::new().arg("SUBSCRIBE").arg("test_channel"), RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) - // .await; - - // assert_eq!( - // result, - // Ok(Value::Push { - // kind: PushKind::Subscribe, - // data: vec![Value::BulkString("test_channel".into()), Value::Int(1)], - // }) - // ); - - // // pull out all the subscribe notification, this push notification is due to the previous subscribe command - // let result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Subscribe, - // vec![ - // Value::BulkString("test_channel".as_bytes().to_vec()), - // Value::Int(1), - // ] - // ) - // ); - - // // ensure subscription, routing on the same node, expected return Int(1) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_0"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // // ensure subscription, routing on different node, expected return Int(0) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_2"), RoutingInfo::SingleNode(node_2_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // for i in vec![0, 2] { - // let result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - // println!("^^^^^^^^^ '{:?} -> {:?}'", kind, data); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Message, - // vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString(format!("test_message_from_node_{}", i).into()), - // ] - // ) - // ); - // } - - // // drop and recreate cluster and connections - // drop(cluster); - // println!("*********** DROPPED **********"); - - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder.retries(3).use_protocol(ProtocolVersion::RESP3), - // //|builder| builder.retries(3), - // false, - // ); - - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_0"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // //sleep(futures_time::time::Duration::from_secs(15)).await; - // //return Ok(()); - - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder.retries(3).use_protocol(ProtocolVersion::RESP3), - // //|builder| builder.retries(3), - // false, - // ); - - // // ensure subscription state restore - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_0"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // // non-subscribed channel - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_1").arg("should_not_receive"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // // ensure subscription, routing on different node, expected return Int(0) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel").arg("test_message_from_node_2"), RoutingInfo::SingleNode(node_2_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // // should produce an arbitrary number of 'disconnected' notifications - 1 for the intitial try after the drop and an unknown? amout during reconnecting procedure - // // Notifications become available ONLY after we try to send the commands, since push manager does not register TCP disconnect on a idle socket - // // Remove the any amount of 'disconnected' notifications - // sleep(futures_time::time::Duration::from_secs(1)).await; - // //let mut result = rx.recv().await; - // let mut result = rx.try_recv(); - // assert!(result.is_ok()); - // //assert!(result.is_some()); - // loop { - // let kind = result.clone().unwrap().kind; - // if kind != PushKind::Disconnection && kind != PushKind::Subscribe { - // break; - // } - // // result = rx.recv().await; - // // assert!(result.is_some()); - // result = rx.try_recv(); - // assert!(result.is_ok()); - // } - - // // ensure messages test_message_from_node_0 and test_message_from_node_2 - // let mut msg_from_0 = false; - // let mut msg_from_2 = false; - // while !msg_from_0 && !msg_from_2 { - // let mut result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - - // assert!(kind == PushKind::Disconnection || kind == PushKind::Subscribe || kind == PushKind::Message); - // if kind == PushKind::Disconnection || kind == PushKind::Subscribe { - // // ignore - // continue; - // } - - // if data == vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString("test_message_from_node_0".into())] { - // assert!(!msg_from_0); - // msg_from_0 = true; - // } - // else if data == vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString("test_message_from_node_2".into())] { - // assert!(!msg_from_2); - // msg_from_2 = true; - // } - // else { - // assert!(false, "Unexpected message received"); - // } - // } - - // // let mut msg_from_0 = false; - // // let mut msg_from_2 = false; - // // while !msg_from_2 { - // // let mut result = rx.recv().await; - // // assert!(result.is_some()); - // // let PushInfo { kind, data } = result.unwrap(); - - // // assert!(kind == PushKind::Disconnection || kind == PushKind::Subscribe || kind == PushKind::Message); - // // if kind == PushKind::Disconnection || kind == PushKind::Subscribe { - // // // ignore - // // continue; - // // } - - // // if data == vec![ - // // Value::BulkString("test_channel".into()), - // // Value::BulkString("test_message_from_node_2".into())] { - // // assert!(!msg_from_2); - // // msg_from_2 = true; - // // } - // // else { - // // assert!(false, "Unexpected message received"); - // // } - // // } - - // Ok(()) - // }) - // .unwrap(); + fn test_async_cluster_test_fast_reconnect() { + // Note the 3 seconds connection check to differentiate between notifications and periodic + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(0) + .periodic_connections_checks(Duration::from_secs(3)) + }, + false, + ); + + // For tokio-comp, do 3 consequtive disconnects and ensure reconnects succeeds in less than 100ms, + // which is more than enough for local connections even with TLS. + // More than 1 run is done to ensure it is the fast reconnect notification that trigger the reconnect + // and not the periodic interval. + // For other async implementation, only periodic connection check is available, hence, + // do 1 run sleeping for periodic connection check interval, allowing it to reestablish connections + block_on_all(async move { + let mut disconnecting_con = cluster.async_connection(None).await; + let mut monitoring_con = cluster.async_connection(None).await; + + #[cfg(feature = "tokio-comp")] + let tries = 0..3; + #[cfg(not(feature = "tokio-comp"))] + let tries = 0..1; + + for _ in tries { + // get connection id + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("ID"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let id = { + match res { + Value::Int(id) => id, + _ => { + panic!("Wrong return value for CLIENT ID command: {:?}", res); + } + } + }; + + // ask server to kill the connection + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("KILL").arg("ID").arg(id).arg("SKIPME").arg("NO"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + // assert server has closed connection + assert_eq!(res, Ok(Value::Int(1))); + + #[cfg(feature = "tokio-comp")] + // ensure reconnect happened in less than 100ms + sleep(futures_time::time::Duration::from_millis(100)).await; + + #[cfg(not(feature = "tokio-comp"))] + // no fast notification is available, wait for 1 periodic check + overhead + sleep(futures_time::time::Duration::from_secs(3 + 1)).await; + + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("LIST").arg("TYPE").arg("NORMAL"); + let res = monitoring_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let client_list: String = { + match res { + // RESP2 + Value::BulkString(client_info) => { + // ensure 4 connections - 2 for each client, its save to unwrap here + String::from_utf8(client_info).unwrap() + } + // RESP3 + Value::VerbatimString { format: _, text } => text, + _ => { + panic!("Wrong return type for CLIENT LIST command: {:?}", res); + } + } + }; + assert_eq!(client_list.chars().filter(|&x| x == '\n').count(), 4); + } + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_restore_resp3_pubsub_state_passive_disconnect() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel".as_bytes())]), + )]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + // note topology change detection is not activated since no topology change is expected + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + .periodic_connections_checks(Duration::from_secs(1)) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // simulate passive disconnect + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let _cluster = + TestClusterContext::new_with_cluster_client_builder(3, 0, |builder| builder, false); + + // sleep for 1 periodic_connections_checks + overhead + sleep(futures_time::time::Duration::from_secs(1 + 1)).await; + + // new subscription notifications due to resubscriptions + validate_subscriptions(&client_subscriptions, &mut rx, true); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + Ok(()) + }) + .unwrap(); } #[test] - fn test_async_cluster_restore_resp3_pubsub_state_after_scale_in() { - - // let client_subscriptions = PubSubSubscriptionInfo::from( - // [ - // (PubSubSubscriptionKind::Exact, HashSet::from( - // [ - // // test_channel_? is used as it maps to the last node in both 3 and 6 node config - // // (assuming slots allocation is monotonicaly increasing starting from node 0) - // PubSubChannelOrPattern::from(b"test_channel_?") - // ]) - // ) - // ] - // ); - - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 6, - // 0, - // |builder| builder - // .retries(3) - // .use_protocol(ProtocolVersion::RESP3) - // .pubsub_subscriptions(client_subscriptions.clone()), - // false, - // ); - - // block_on_all(async move { - // let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); - // let mut connection = cluster.async_connection(Some(tx.clone())).await; - - // // short sleep to allow the server to push subscription notification - // sleep(futures_time::time::Duration::from_secs(1)).await; - // let result = rx.try_recv(); - // assert!(result.is_ok()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Subscribe, - // vec![ - // Value::BulkString("test_channel_?".into()), - // Value::Int(1), - // ] - // ) - // ); - - // let slot_14212 = get_slot(b"test_channel_?"); - // assert_eq!(slot_14212, 14212); - // let slot_14212_route = redis::cluster_routing::Route::new(slot_14212, redis::cluster_routing::SlotAddr::Master); - // let node_5_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); - - // let result = connection - // //.route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_msg"), RoutingInfo::SingleNode(node_5_route.clone())) - // .route_command(&redis::Cmd::new().arg("PING"), RoutingInfo::SingleNode(node_5_route.clone())) - // .await; - // // let slot_0_route = redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); - // // let node_0_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_0_route); - - // let result = cmd("PUBLISH") - // .arg("test_channel_?") - // .arg("test_message") - // .query_async(&mut connection) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // sleep(futures_time::time::Duration::from_secs(1)).await; - // let result = rx.try_recv(); - // assert!(result.is_ok()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Message, - // vec![ - // Value::BulkString("test_channel_?".into()), - // Value::BulkString(format!("test_message").into()), - // ] - // ) - // ); - - // // simulate scale in - // drop(cluster); - // println!("*********** DROPPED **********"); - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder - // .retries(6) - // .use_protocol(ProtocolVersion::RESP3) - // .pubsub_subscriptions(client_subscriptions.clone()), - // false, - // ); - - // sleep(futures_time::time::Duration::from_secs(3)).await; - - // //ensure subscription notification due to resubscription - // // let result = cmd("PUBLISH") - // // .arg("test_channel_?") - // // .arg("test_message") - // // .query_async(&mut connection) - // // .await; - // let result = connection - // //.route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_msg"), RoutingInfo::SingleNode(node_5_route.clone())) - // .route_command(&redis::Cmd::new().arg("PING"), RoutingInfo::SingleNode(node_5_route.clone())) - // .await; - // // assert_eq!( - // // result, - // // Ok(Value::Int(1)) - // // ); - - // let slot_14212 = get_slot(b"test_channel_?"); - // assert_eq!(slot_14212, 14212); - // let slot_14212_route = redis::cluster_routing::Route::new(slot_14212, redis::cluster_routing::SlotAddr::Master); - // let node_2_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message"), RoutingInfo::SingleNode(node_2_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // sleep(futures_time::time::Duration::from_secs(1)).await; - // let result = rx.try_recv(); - // assert!(result.is_ok()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Subscribe, - // vec![ - // Value::BulkString("test_channel_?".into()), - // Value::Int(1), - // ] - // ) - // ); - - // let result = rx.try_recv(); - // assert!(result.is_ok()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Disconnection, - // vec![], - // ) - // ); - - // return Ok(()); - - // // Subscribe on the slot 14212, this slot will reside on the last node in both 3 and 6 nodes cluster, - // // When the cluster is recreated with 3 nodes, this slot will reside on different network address. - // // Assuming the implementation of TestCluster assigns the slots monotonicaly incerasing with the nodes - // let slot_14212 = get_slot(b"test_channel_?"); - // assert_eq!(slot_14212, 14212); - // let slot_14212_route = redis::cluster_routing::Route::new(slot_14212, redis::cluster_routing::SlotAddr::Master); - // let node_5_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); - - // let slot_0_route = redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); - // let node_0_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_0_route); - - // let result = connection - // .route_command(&redis::Cmd::new().arg("SUBSCRIBE").arg("test_channel_?"), RoutingInfo::SingleNode(node_5_route.clone())) - // //.route_command(&redis::Cmd::new().arg("SUBSCRIBE").arg("test_channel"), RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) - // .await; - - // assert_eq!( - // result, - // Ok(Value::Push { - // kind: PushKind::Subscribe, - // data: vec![Value::BulkString("test_channel_?".into()), Value::Int(1)], - // }) - // ); - - // // pull out all the subscribe notification, this push notification is due to the previous subscribe command - // let result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Subscribe, - // vec![ - // Value::BulkString("test_channel_?".as_bytes().to_vec()), - // Value::Int(1), - // ] - // ) - // ); - - // // ensure subscription, routing on the last node, expected return Int(1) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message_from_node_5"), RoutingInfo::SingleNode(node_5_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // // ensure subscription, routing on the first node, expected return Int(0) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message_from_node_0"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // for i in vec![5, 0] { - // let result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - // println!("^^^^^^^^^ '{:?} -> {:?}'", kind, data); - // assert_eq!( - // (kind, data), - // ( - // PushKind::Message, - // vec![ - // Value::BulkString("test_channel_?".into()), - // Value::BulkString(format!("test_message_from_node_{}", i).into()), - // ] - // ) - // ); - // } - - // // drop and recreate cluster and connections - // drop(cluster); - // println!("*********** DROPPED **********"); - - // let cluster = TestClusterContext::new_with_cluster_client_builder( - // 3, - // 0, - // |builder| builder.retries(3).use_protocol(ProtocolVersion::RESP3), - // //|builder| builder.retries(3), - // false, - // ); - - // // ensure subscription state restore - // let node_2_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message_from_node_2"), RoutingInfo::SingleNode(node_2_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(1)) - // ); - - // // non-subscribed channel - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_1").arg("should_not_receive"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // // ensure subscription, routing on different node, expected return Int(0) - // let result = connection - // .route_command(&redis::Cmd::new().arg("PUBLISH").arg("test_channel_?").arg("test_message_from_node_2"), RoutingInfo::SingleNode(node_0_route.clone())) - // .await; - // assert_eq!( - // result, - // Ok(Value::Int(0)) - // ); - - // // should produce an arbitrary number of 'disconnected' notifications - 1 for the intitial try after the drop and an unknown? amout during reconnecting procedure - // // Notifications become available ONLY after we try to send the commands, since push manager does not register TCP disconnect on a idle socket - // // Remove the any amount of 'disconnected' notifications - // sleep(futures_time::time::Duration::from_secs(1)).await; - // //let mut result = rx.recv().await; - // let mut result = rx.try_recv(); - // assert!(result.is_ok()); - // //assert!(result.is_some()); - // loop { - // let kind = result.clone().unwrap().kind; - // if kind != PushKind::Disconnection && kind != PushKind::Subscribe { - // break; - // } - // // result = rx.recv().await; - // // assert!(result.is_some()); - // result = rx.try_recv(); - // assert!(result.is_ok()); - // } - - // // ensure messages test_message_from_node_0 and test_message_from_node_2 - // let mut msg_from_0 = false; - // let mut msg_from_2 = false; - // while !msg_from_0 && !msg_from_2 { - // let mut result = rx.recv().await; - // assert!(result.is_some()); - // let PushInfo { kind, data } = result.unwrap(); - - // assert!(kind == PushKind::Disconnection || kind == PushKind::Subscribe || kind == PushKind::Message); - // if kind == PushKind::Disconnection || kind == PushKind::Subscribe { - // // ignore - // continue; - // } - - // if data == vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString("test_message_from_node_0".into())] { - // assert!(!msg_from_0); - // msg_from_0 = true; - // } - // else if data == vec![ - // Value::BulkString("test_channel".into()), - // Value::BulkString("test_message_from_node_2".into())] { - // assert!(!msg_from_2); - // msg_from_2 = true; - // } - // else { - // assert!(false, "Unexpected message received"); - // } - // } - - // Ok(()) - // }) - // .unwrap(); + fn test_async_cluster_restore_resp3_pubsub_state_after_scale_out() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([ + // test_channel_? is used as it maps to 14212 slot, which is the last node in both 3 and 6 node config + // (assuming slots allocation is monotonicaly increasing starting from node 0) + ( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ), + ]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + let slot_14212 = get_slot(b"test_channel_?"); + assert_eq!(slot_14212, 14212); + + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + // periodic connection check is required to detect the disconnect from the last node + .periodic_connections_checks(Duration::from_secs(1)) + // periodic topology check is required to detect topology change + .periodic_topology_checks(Duration::from_secs(1)) + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // drop and recreate a cluster with more nodes + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let cluster = + TestClusterContext::new_with_cluster_client_builder(6, 0, |builder| builder, false); + + // assume slot 14212 will reside in the last node + let last_server_port = { + let addr = cluster.cluster.servers.last().unwrap().addr.clone(); + match addr { + redis::ConnectionAddr::TcpTls { + host: _, + port, + insecure: _, + tls_params: _, + } => port, + redis::ConnectionAddr::Tcp(_, port) => port, + _ => { + panic!("Wrong server address type: {:?}", addr); + } + } + }; + + // wait for new topology discovery + loop { + let mut cmd = redis::cmd("INFO"); + cmd.arg("SERVER"); + let res = publishing_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + slot_14212, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + match res { + Value::VerbatimString { format: _, text } => { + if text.contains(format!("tcp_port:{}", last_server_port).as_str()) { + // new topology rediscovered + break; + } + } + _ => { + panic!("Wrong return type for INFO SERVER command: {:?}", res); + } + } + sleep(futures_time::time::Duration::from_secs(1)).await; + } + + // sleep for one one cycle of topology refresh + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate PUBLISH + let result = redis::cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + loop { + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + // ignore disconnection and subscription notifications due to resubscriptions + if kind == PushKind::Message { + assert_eq!( + data, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ); + break; + } + } + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + drop(publishing_con); + drop(_listening_con); + + Ok(()) + }) + .unwrap(); + + block_on_all(async move { + sleep(futures_time::time::Duration::from_secs(10)).await; + Ok(()) + }) + .unwrap(); } - //#[allow(unreachable_code)] #[test] fn test_async_cluster_resp3_pubsub() { let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); @@ -3268,39 +3304,10 @@ mod cluster_async { // short sleep to allow the server to push subscription notification sleep(futures_time::time::Duration::from_secs(1)).await; - let mut subscribe_cnt = client_subscriptions[&PubSubSubscriptionKind::Exact].len(); - let mut psubscribe_cnt = client_subscriptions[&PubSubSubscriptionKind::Pattern].len(); - let mut ssubscribe_cnt = 0; - if let Some(sharded_shubs) = client_subscriptions.get(&PubSubSubscriptionKind::Sharded) - { - ssubscribe_cnt += sharded_shubs.len() - } - for _ in 0..(subscribe_cnt + psubscribe_cnt + ssubscribe_cnt) { - let result = rx.try_recv(); - assert!(result.is_ok()); - let PushInfo { kind, data: _ } = result.unwrap(); - assert!( - kind == PushKind::Subscribe - || kind == PushKind::PSubscribe - || kind == PushKind::SSubscribe - ); - if kind == PushKind::Subscribe { - subscribe_cnt -= 1; - } else if kind == PushKind::PSubscribe { - psubscribe_cnt -= 1; - } else { - ssubscribe_cnt -= 1; - } - } - - assert!(subscribe_cnt == 0); - assert!(psubscribe_cnt == 0); - assert!(ssubscribe_cnt == 0); + validate_subscriptions(&client_subscriptions, &mut rx, false); let slot_14212 = get_slot(b"test_channel_?"); assert_eq!(slot_14212, 14212); - //let slot_14212_route = redis::cluster_routing::Route::new(slot_14212, redis::cluster_routing::SlotAddr::Master); - //let node_5_route = redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_14212_route); let slot_0_route = redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); diff --git a/redis/tests/test_sentinel.rs b/redis/tests/test_sentinel.rs index 53ff86e48..0782c8b6d 100644 --- a/redis/tests/test_sentinel.rs +++ b/redis/tests/test_sentinel.rs @@ -283,7 +283,7 @@ pub mod async_tests { .await .unwrap(); let mut replica_con = replica_client - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await .unwrap(); @@ -316,7 +316,7 @@ pub mod async_tests { .await .unwrap(); let mut replica_con = replica_client - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await .unwrap(); @@ -338,12 +338,14 @@ pub mod async_tests { let master_client = sentinel .async_master_for(master_name, Some(&node_conn_info)) .await?; - let mut master_con = master_client.get_multiplexed_async_connection(None).await?; + let mut master_con = master_client + .get_multiplexed_async_connection(None, None) + .await?; let mut replica_con = sentinel .async_replica_for(master_name, Some(&node_conn_info)) .await? - .get_multiplexed_async_connection(None) + .get_multiplexed_async_connection(None, None) .await?; async_assert_is_connection_to_master(&mut master_con).await; @@ -367,7 +369,9 @@ pub mod async_tests { let master_client = sentinel .async_master_for(master_name, Some(&node_conn_info)) .await?; - let mut master_con = master_client.get_multiplexed_async_connection(None).await?; + let mut master_con = master_client + .get_multiplexed_async_connection(None, None) + .await?; async_assert_is_connection_to_master(&mut master_con).await; @@ -408,7 +412,9 @@ pub mod async_tests { let master_client = sentinel .async_master_for(master_name, Some(&node_conn_info)) .await?; - let mut master_con = master_client.get_multiplexed_async_connection(None).await?; + let mut master_con = master_client + .get_multiplexed_async_connection(None, None) + .await?; async_assert_is_connection_to_master(&mut master_con).await; From 73ff308dacde912555e6ce28ae3f3a76f7db0fb1 Mon Sep 17 00:00:00 2001 From: ikolomi Date: Tue, 3 Sep 2024 12:53:13 +0300 Subject: [PATCH 2/3] Agregate Glide-specific connection params into a struct --- redis/examples/async-await.rs | 6 +- redis/examples/async-connection-loss.rs | 8 +- redis/examples/async-multiplexed.rs | 4 +- redis/examples/async-pub-sub.rs | 6 +- redis/examples/async-scan.rs | 6 +- redis/src/aio/connection_manager.rs | 4 +- redis/src/aio/multiplexed_connection.rs | 17 ++-- redis/src/client.rs | 94 ++++++++----------- redis/src/cluster_async/connections_logic.rs | 67 +++++-------- redis/src/cluster_async/mod.rs | 57 +++++------ redis/src/lib.rs | 1 + redis/src/sentinel.rs | 15 ++- redis/tests/support/mock_cluster.rs | 10 +- redis/tests/support/mod.rs | 11 ++- redis/tests/test_async.rs | 18 ++-- redis/tests/test_async_async_std.rs | 4 +- .../test_async_cluster_connections_logic.rs | 26 ++--- redis/tests/test_cluster_async.rs | 16 ++-- redis/tests/test_sentinel.rs | 14 +-- 19 files changed, 176 insertions(+), 208 deletions(-) diff --git a/redis/examples/async-await.rs b/redis/examples/async-await.rs index b52776a46..2d829c7d6 100644 --- a/redis/examples/async-await.rs +++ b/redis/examples/async-await.rs @@ -1,10 +1,12 @@ #![allow(unknown_lints, dependency_on_unit_never_type_fallback)] -use redis::AsyncCommands; +use redis::{AsyncCommands, GlideConnectionOptions}; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut con = client.get_multiplexed_async_connection(None, None).await?; + let mut con = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; con.set("key1", b"foo").await?; diff --git a/redis/examples/async-connection-loss.rs b/redis/examples/async-connection-loss.rs index 90af361f2..a7dba3ab8 100644 --- a/redis/examples/async-connection-loss.rs +++ b/redis/examples/async-connection-loss.rs @@ -13,6 +13,7 @@ use std::time::Duration; use futures::future; use redis::aio::ConnectionLike; +use redis::GlideConnectionOptions; use redis::RedisResult; use tokio::time::interval; @@ -81,7 +82,12 @@ async fn main() -> RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); match mode { Mode::Default => { - run_multi(client.get_multiplexed_tokio_connection(None, None).await?).await? + run_multi( + client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await?, + ) + .await? } Mode::Reconnect => run_multi(client.get_connection_manager().await?).await?, #[allow(deprecated)] diff --git a/redis/examples/async-multiplexed.rs b/redis/examples/async-multiplexed.rs index 9c8c73235..2e5332359 100644 --- a/redis/examples/async-multiplexed.rs +++ b/redis/examples/async-multiplexed.rs @@ -1,6 +1,6 @@ #![allow(unknown_lints, dependency_on_unit_never_type_fallback)] use futures::prelude::*; -use redis::{aio::MultiplexedConnection, RedisResult}; +use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult}; async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> { let mut con = con.clone(); @@ -35,7 +35,7 @@ async fn main() { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); let con = client - .get_multiplexed_tokio_connection(None, None) + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) .await .unwrap(); diff --git a/redis/examples/async-pub-sub.rs b/redis/examples/async-pub-sub.rs index 15634e2b0..fe84b44fb 100644 --- a/redis/examples/async-pub-sub.rs +++ b/redis/examples/async-pub-sub.rs @@ -1,11 +1,13 @@ #![allow(unknown_lints, dependency_on_unit_never_type_fallback)] use futures_util::StreamExt as _; -use redis::AsyncCommands; +use redis::{AsyncCommands, GlideConnectionOptions}; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut publish_conn = client.get_multiplexed_async_connection(None, None).await?; + let mut publish_conn = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; let mut pubsub_conn = client.get_async_pubsub().await?; pubsub_conn.subscribe("wavephone").await?; diff --git a/redis/examples/async-scan.rs b/redis/examples/async-scan.rs index 6f55ac933..06a66fe83 100644 --- a/redis/examples/async-scan.rs +++ b/redis/examples/async-scan.rs @@ -1,11 +1,13 @@ #![allow(unknown_lints, dependency_on_unit_never_type_fallback)] use futures::stream::StreamExt; -use redis::{AsyncCommands, AsyncIter}; +use redis::{AsyncCommands, AsyncIter, GlideConnectionOptions}; #[tokio::main] async fn main() -> redis::RedisResult<()> { let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut con = client.get_multiplexed_async_connection(None, None).await?; + let mut con = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; con.set("async-key1", b"foo").await?; con.set("async-key2", b"foo").await?; diff --git a/redis/src/aio/connection_manager.rs b/redis/src/aio/connection_manager.rs index 741086d76..61df9bc31 100644 --- a/redis/src/aio/connection_manager.rs +++ b/redis/src/aio/connection_manager.rs @@ -1,4 +1,5 @@ use super::RedisFuture; +use crate::client::GlideConnectionOptions; use crate::cmd::Cmd; use crate::push_manager::PushManager; use crate::types::{RedisError, RedisResult, Value}; @@ -195,8 +196,7 @@ impl ConnectionManager { client.get_multiplexed_async_connection_with_timeouts( response_timeout, connection_timeout, - None, - None, + GlideConnectionOptions::default(), ) }) .await diff --git a/redis/src/aio/multiplexed_connection.rs b/redis/src/aio/multiplexed_connection.rs index c085461cf..1067bc2df 100644 --- a/redis/src/aio/multiplexed_connection.rs +++ b/redis/src/aio/multiplexed_connection.rs @@ -1,12 +1,13 @@ use super::{ConnectionLike, Runtime}; use crate::aio::setup_connection; use crate::aio::DisconnectNotifier; +use crate::client::GlideConnectionOptions; use crate::cmd::Cmd; #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] use crate::parser::ValueCodec; use crate::push_manager::PushManager; use crate::types::{RedisError, RedisFuture, RedisResult, Value}; -use crate::{cmd, ConnectionInfo, ProtocolVersion, PushInfo, PushKind}; +use crate::{cmd, ConnectionInfo, ProtocolVersion, PushKind}; use ::tokio::{ io::{AsyncRead, AsyncWrite}, sync::{mpsc, oneshot}, @@ -416,8 +417,7 @@ impl MultiplexedConnection { pub async fn new( connection_info: &ConnectionInfo, stream: C, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<(Self, impl Future)> where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, @@ -426,8 +426,7 @@ impl MultiplexedConnection { connection_info, stream, std::time::Duration::MAX, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await } @@ -438,8 +437,7 @@ impl MultiplexedConnection { connection_info: &ConnectionInfo, stream: C, response_timeout: std::time::Duration, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<(Self, impl Future)> where C: Unpin + AsyncRead + AsyncWrite + Send + 'static, @@ -457,10 +455,11 @@ impl MultiplexedConnection { let codec = ValueCodec::default() .framed(stream) .and_then(|msg| async move { msg }); - let (mut pipeline, driver) = Pipeline::new(codec, disconnect_notifier); + let (mut pipeline, driver) = + Pipeline::new(codec, glide_connection_options.disconnect_notifier); let driver = boxed(driver); let pm = PushManager::default(); - if let Some(sender) = push_sender { + if let Some(sender) = glide_connection_options.push_sender { pm.replace_sender(sender); } diff --git a/redis/src/client.rs b/redis/src/client.rs index 534c186d9..5e3f144e7 100644 --- a/redis/src/client.rs +++ b/redis/src/client.rs @@ -78,6 +78,16 @@ impl Client { } } +/// Glide-specific connection options +#[derive(Clone, Default)] +pub struct GlideConnectionOptions { + /// Queue for RESP3 notifications + pub push_sender: Option>, + #[cfg(feature = "aio")] + /// Passive disconnect notifier + pub disconnect_notifier: Option>, +} + /// To enable async support you need to chose one of the supported runtimes and active its /// corresponding feature: `tokio-comp` or `async-std-comp` #[cfg(feature = "aio")] @@ -149,14 +159,12 @@ impl Client { )] pub async fn get_multiplexed_async_connection( &self, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult { self.get_multiplexed_async_connection_with_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await } @@ -171,8 +179,7 @@ impl Client { &self, response_timeout: std::time::Duration, connection_timeout: std::time::Duration, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult { let result = match Runtime::locate() { #[cfg(feature = "tokio-comp")] @@ -182,8 +189,7 @@ impl Client { self.get_multiplexed_async_connection_inner::( response_timeout, None, - push_sender, - disconnect_notifier, + glide_connection_options, ), ) .await @@ -195,8 +201,7 @@ impl Client { self.get_multiplexed_async_connection_inner::( response_timeout, None, - push_sender, - disconnect_notifier, + glide_connection_options, ), ) .await @@ -220,8 +225,7 @@ impl Client { )] pub async fn get_multiplexed_async_connection_and_ip( &self, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> { match Runtime::locate() { #[cfg(feature = "tokio-comp")] @@ -229,8 +233,7 @@ impl Client { self.get_multiplexed_async_connection_inner::( Duration::MAX, None, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await } @@ -239,8 +242,7 @@ impl Client { self.get_multiplexed_async_connection_inner::( Duration::MAX, None, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await } @@ -257,8 +259,7 @@ impl Client { &self, response_timeout: std::time::Duration, connection_timeout: std::time::Duration, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult { let result = Runtime::locate() .timeout( @@ -266,8 +267,7 @@ impl Client { self.get_multiplexed_async_connection_inner::( response_timeout, None, - push_sender, - disconnect_notifier, + glide_connection_options, ), ) .await; @@ -287,14 +287,12 @@ impl Client { #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] pub async fn get_multiplexed_tokio_connection( &self, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult { self.get_multiplexed_tokio_connection_with_response_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await } @@ -309,8 +307,7 @@ impl Client { &self, response_timeout: std::time::Duration, connection_timeout: std::time::Duration, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult { let result = Runtime::locate() .timeout( @@ -318,8 +315,7 @@ impl Client { self.get_multiplexed_async_connection_inner::( response_timeout, None, - push_sender, - disconnect_notifier, + glide_connection_options, ), ) .await; @@ -339,14 +335,12 @@ impl Client { #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] pub async fn get_multiplexed_async_std_connection( &self, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult { self.get_multiplexed_async_std_connection_with_timeouts( std::time::Duration::MAX, std::time::Duration::MAX, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await } @@ -362,8 +356,7 @@ impl Client { pub async fn create_multiplexed_tokio_connection_with_response_timeout( &self, response_timeout: std::time::Duration, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -371,8 +364,7 @@ impl Client { self.create_multiplexed_async_connection_inner::( response_timeout, None, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await .map(|(conn, driver, _ip)| (conn, driver)) @@ -387,16 +379,14 @@ impl Client { #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] pub async fn create_multiplexed_tokio_connection( &self, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, )> { self.create_multiplexed_tokio_connection_with_response_timeout( std::time::Duration::MAX, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await .map(|conn_res| (conn_res.0, conn_res.1)) @@ -413,8 +403,7 @@ impl Client { pub async fn create_multiplexed_async_std_connection_with_response_timeout( &self, response_timeout: std::time::Duration, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -422,8 +411,7 @@ impl Client { self.create_multiplexed_async_connection_inner::( response_timeout, None, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await .map(|(conn, driver, _ip)| (conn, driver)) @@ -438,16 +426,14 @@ impl Client { #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] pub async fn create_multiplexed_async_std_connection( &self, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, )> { self.create_multiplexed_async_std_connection_with_response_timeout( std::time::Duration::MAX, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await } @@ -650,8 +636,7 @@ impl Client { &self, response_timeout: std::time::Duration, socket_addr: Option, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> where T: crate::aio::RedisRuntime, @@ -660,8 +645,7 @@ impl Client { .create_multiplexed_async_connection_inner::( response_timeout, socket_addr, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await?; T::spawn(driver); @@ -672,8 +656,7 @@ impl Client { &self, response_timeout: std::time::Duration, socket_addr: Option, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult<( crate::aio::MultiplexedConnection, impl std::future::Future, @@ -687,8 +670,7 @@ impl Client { &self.connection_info, con, response_timeout, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await .map(|res| (res.0, res.1, ip)) diff --git a/redis/src/cluster_async/connections_logic.rs b/redis/src/cluster_async/connections_logic.rs index dc3fd82d0..7de249300 100644 --- a/redis/src/cluster_async/connections_logic.rs +++ b/redis/src/cluster_async/connections_logic.rs @@ -6,15 +6,14 @@ use super::{ }; use crate::{ aio::{ConnectionLike, DisconnectNotifier, Runtime}, + client::GlideConnectionOptions, cluster::get_connection_info, cluster_client::ClusterParams, - push_manager::PushInfo, ErrorKind, RedisError, RedisResult, }; use futures::prelude::*; use futures_util::{future::BoxFuture, join}; -use tokio::sync::mpsc; use tracing::warn; pub(crate) type ConnectionFuture = futures::future::Shared>; @@ -56,8 +55,7 @@ pub(crate) async fn get_or_create_conn( node: Option>, params: &ClusterParams, conn_type: RefreshConnectionType, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult> where C: ConnectionLike + Send + Clone + Sync + Connect + 'static, @@ -73,8 +71,7 @@ where None, conn_type, Some(node), - push_sender, - disconnect_notifier, + glide_connection_options, ) .await .get_node(), @@ -86,8 +83,7 @@ where None, conn_type, None, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await .get_node() @@ -111,8 +107,7 @@ pub(crate) async fn connect_and_check_all_connections( addr: &str, params: ClusterParams, socket_addr: Option, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> ConnectAndCheckResult where C: ConnectionLike + Connect + Send + Sync + 'static + Clone, @@ -122,17 +117,15 @@ where addr, params.clone(), socket_addr, - push_sender.clone(), false, - disconnect_notifier.clone(), + glide_connection_options.clone(), ), create_connection( addr, params.clone(), socket_addr, - push_sender, true, - disconnect_notifier, + glide_connection_options, ), ) .await @@ -188,9 +181,11 @@ where addr, params.clone(), socket_addr, - None, true, - disconnect_notifier, + GlideConnectionOptions { + push_sender: None, + disconnect_notifier, + }, ) .await { @@ -269,8 +264,7 @@ pub async fn connect_and_check( socket_addr: Option, conn_type: RefreshConnectionType, node: Option>, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> ConnectAndCheckResult where C: ConnectionLike + Connect + Send + Sync + 'static + Clone, @@ -281,8 +275,7 @@ where addr, params.clone(), socket_addr, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await { @@ -301,7 +294,7 @@ where params, socket_addr, node, - disconnect_notifier, + glide_connection_options.disconnect_notifier, ) .await } @@ -310,22 +303,15 @@ where addr, params, socket_addr, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await } } } RefreshConnectionType::AllConnections => { - connect_and_check_all_connections( - addr, - params, - socket_addr, - push_sender, - disconnect_notifier, - ) - .await + connect_and_check_all_connections(addr, params, socket_addr, glide_connection_options) + .await } } } @@ -334,8 +320,7 @@ async fn create_and_setup_user_connection( node: &str, params: ClusterParams, socket_addr: Option, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult> where C: ConnectionLike + Connect + Send + 'static, @@ -344,9 +329,8 @@ where node, params.clone(), socket_addr, - push_sender, false, - disconnect_notifier, + glide_connection_options, ) .await?; setup_user_connection(&mut connection.conn, params).await?; @@ -386,9 +370,8 @@ async fn create_connection( node: &str, mut params: ClusterParams, socket_addr: Option, - push_sender: Option>, is_management: bool, - disconnect_notifier: Option>, + mut glide_connection_options: GlideConnectionOptions, ) -> RedisResult> where C: ConnectionLike + Connect + Send + 'static, @@ -401,17 +384,15 @@ where } let info = get_connection_info(node, params)?; // management connection does not require notifications or disconnect notifications + if is_management { + glide_connection_options.disconnect_notifier = None; + } C::connect( info, response_timeout, connection_timeout, socket_addr, - if !is_management { push_sender } else { None }, - if !is_management { - disconnect_notifier - } else { - None - }, + glide_connection_options, ) .await .map(|conn| conn.into()) diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index aa5ea47f1..9f2ab6c35 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -30,6 +30,7 @@ pub mod testing { pub use super::connections_logic::*; } use crate::{ + client::GlideConnectionOptions, cluster_routing::{Routable, RoutingInfo}, cluster_slotmap::SlotMap, cluster_topology::SLOT_SIZE, @@ -57,8 +58,11 @@ use std::{ #[cfg(feature = "tokio-comp")] use tokio::task::JoinHandle; +#[cfg(feature = "tokio-comp")] +use crate::aio::DisconnectNotifier; + use crate::{ - aio::{get_socket_addrs, ConnectionLike, DisconnectNotifier, MultiplexedConnection, Runtime}, + aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection, Runtime}, cluster::slot_cmd, cluster_async::connections_logic::{ get_host_and_port_from_addr, get_or_create_conn, ConnectionFuture, RefreshConnectionType, @@ -399,10 +403,9 @@ pub(crate) struct InnerCore { pending_requests: Mutex>>, slot_refresh_state: SlotRefreshState, initial_nodes: Vec, - push_sender: Option>, subscriptions_by_address: RwLock>, unassigned_subscriptions: RwLock, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, #[cfg(feature = "tokio-comp")] tokio_notify: Arc, } @@ -1004,8 +1007,10 @@ where let connections = Self::create_initial_connections( initial_nodes, &cluster_params, - push_sender.clone(), - disconnect_notifier.clone(), + GlideConnectionOptions { + push_sender: push_sender.clone(), + disconnect_notifier: disconnect_notifier.clone(), + }, ) .await?; @@ -1022,7 +1027,6 @@ where pending_requests: Mutex::new(Vec::new()), slot_refresh_state: SlotRefreshState::new(slots_refresh_rate_limiter), initial_nodes: initial_nodes.to_vec(), - push_sender: push_sender.clone(), unassigned_subscriptions: RwLock::new( if let Some(subs) = cluster_params.pubsub_subscriptions { subs.clone() @@ -1031,7 +1035,10 @@ where }, ), subscriptions_by_address: RwLock::new(Default::default()), - disconnect_notifier: disconnect_notifier.clone(), + glide_connection_options: GlideConnectionOptions { + push_sender: push_sender.clone(), + disconnect_notifier: disconnect_notifier.clone(), + }, #[cfg(feature = "tokio-comp")] tokio_notify, }); @@ -1125,18 +1132,16 @@ where async fn create_initial_connections( initial_nodes: &[ConnectionInfo], params: &ClusterParams, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisResult> { let initial_nodes: Vec<(String, Option)> = Self::try_to_expand_initial_nodes(initial_nodes).await; let connections = stream::iter(initial_nodes.iter().cloned()) .map(|(node_addr, socket_addr)| { let mut params: ClusterParams = params.clone(); - let push_sender = push_sender.clone(); + let glide_connection_options = glide_connection_options.clone(); // set subscriptions to none, they will be applied upon the topology discovery params.pubsub_subscriptions = None; - let disconnect_notifier = disconnect_notifier.clone(); async move { let result = connect_and_check( @@ -1145,8 +1150,7 @@ where socket_addr, RefreshConnectionType::AllConnections, None, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await .get_node(); @@ -1192,8 +1196,7 @@ where let connection_map = match Self::create_initial_connections( &inner.initial_nodes, &inner.cluster_params, - None, - inner.disconnect_notifier.clone(), + inner.glide_connection_options.clone(), ) .await { @@ -1292,8 +1295,7 @@ where let connections_container = inner.conn_lock.read().await; let cluster_params = &inner.cluster_params; let subscriptions_by_address = &inner.subscriptions_by_address; - let push_sender = &inner.push_sender; - let disconnect_notifier = &inner.disconnect_notifier; + let glide_connection_optons = &inner.glide_connection_options; stream::iter(addresses.into_iter()) .fold( @@ -1315,8 +1317,7 @@ where node_option, &cluster_params, conn_type, - push_sender.clone(), - disconnect_notifier.clone(), + glide_connection_optons.clone(), ) .await; match node { @@ -1769,8 +1770,7 @@ where node, &cluster_params, RefreshConnectionType::AllConnections, - inner.push_sender.clone(), - inner.disconnect_notifier.clone(), + inner.glide_connection_options.clone(), ) .await; if let Ok(node) = node { @@ -2065,8 +2065,7 @@ where None, RefreshConnectionType::AllConnections, None, - core.push_sender.clone(), - core.disconnect_notifier.clone(), + core.glide_connection_options.clone(), ) .await .get_node() @@ -2503,8 +2502,7 @@ pub trait Connect: Sized { response_timeout: Duration, connection_timeout: Duration, socket_addr: Option, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a; @@ -2516,8 +2514,7 @@ impl Connect for MultiplexedConnection { response_timeout: Duration, connection_timeout: Duration, socket_addr: Option, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisFuture<'a, (MultiplexedConnection, Option)> where T: IntoConnectionInfo + Send + 'a, @@ -2534,8 +2531,7 @@ impl Connect for MultiplexedConnection { client.get_multiplexed_async_connection_inner::( response_timeout, socket_addr, - push_sender, - disconnect_notifier, + glide_connection_options, ), ) .await? @@ -2546,8 +2542,7 @@ impl Connect for MultiplexedConnection { .get_multiplexed_async_connection_inner::( response_timeout, socket_addr, - push_sender, - disconnect_notifier, + glide_connection_options, )) .await? } diff --git a/redis/src/lib.rs b/redis/src/lib.rs index a348fe0c6..4f138c2bb 100644 --- a/redis/src/lib.rs +++ b/redis/src/lib.rs @@ -364,6 +364,7 @@ assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); // public api pub use crate::client::Client; +pub use crate::client::GlideConnectionOptions; pub use crate::cmd::{cmd, pack_command, pipe, Arg, Cmd, Iter}; pub use crate::commands::{ Commands, ControlFlow, Direction, LposOptions, PubSubCommands, SetOptions, diff --git a/redis/src/sentinel.rs b/redis/src/sentinel.rs index 8b853f643..ac6aac65c 100644 --- a/redis/src/sentinel.rs +++ b/redis/src/sentinel.rs @@ -112,8 +112,8 @@ use rand::Rng; use crate::aio::MultiplexedConnection as AsyncConnection; use crate::{ - connection::ConnectionInfo, types::RedisResult, Client, Cmd, Connection, ErrorKind, - FromRedisValue, IntoConnectionInfo, RedisConnectionInfo, TlsMode, Value, + client::GlideConnectionOptions, connection::ConnectionInfo, types::RedisResult, Client, Cmd, + Connection, ErrorKind, FromRedisValue, IntoConnectionInfo, RedisConnectionInfo, TlsMode, Value, }; /// The Sentinel type, serves as a special purpose client which builds other clients on @@ -301,7 +301,10 @@ fn find_valid_master( #[cfg(feature = "aio")] async fn async_check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { if let Ok(client) = Client::open(connection_info.clone()) { - if let Ok(mut conn) = client.get_multiplexed_async_connection(None, None).await { + if let Ok(mut conn) = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + { let result: RedisResult> = crate::cmd("ROLE").query_async(&mut conn).await; return check_role_result(&result, target_role); } @@ -366,7 +369,7 @@ async fn async_reconnect( ) -> RedisResult<()> { let sentinel_client = Client::open(connection_info.clone())?; let new_connection = sentinel_client - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await?; connection.replace(new_connection); Ok(()) @@ -768,6 +771,8 @@ impl SentinelClient { #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] pub async fn get_async_connection(&mut self) -> RedisResult { let client = self.async_get_client().await?; - client.get_multiplexed_async_connection(None, None).await + client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await } } diff --git a/redis/tests/support/mock_cluster.rs b/redis/tests/support/mock_cluster.rs index 93acff5e8..ce91988ce 100644 --- a/redis/tests/support/mock_cluster.rs +++ b/redis/tests/support/mock_cluster.rs @@ -1,6 +1,6 @@ use redis::{ cluster::{self, ClusterClient, ClusterClientBuilder}, - ErrorKind, FromRedisValue, PushInfo, RedisError, + ErrorKind, FromRedisValue, GlideConnectionOptions, RedisError, }; use std::{ @@ -18,8 +18,6 @@ use { redis::{IntoConnectionInfo, RedisResult, Value}, }; -use tokio::sync::mpsc; - #[cfg(feature = "cluster-async")] use redis::{aio, cluster_async, RedisFuture}; @@ -29,9 +27,6 @@ use futures::future; #[cfg(feature = "cluster-async")] use tokio::runtime::Runtime; -#[cfg(feature = "aio")] -use redis::aio::DisconnectNotifier; - type Handler = Arc Result<(), RedisResult> + Send + Sync>; pub struct MockConnectionBehavior { @@ -137,8 +132,7 @@ impl cluster_async::Connect for MockConnection { _response_timeout: Duration, _connection_timeout: Duration, _socket_addr: Option, - _push_sender: Option>, - _disconnect_notifier: Option>, + _glide_connection_options: GlideConnectionOptions, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a, diff --git a/redis/tests/support/mod.rs b/redis/tests/support/mod.rs index 96ce71e6a..8169ee94b 100644 --- a/redis/tests/support/mod.rs +++ b/redis/tests/support/mod.rs @@ -21,6 +21,9 @@ use redis::{ClientTlsConfig, TlsCertificates}; use socket2::{Domain, Socket, Type}; use tempfile::TempDir; +#[cfg(feature = "aio")] +use redis::GlideConnectionOptions; + pub fn use_protocol() -> ProtocolVersion { if env::var("PROTOCOL").unwrap_or_default() == "RESP3" { ProtocolVersion::RESP3 @@ -502,7 +505,7 @@ impl TestContext { #[cfg(feature = "aio")] pub async fn async_connection(&self) -> redis::RedisResult { self.client - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await } @@ -516,7 +519,7 @@ impl TestContext { &self, ) -> redis::RedisResult { self.client - .get_multiplexed_async_std_connection(None, None) + .get_multiplexed_async_std_connection(GlideConnectionOptions::default()) .await } @@ -536,7 +539,7 @@ impl TestContext { &self, ) -> redis::RedisResult { self.client - .get_multiplexed_tokio_connection(None, None) + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) .await } @@ -545,7 +548,7 @@ impl TestContext { &self, ) -> redis::RedisResult { self.client - .get_multiplexed_async_std_connection(None, None) + .get_multiplexed_async_std_connection(GlideConnectionOptions::default()) .await } diff --git a/redis/tests/test_async.rs b/redis/tests/test_async.rs index f7c892a26..d16f1e069 100644 --- a/redis/tests/test_async.rs +++ b/redis/tests/test_async.rs @@ -9,7 +9,8 @@ mod basic_async { use futures::{prelude::*, StreamExt}; use redis::{ aio::{ConnectionLike, MultiplexedConnection}, - cmd, pipe, AsyncCommands, ErrorKind, PushInfo, PushKind, RedisResult, Value, + cmd, pipe, AsyncCommands, ErrorKind, GlideConnectionOptions, PushInfo, PushKind, + RedisResult, Value, }; use tokio::sync::mpsc::error::TryRecvError; @@ -100,7 +101,7 @@ mod basic_async { fn dont_panic_on_closed_multiplexed_connection() { let ctx = TestContext::new(); let client = ctx.client.clone(); - let connect = client.get_multiplexed_async_connection(None, None); + let connect = client.get_multiplexed_async_connection(GlideConnectionOptions::default()); drop(ctx); block_on_all(async move { @@ -584,7 +585,7 @@ mod basic_async { let client = redis::Client::open(coninfo).unwrap(); let err = client - .get_multiplexed_tokio_connection(None, None) + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) .await .err() .unwrap(); @@ -916,7 +917,10 @@ mod basic_async { let millisecond = std::time::Duration::from_millis(1); let mut retries = 0; loop { - match client.get_multiplexed_async_connection(None, None).await { + match client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + { Err(err) => { if err.is_connection_refusal() { tokio::time::sleep(millisecond).await; @@ -986,7 +990,8 @@ mod basic_async { let client = build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, true) .unwrap(); - let connect = client.get_multiplexed_async_connection(None, None); + let connect = + client.get_multiplexed_async_connection(GlideConnectionOptions::default()); block_on_all(connect.and_then(|mut con| async move { redis::cmd("SET") .arg("key1") @@ -1007,7 +1012,8 @@ mod basic_async { let client = build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, false) .unwrap(); - let connect = client.get_multiplexed_async_connection(None, None); + let connect = + client.get_multiplexed_async_connection(GlideConnectionOptions::default()); let result = block_on_all(connect.and_then(|mut con| async move { redis::cmd("SET") .arg("key1") diff --git a/redis/tests/test_async_async_std.rs b/redis/tests/test_async_async_std.rs index ae2ae8443..656d1979f 100644 --- a/redis/tests/test_async_async_std.rs +++ b/redis/tests/test_async_async_std.rs @@ -3,7 +3,7 @@ use futures::prelude::*; use crate::support::*; -use redis::{aio::MultiplexedConnection, RedisResult}; +use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult}; mod support; @@ -61,7 +61,7 @@ fn test_args_async_std() { fn dont_panic_on_closed_multiplexed_connection() { let ctx = TestContext::new(); let client = ctx.client.clone(); - let connect = client.get_multiplexed_async_std_connection(None, None); + let connect = client.get_multiplexed_async_std_connection(GlideConnectionOptions::default()); drop(ctx); block_on_all_using_async_std(async move { diff --git a/redis/tests/test_async_cluster_connections_logic.rs b/redis/tests/test_async_cluster_connections_logic.rs index 07e41a699..0230d1de1 100644 --- a/redis/tests/test_async_cluster_connections_logic.rs +++ b/redis/tests/test_async_cluster_connections_logic.rs @@ -5,7 +5,7 @@ mod support; use redis::{ cluster_async::testing::{AsyncClusterNode, RefreshConnectionType}, testing::ClusterParams, - ErrorKind, + ErrorKind, GlideConnectionOptions, }; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; @@ -72,8 +72,7 @@ mod test_connect_and_check { None, RefreshConnectionType::AllConnections, None, - None, - None, + GlideConnectionOptions::default(), ) .await; let node = assert_full_success(result); @@ -109,8 +108,7 @@ mod test_connect_and_check { None, RefreshConnectionType::AllConnections, None, - None, - None, + GlideConnectionOptions::default(), ) .await; let (node, _) = assert_partial_result(result); @@ -128,8 +126,7 @@ mod test_connect_and_check { None, RefreshConnectionType::AllConnections, None, - None, - None, + GlideConnectionOptions::default(), ) .await; let (node, _) = assert_partial_result(result); @@ -162,8 +159,7 @@ mod test_connect_and_check { None, RefreshConnectionType::AllConnections, None, - None, - None, + GlideConnectionOptions::default(), ) .await; let node = assert_full_success(result); @@ -200,8 +196,7 @@ mod test_connect_and_check { None, RefreshConnectionType::AllConnections, None, - None, - None, + GlideConnectionOptions::default(), ) .await; let err = result.get_error().unwrap(); @@ -252,8 +247,7 @@ mod test_connect_and_check { None, RefreshConnectionType::OnlyManagementConnection, Some(node), - None, - None, + GlideConnectionOptions::default(), ) .await; let node = assert_full_success(result); @@ -300,8 +294,7 @@ mod test_connect_and_check { None, RefreshConnectionType::OnlyManagementConnection, Some(node), - None, - None, + GlideConnectionOptions::default(), ) .await; let (node, _) = assert_partial_result(result); @@ -363,8 +356,7 @@ mod test_connect_and_check { None, RefreshConnectionType::OnlyUserConnection, Some(node), - None, - None, + GlideConnectionOptions::default(), ) .await; let node = assert_full_success(result); diff --git a/redis/tests/test_cluster_async.rs b/redis/tests/test_cluster_async.rs index 8c1d0d7e0..4d0883d47 100644 --- a/redis/tests/test_cluster_async.rs +++ b/redis/tests/test_cluster_async.rs @@ -21,7 +21,7 @@ mod cluster_async { use std::ops::Add; use redis::{ - aio::{ConnectionLike, DisconnectNotifier, MultiplexedConnection}, + aio::{ConnectionLike, MultiplexedConnection}, cluster::ClusterClient, cluster_async::{testing::MANAGEMENT_CONN_NAME, ClusterConnection, Connect}, cluster_routing::{ @@ -29,9 +29,9 @@ mod cluster_async { }, cluster_topology::{get_slot, DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES}, cmd, from_owned_redis_value, parse_redis_value, AsyncCommands, Cmd, ErrorKind, - FromRedisValue, InfoDict, IntoConnectionInfo, ProtocolVersion, PubSubChannelOrPattern, - PubSubSubscriptionInfo, PubSubSubscriptionKind, PushInfo, PushKind, RedisError, - RedisFuture, RedisResult, Script, Value, + FromRedisValue, GlideConnectionOptions, InfoDict, IntoConnectionInfo, ProtocolVersion, + PubSubChannelOrPattern, PubSubSubscriptionInfo, PubSubSubscriptionKind, PushInfo, PushKind, + RedisError, RedisFuture, RedisResult, Script, Value, }; use crate::support::*; @@ -436,7 +436,7 @@ mod cluster_async { .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); let mut conn = client - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await .unwrap_or_else(|e| panic!("Failed to get connection: {e}")); @@ -535,8 +535,7 @@ mod cluster_async { response_timeout: std::time::Duration, connection_timeout: std::time::Duration, socket_addr: Option, - push_sender: Option>, - disconnect_notifier: Option>, + glide_connection_options: GlideConnectionOptions, ) -> RedisFuture<'a, (Self, Option)> where T: IntoConnectionInfo + Send + 'a, @@ -547,8 +546,7 @@ mod cluster_async { response_timeout, connection_timeout, socket_addr, - push_sender, - disconnect_notifier, + glide_connection_options, ) .await?; Ok((ErrorConnection { inner }, None)) diff --git a/redis/tests/test_sentinel.rs b/redis/tests/test_sentinel.rs index 0782c8b6d..24cd13bd6 100644 --- a/redis/tests/test_sentinel.rs +++ b/redis/tests/test_sentinel.rs @@ -239,7 +239,7 @@ pub mod async_tests { use redis::{ aio::MultiplexedConnection, sentinel::{Sentinel, SentinelClient, SentinelNodeConnectionInfo}, - Client, ConnectionAddr, RedisError, + Client, ConnectionAddr, GlideConnectionOptions, RedisError, }; use crate::{assert_is_master_role, assert_replica_role_and_master_addr, support::*}; @@ -283,7 +283,7 @@ pub mod async_tests { .await .unwrap(); let mut replica_con = replica_client - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await .unwrap(); @@ -316,7 +316,7 @@ pub mod async_tests { .await .unwrap(); let mut replica_con = replica_client - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await .unwrap(); @@ -339,13 +339,13 @@ pub mod async_tests { .async_master_for(master_name, Some(&node_conn_info)) .await?; let mut master_con = master_client - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await?; let mut replica_con = sentinel .async_replica_for(master_name, Some(&node_conn_info)) .await? - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await?; async_assert_is_connection_to_master(&mut master_con).await; @@ -370,7 +370,7 @@ pub mod async_tests { .async_master_for(master_name, Some(&node_conn_info)) .await?; let mut master_con = master_client - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await?; async_assert_is_connection_to_master(&mut master_con).await; @@ -413,7 +413,7 @@ pub mod async_tests { .async_master_for(master_name, Some(&node_conn_info)) .await?; let mut master_con = master_client - .get_multiplexed_async_connection(None, None) + .get_multiplexed_async_connection(GlideConnectionOptions::default()) .await?; async_assert_is_connection_to_master(&mut master_con).await; From 24c19dd79a3c200f26a18f032f747755d4957980 Mon Sep 17 00:00:00 2001 From: ikolomi Date: Tue, 3 Sep 2024 17:08:36 +0300 Subject: [PATCH 3/3] CR changes: Add async method to DisconnectNotifier trait, styling and other cleanups --- redis/src/aio/mod.rs | 5 ++ redis/src/cluster_async/mod.rs | 117 ++++++++++++++++++--------------- 2 files changed, 69 insertions(+), 53 deletions(-) diff --git a/redis/src/aio/mod.rs b/redis/src/aio/mod.rs index 737ad82a7..ffe2c9e3a 100644 --- a/redis/src/aio/mod.rs +++ b/redis/src/aio/mod.rs @@ -12,6 +12,7 @@ use std::net::SocketAddr; #[cfg(unix)] use std::path::Path; use std::pin::Pin; +use std::time::Duration; /// Enables the async_std compatibility #[cfg(feature = "async-std-comp")] @@ -91,10 +92,14 @@ pub trait ConnectionLike { } /// Implements ability to notify about disconnection events +#[async_trait] pub trait DisconnectNotifier: Send + Sync { /// Notify about disconnect event fn notify_disconnect(&mut self); + /// Wait for disconnect event with timeout + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration); + /// Intended to be used with Box fn clone_box(&self) -> Box; } diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 9f2ab6c35..2225062b7 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -91,6 +91,8 @@ use backoff_std_async::future::retry; #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] use backoff_std_async::{Error as BackoffError, ExponentialBackoff}; +#[cfg(feature = "tokio-comp")] +use async_trait::async_trait; #[cfg(feature = "tokio-comp")] use backoff_tokio::future::retry; #[cfg(feature = "tokio-comp")] @@ -379,20 +381,37 @@ where #[cfg(feature = "tokio-comp")] #[derive(Clone)] struct TokioDisconnectNotifier { - pub disconnect_notifier: Arc, + disconnect_notifier: Arc, } #[cfg(feature = "tokio-comp")] +#[async_trait] impl DisconnectNotifier for TokioDisconnectNotifier { fn notify_disconnect(&mut self) { self.disconnect_notifier.notify_one(); } + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) { + let _ = timeout(*max_wait, async { + self.disconnect_notifier.notified().await; + }) + .await; + } + fn clone_box(&self) -> Box { Box::new(self.clone()) } } +#[cfg(feature = "tokio-comp")] +impl TokioDisconnectNotifier { + fn new() -> TokioDisconnectNotifier { + TokioDisconnectNotifier { + disconnect_notifier: Arc::new(Notify::new()), + } + } +} + type ConnectionMap = connections_container::ConnectionsMap>; type ConnectionsContainer = self::connections_container::ConnectionsContainer>; @@ -406,8 +425,6 @@ pub(crate) struct InnerCore { subscriptions_by_address: RwLock>, unassigned_subscriptions: RwLock, glide_connection_options: GlideConnectionOptions, - #[cfg(feature = "tokio-comp")] - tokio_notify: Arc, } pub(crate) type Core = Arc>; @@ -990,27 +1007,24 @@ where cluster_params: ClusterParams, push_sender: Option>, ) -> RedisResult> { - #[cfg(feature = "tokio-comp")] - let tokio_notify = Arc::new(Notify::new()); - let disconnect_notifier = { #[cfg(feature = "tokio-comp")] { - Some::>(Box::new(TokioDisconnectNotifier { - disconnect_notifier: tokio_notify.clone(), - })) + Some::>(Box::new(TokioDisconnectNotifier::new())) } #[cfg(not(feature = "tokio-comp"))] None }; + let glide_connection_options = GlideConnectionOptions { + push_sender, + disconnect_notifier, + }; + let connections = Self::create_initial_connections( initial_nodes, &cluster_params, - GlideConnectionOptions { - push_sender: push_sender.clone(), - disconnect_notifier: disconnect_notifier.clone(), - }, + glide_connection_options.clone(), ) .await?; @@ -1035,12 +1049,7 @@ where }, ), subscriptions_by_address: RwLock::new(Default::default()), - glide_connection_options: GlideConnectionOptions { - push_sender: push_sender.clone(), - disconnect_notifier: disconnect_notifier.clone(), - }, - #[cfg(feature = "tokio-comp")] - tokio_notify, + glide_connection_options, }); let mut connection = ClusterConnInner { inner, @@ -1227,40 +1236,40 @@ where // In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server. async fn validate_all_user_connections(inner: Arc>) { let mut all_valid_conns = HashMap::new(); - let mut all_nodes_with_slots = HashSet::new(); // prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts - { - let mut nodes_to_delete = Vec::new(); - let connections_container = inner.conn_lock.read().await; - - connections_container - .slot_map - .addresses_for_all_nodes() - .iter() - .for_each(|addr| { - all_nodes_with_slots.insert(String::from(*addr)); - }); + let mut nodes_to_delete = Vec::new(); + let connections_container = inner.conn_lock.read().await; - connections_container - .all_node_connections() - .for_each(|(addr, con)| { - if all_nodes_with_slots.contains(&addr) { - all_valid_conns.insert(addr.clone(), con.clone()); - } else { - nodes_to_delete.push(addr.clone()); - } - }); + let all_nodes_with_slots: HashSet = connections_container + .slot_map + .addresses_for_all_nodes() + .iter() + .map(|addr| String::from(*addr)) + .collect(); + + connections_container + .all_node_connections() + .for_each(|(addr, con)| { + if all_nodes_with_slots.contains(&addr) { + all_valid_conns.insert(addr.clone(), con.clone()); + } else { + nodes_to_delete.push(addr.clone()); + } + }); - for addr in &nodes_to_delete { - connections_container.remove_node(addr); - } + for addr in &nodes_to_delete { + connections_container.remove_node(addr); } + drop(connections_container); + // identify nodes with closed connection let mut addrs_to_refresh = Vec::new(); for (addr, con_fut) in &all_valid_conns { let con = con_fut.clone().await; + // connection object might be present despite the transport being closed if con.is_closed() { + // transport is closed, need to refresh addrs_to_refresh.push(addr.clone()); } } @@ -1289,7 +1298,7 @@ where inner: Arc>, addresses: Vec, conn_type: RefreshConnectionType, - try_existing_node: bool, + check_existing_conn: bool, ) { info!("Started refreshing connections to {:?}", addresses); let connections_container = inner.conn_lock.read().await; @@ -1301,10 +1310,10 @@ where .fold( &*connections_container, |connections_container, address| async move { - let node_option = if try_existing_node { + let node_option = if check_existing_conn { connections_container.remove_node(&address) } else { - Option::None + None }; // override subscriptions for this connection @@ -1541,13 +1550,15 @@ where async fn connections_validation_task(inner: Arc>, interval_duration: Duration) { loop { - #[cfg(feature = "tokio-comp")] - let _ = timeout(interval_duration, async { - inner.tokio_notify.notified().await; - }) - .await; - #[cfg(not(feature = "tokio-comp"))] - let _ = boxed_sleep(interval_duration).await; + if let Some(disconnect_notifier) = + inner.glide_connection_options.disconnect_notifier.clone() + { + disconnect_notifier + .wait_for_disconnect_with_timeout(&interval_duration) + .await; + } else { + let _ = boxed_sleep(interval_duration).await; + } Self::validate_all_user_connections(inner.clone()).await; }