From 6ce81dbece796e24d70621fdf64f48e66c3dbc9b Mon Sep 17 00:00:00 2001 From: barshaul Date: Tue, 27 Aug 2024 08:13:19 +0000 Subject: [PATCH] SlotMap refactor: Added new NodesMap, changed shard addresses to be shard between shard nodes and slot map values --- redis/src/cluster.rs | 34 ++- .../cluster_async/connections_container.rs | 12 +- redis/src/cluster_async/mod.rs | 40 +-- redis/src/cluster_routing.rs | 38 +-- redis/src/cluster_slotmap.rs | 229 +++++++++++------- redis/src/cluster_topology.rs | 51 ++-- redis/src/commands/cluster_scan.rs | 51 ++-- 7 files changed, 245 insertions(+), 210 deletions(-) diff --git a/redis/src/cluster.rs b/redis/src/cluster.rs index 5c0702d857..cec907b122 100644 --- a/redis/src/cluster.rs +++ b/redis/src/cluster.rs @@ -35,14 +35,14 @@ //! .expire(key, 60).ignore() //! .query(&mut connection).unwrap(); //! ``` +use rand::{seq::IteratorRandom, thread_rng, Rng}; use std::cell::RefCell; use std::collections::HashSet; use std::str::FromStr; +use std::sync::Arc; use std::thread; use std::time::Duration; -use rand::{seq::IteratorRandom, thread_rng, Rng}; - use crate::cluster_pipeline::UNROUTABLE_ERROR; use crate::cluster_routing::{ MultipleNodeRoutingInfo, ResponsePolicy, Routable, SingleNodeRoutingInfo, SlotAddr, @@ -343,22 +343,20 @@ where let mut slots = self.slots.borrow_mut(); *slots = self.create_new_slots()?; - let mut nodes = slots.values().flatten().collect::>(); - nodes.sort_unstable(); - nodes.dedup(); - + let nodes = slots.all_node_addresses(); let mut connections = self.connections.borrow_mut(); *connections = nodes .into_iter() .filter_map(|addr| { - if connections.contains_key(addr) { - let mut conn = connections.remove(addr).unwrap(); + let addr = addr.to_string(); + if connections.contains_key(&addr) { + let mut conn = connections.remove(&addr).unwrap(); if conn.check_connection() { return Some((addr.to_string(), conn)); } } - if let Ok(mut conn) = self.connect(addr) { + if let Ok(mut conn) = self.connect(&addr) { if conn.check_connection() { return Some((addr.to_string(), conn)); } @@ -424,7 +422,7 @@ where if let Some(addr) = slots.slot_addr_for_route(route) { Ok(( addr.to_string(), - self.get_connection_by_addr(connections, addr)?, + self.get_connection_by_addr(connections, &addr)?, )) } else { // try a random node next. This is safe if slots are involved @@ -495,13 +493,13 @@ where fn execute_on_all<'a>( &'a self, input: Input, - addresses: HashSet<&'a str>, + addresses: HashSet>, connections: &'a mut HashMap, - ) -> Vec> { + ) -> Vec, Value)>> { addresses .into_iter() .map(|addr| { - let connection = self.get_connection_by_addr(connections, addr)?; + let connection = self.get_connection_by_addr(connections, &addr)?; match input { Input::Slice { cmd, routable: _ } => connection.req_packed_command(cmd), Input::Cmd(cmd) => connection.req_command(cmd), @@ -526,8 +524,8 @@ where input: Input, slots: &'a mut SlotMap, connections: &'a mut HashMap, - ) -> Vec> { - self.execute_on_all(input, slots.addresses_for_all_nodes(), connections) + ) -> Vec, Value)>> { + self.execute_on_all(input, slots.all_node_addresses(), connections) } fn execute_on_all_primaries<'a>( @@ -535,7 +533,7 @@ where input: Input, slots: &'a mut SlotMap, connections: &'a mut HashMap, - ) -> Vec> { + ) -> Vec, Value)>> { self.execute_on_all(input, slots.addresses_for_all_primaries(), connections) } @@ -545,7 +543,7 @@ where slots: &'a mut SlotMap, connections: &'a mut HashMap, routes: &'b [(Route, Vec)], - ) -> Vec> + ) -> Vec, Value)>> where 'b: 'a, { @@ -557,7 +555,7 @@ where ErrorKind::IoError, "Couldn't find connection", )))?; - let connection = self.get_connection_by_addr(connections, addr)?; + let connection = self.get_connection_by_addr(connections, &addr)?; let (_, indices) = routes.get(index).unwrap(); let cmd = crate::cluster_routing::command_for_multi_slot_indices(&input, indices.iter()); diff --git a/redis/src/cluster_async/connections_container.rs b/redis/src/cluster_async/connections_container.rs index 2bfbb8b934..396d2dddf1 100644 --- a/redis/src/cluster_async/connections_container.rs +++ b/redis/src/cluster_async/connections_container.rs @@ -147,18 +147,14 @@ where /// Returns true if the address represents a known primary node. pub(crate) fn is_primary(&self, address: &String) -> bool { - self.connection_for_address(address).is_some() - && self - .slot_map - .values() - .any(|slot_addrs| slot_addrs.primary.as_str() == address) + self.connection_for_address(address).is_some() && self.slot_map.is_primary(address) } fn round_robin_read_from_replica( &self, slot_map_value: &SlotMapValue, ) -> Option> { - let addrs = &slot_map_value.addrs; + let addrs = &slot_map_value.addrs.read().unwrap(); let initial_index = slot_map_value .latest_used_replica .load(std::sync::atomic::Ordering::Relaxed); @@ -185,7 +181,7 @@ where fn lookup_route(&self, route: &Route) -> Option> { let slot_map_value = self.slot_map.slot_value_for_route(route)?; - let addrs = &slot_map_value.addrs; + let addrs = &slot_map_value.addrs.read().unwrap(); if addrs.replicas.is_empty() { return self.connection_for_address(addrs.primary.as_str()); } @@ -232,7 +228,7 @@ where self.slot_map .addresses_for_all_primaries() .into_iter() - .flat_map(|addr| self.connection_for_address(addr)) + .flat_map(|addr| self.connection_for_address(&addr)) } pub(crate) fn node_for_address(&self, address: &str) -> Option> { diff --git a/redis/src/cluster_async/mod.rs b/redis/src/cluster_async/mod.rs index 46c7d4a729..06aa877366 100644 --- a/redis/src/cluster_async/mod.rs +++ b/redis/src/cluster_async/mod.rs @@ -396,7 +396,7 @@ where &self, slot: u16, slot_addr: SlotAddr, - ) -> Option { + ) -> Option> { self.conn_lock .read() .await @@ -444,7 +444,7 @@ where } // return slots of node - pub(crate) async fn get_slots_of_address(&self, node_address: &str) -> Vec { + pub(crate) async fn get_slots_of_address(&self, node_address: Arc) -> Vec { self.conn_lock .read() .await @@ -1020,7 +1020,6 @@ where Self::refresh_slots_and_subscriptions_with_retries( connection.inner.clone(), &RefreshPolicy::NotThrottable, - None, ) .await?; @@ -1164,7 +1163,6 @@ where if let Err(err) = Self::refresh_slots_and_subscriptions_with_retries( inner.clone(), &RefreshPolicy::Throttable, - None, ) .await { @@ -1336,7 +1334,6 @@ where async fn refresh_slots_and_subscriptions_with_retries( inner: Arc>, policy: &RefreshPolicy, - moved_redirect: Option, ) -> RedisResult<()> { let SlotRefreshState { in_progress, @@ -1388,10 +1385,6 @@ where Self::refresh_slots(inner.clone(), curr_retry) }) .await; - } else if moved_redirect.is_some() { - // Update relevant slots in the slots map based on the moved_redirect address, - // rather than refreshing all slots by querying the cluster nodes for their topology view. - Self::update_slots_for_redirect_change(inner.clone(), moved_redirect).await?; } in_progress.store(false, Ordering::Relaxed); @@ -1400,15 +1393,6 @@ where res } - /// Update relevant slots in the slots map based on the moved_redirect address - pub(crate) async fn update_slots_for_redirect_change( - _inner: Arc>, - _moved_redirect: Option, - ) -> RedisResult<()> { - // TODO: Add implementation - Ok(()) - } - /// Determines if the cluster topology has changed and refreshes slots and subscriptions if needed. /// Returns `RedisResult` with `true` if changes were detected and slots were refreshed, /// or `false` if no changes were found. Raises an error if refreshing the topology fails. @@ -1418,7 +1402,7 @@ where ) -> RedisResult { let topology_changed = Self::check_for_topology_diff(inner.clone()).await; if topology_changed { - Self::refresh_slots_and_subscriptions_with_retries(inner.clone(), policy, None).await?; + Self::refresh_slots_and_subscriptions_with_retries(inner.clone(), policy).await?; } Ok(topology_changed) } @@ -1629,21 +1613,20 @@ where .0?; let connections = &*read_guard; // Create a new connection vector of the found nodes - let mut nodes = new_slots.values().flatten().collect::>(); - nodes.sort_unstable(); - nodes.dedup(); + let nodes = new_slots.all_node_addresses(); let nodes_len = nodes.len(); let addresses_and_connections_iter = stream::iter(nodes) .fold( Vec::with_capacity(nodes_len), |mut addrs_and_conns, addr| async move { + let addr = addr.to_string(); if let Some(node) = connections.node_for_address(addr.as_str()) { addrs_and_conns.push((addr, Some(node))); return addrs_and_conns; } // If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name. // We shall check if a connection is already exists under the resolved IP name. - let (host, port) = match get_host_and_port_from_addr(addr) { + let (host, port) = match get_host_and_port_from_addr(&addr) { Some((host, port)) => (host, port), None => { addrs_and_conns.push((addr, None)); @@ -1669,10 +1652,10 @@ where |connections, (addr, node)| async { let mut cluster_params = inner.cluster_params.clone(); let subs_guard = inner.subscriptions_by_address.read().await; - cluster_params.pubsub_subscriptions = subs_guard.get(addr).cloned(); + cluster_params.pubsub_subscriptions = subs_guard.get(&addr).cloned(); drop(subs_guard); let node = get_or_create_conn( - addr, + &addr, node, &cluster_params, RefreshConnectionType::AllConnections, @@ -1680,7 +1663,7 @@ where ) .await; if let Ok(node) = node { - connections.0.insert(addr.into(), node); + connections.0.insert(addr, node); } connections }, @@ -2024,7 +2007,6 @@ where *future = Box::pin(Self::refresh_slots_and_subscriptions_with_retries( self.inner.clone(), &RefreshPolicy::Throttable, - None, )); Poll::Ready(Err(err)) } @@ -2271,12 +2253,12 @@ where match ready!(self.poll_complete(cx)) { PollFlushAction::None => return Poll::Ready(Ok(())), - PollFlushAction::RebuildSlots(moved_redirect) => { + PollFlushAction::RebuildSlots(_moved_redirect) => { + // TODO: Add logic to update the slots map based on the MOVED error self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin( ClusterConnInner::refresh_slots_and_subscriptions_with_retries( self.inner.clone(), &RefreshPolicy::Throttable, - moved_redirect, ), ))); } diff --git a/redis/src/cluster_routing.rs b/redis/src/cluster_routing.rs index 27abd54fea..83e7bb4e43 100644 --- a/redis/src/cluster_routing.rs +++ b/redis/src/cluster_routing.rs @@ -1,11 +1,11 @@ -use std::cmp::min; -use std::collections::HashMap; - use crate::cluster_topology::get_slot; use crate::cmd::{Arg, Cmd}; use crate::types::Value; use crate::{ErrorKind, RedisResult}; +use std::cmp::min; +use std::collections::HashMap; use std::iter::Once; +use std::sync::Arc; #[derive(Clone)] pub(crate) enum Redirect { @@ -866,14 +866,6 @@ impl Slot { } } - pub fn start(&self) -> u16 { - self.start - } - - pub fn end(&self) -> u16 { - self.end - } - #[allow(dead_code)] // used in tests pub(crate) fn master(&self) -> &str { self.master.as_str() @@ -902,25 +894,15 @@ pub enum SlotAddr { /// which stores only the master and [optional] replica /// to avoid the need to choose a replica each time /// a command is executed -#[derive(Debug, Eq, PartialEq)] -pub(crate) struct SlotAddrs { - pub(crate) primary: String, - pub(crate) replicas: Vec, -} - -impl SlotAddrs { - pub(crate) fn new(primary: String, replicas: Vec) -> Self { - Self { primary, replicas } - } - - pub(crate) fn from_slot(slot: Slot) -> Self { - SlotAddrs::new(slot.master, slot.replicas) - } +#[derive(Debug, Eq, PartialEq, Clone, PartialOrd, Ord)] +pub(crate) struct ShardAddrs { + pub(crate) primary: Arc, + pub(crate) replicas: Vec>, } -impl<'a> IntoIterator for &'a SlotAddrs { - type Item = &'a String; - type IntoIter = std::iter::Chain, std::slice::Iter<'a, String>>; +impl<'a> IntoIterator for &'a ShardAddrs { + type Item = &'a Arc; + type IntoIter = std::iter::Chain>, std::slice::Iter<'a, Arc>>; fn into_iter(self) -> Self::IntoIter { std::iter::once(&self.primary).chain(self.replicas.iter()) diff --git a/redis/src/cluster_slotmap.rs b/redis/src/cluster_slotmap.rs index 7f1f70af98..d718b947fe 100644 --- a/redis/src/cluster_slotmap.rs +++ b/redis/src/cluster_slotmap.rs @@ -1,28 +1,22 @@ +use std::sync::Arc; +use std::sync::RwLock; use std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, fmt::Display, sync::atomic::AtomicUsize, }; -use crate::cluster_routing::{Route, Slot, SlotAddr, SlotAddrs}; +use crate::cluster_routing::{Route, ShardAddrs, Slot, SlotAddr}; + +pub(crate) type NodesMap = HashMap, Arc>>; #[derive(Debug)] -pub(crate) struct SlotMapValue { +pub struct SlotMapValue { pub(crate) start: u16, - pub(crate) addrs: SlotAddrs, + pub(crate) addrs: Arc>, pub(crate) latest_used_replica: AtomicUsize, } -impl SlotMapValue { - fn from_slot(slot: Slot) -> Self { - Self { - start: slot.start(), - addrs: SlotAddrs::from_slot(slot), - latest_used_replica: AtomicUsize::new(0), - } - } -} - #[derive(Debug, Default, Clone, PartialEq, Copy)] pub(crate) enum ReadFromReplicaStrategy { #[default] @@ -31,8 +25,9 @@ pub(crate) enum ReadFromReplicaStrategy { } #[derive(Debug, Default)] -pub(crate) struct SlotMap { - pub(crate) slots: BTreeMap, +pub struct SlotMap { + pub slots: BTreeMap, + pub nodes_map: NodesMap, read_from_replica: ReadFromReplicaStrategy, } @@ -40,34 +35,75 @@ fn get_address_from_slot( slot: &SlotMapValue, read_from_replica: ReadFromReplicaStrategy, slot_addr: SlotAddr, -) -> &str { - if slot_addr == SlotAddr::Master || slot.addrs.replicas.is_empty() { - return slot.addrs.primary.as_str(); +) -> Arc { + let addrs = slot.addrs.read().unwrap(); + if slot_addr == SlotAddr::Master || addrs.replicas.is_empty() { + return addrs.primary.clone(); } match read_from_replica { - ReadFromReplicaStrategy::AlwaysFromPrimary => slot.addrs.primary.as_str(), + ReadFromReplicaStrategy::AlwaysFromPrimary => addrs.primary.clone(), ReadFromReplicaStrategy::RoundRobin => { let index = slot .latest_used_replica .fetch_add(1, std::sync::atomic::Ordering::Relaxed) - % slot.addrs.replicas.len(); - slot.addrs.replicas[index].as_str() + % addrs.replicas.len(); + addrs.replicas[index].clone() } } } impl SlotMap { - pub(crate) fn new(slots: Vec, read_from_replica: ReadFromReplicaStrategy) -> Self { - let mut this = Self { + pub fn new(slots: Vec, read_from_replica: ReadFromReplicaStrategy) -> Self { + let mut slot_map = SlotMap { slots: BTreeMap::new(), + nodes_map: HashMap::new(), read_from_replica, }; - this.slots.extend( - slots - .into_iter() - .map(|slot| (slot.end(), SlotMapValue::from_slot(slot))), - ); - this + let mut shard_id = 0; + for slot in slots { + let primary = Arc::new(slot.master); + let replicas: Vec> = slot.replicas.into_iter().map(Arc::new).collect(); + + // Get the shard addresses if the primary is already in nodes_map; + // otherwise, create a new ShardAddrs and add it + let shard_addrs_arc = slot_map + .nodes_map + .entry(primary.clone()) + .or_insert_with(|| { + shard_id += 1; + Arc::new(RwLock::new(ShardAddrs { + primary, + replicas: replicas.clone(), + })) + }) + .clone(); + + // Add all replicas to nodes_map with a reference to the same ShardAddrs if not already present + replicas.iter().for_each(|replica| { + slot_map + .nodes_map + .entry(replica.clone()) + .or_insert(shard_addrs_arc.clone()); + }); + + // Insert the slot value into the slots map + slot_map.slots.insert( + slot.end, + SlotMapValue { + addrs: shard_addrs_arc.clone(), + start: slot.start, + latest_used_replica: AtomicUsize::new(0), + }, + ); + } + + slot_map + } + + pub fn is_primary(&self, address: &String) -> bool { + self.nodes_map.get(address).map_or(false, |shard_addrs| { + *shard_addrs.read().unwrap().primary == *address + }) } pub fn slot_value_for_route(&self, route: &Route) -> Option<&SlotMapValue> { @@ -84,40 +120,27 @@ impl SlotMap { }) } - pub fn slot_addr_for_route(&self, route: &Route) -> Option<&str> { + pub fn slot_addr_for_route(&self, route: &Route) -> Option> { self.slot_value_for_route(route).map(|slot_value| { get_address_from_slot(slot_value, self.read_from_replica, route.slot_addr()) }) } - pub fn values(&self) -> impl Iterator { - self.slots.values().map(|slot_value| &slot_value.addrs) - } - - fn all_unique_addresses(&self, only_primaries: bool) -> HashSet<&str> { - let mut addresses = HashSet::new(); - for slot in self.values() { - addresses.insert(slot.primary.as_str()); - if !only_primaries { - addresses.extend(slot.replicas.iter().map(|str| str.as_str())); - } - } - - addresses - } - - pub fn addresses_for_all_primaries(&self) -> HashSet<&str> { - self.all_unique_addresses(true) + pub fn addresses_for_all_primaries(&self) -> HashSet> { + self.nodes_map + .values() + .map(|shard_addrs| shard_addrs.read().unwrap().primary.clone()) + .collect() } - pub fn addresses_for_all_nodes(&self) -> HashSet<&str> { - self.all_unique_addresses(false) + pub fn all_node_addresses(&self) -> HashSet> { + self.nodes_map.keys().cloned().collect() } pub fn addresses_for_multi_slot<'a, 'b>( &'a self, routes: &'b [(Route, Vec)], - ) -> impl Iterator> + 'a + ) -> impl Iterator>> + 'a where 'b: 'a, { @@ -127,13 +150,13 @@ impl SlotMap { } // Returns the slots that are assigned to the given address. - pub(crate) fn get_slots_of_node(&self, node_address: &str) -> Vec { - let node_address = node_address.to_string(); + pub(crate) fn get_slots_of_node(&self, node_address: Arc) -> Vec { self.slots .iter() .filter_map(|(end, slot_value)| { - if slot_value.addrs.primary == node_address - || slot_value.addrs.replicas.contains(&node_address) + let addr_reader = slot_value.addrs.read().unwrap(); + if addr_reader.primary == node_address + || addr_reader.replicas.contains(&node_address) { Some(slot_value.start..(*end + 1)) } else { @@ -148,13 +171,14 @@ impl SlotMap { &self, slot: u16, slot_addr: SlotAddr, - ) -> Option { + ) -> Option> { self.slots.range(slot..).next().and_then(|(_, slot_value)| { if slot_value.start <= slot { - Some( - get_address_from_slot(slot_value, self.read_from_replica, slot_addr) - .to_string(), - ) + Some(get_address_from_slot( + slot_value, + self.read_from_replica, + slot_addr, + )) } else { None } @@ -171,8 +195,8 @@ impl Display for SlotMap { "({}-{}): primary: {}, replicas: {:?}", slot_map_value.start, end, - slot_map_value.addrs.primary, - slot_map_value.addrs.replicas + slot_map_value.addrs.read().unwrap().primary, + slot_map_value.addrs.read().unwrap().replicas )?; } Ok(()) @@ -180,9 +204,22 @@ impl Display for SlotMap { } #[cfg(test)] -mod tests { +mod tests_cluster_slotmap { use super::*; + fn process_expected(expected: Vec<&str>) -> HashSet> { + as IntoIterator>::into_iter(HashSet::from_iter(expected)) + .map(|s| Arc::new(s.to_string())) + .collect() + } + + fn process_expected_with_option(expected: Vec>) -> Vec> { + expected + .into_iter() + .filter_map(|opt| opt.map(|s| Arc::new(s.to_string()))) + .collect() + } + #[test] fn test_slot_map_retrieve_routes() { let slot_map = SlotMap::new( @@ -208,19 +245,19 @@ mod tests { .is_none()); assert_eq!( "node1:6379", - slot_map + *slot_map .slot_addr_for_route(&Route::new(1, SlotAddr::Master)) .unwrap() ); assert_eq!( "node1:6379", - slot_map + *slot_map .slot_addr_for_route(&Route::new(500, SlotAddr::Master)) .unwrap() ); assert_eq!( "node1:6379", - slot_map + *slot_map .slot_addr_for_route(&Route::new(1000, SlotAddr::Master)) .unwrap() ); @@ -230,19 +267,19 @@ mod tests { assert_eq!( "node2:6379", - slot_map + *slot_map .slot_addr_for_route(&Route::new(1002, SlotAddr::Master)) .unwrap() ); assert_eq!( "node2:6379", - slot_map + *slot_map .slot_addr_for_route(&Route::new(1500, SlotAddr::Master)) .unwrap() ); assert_eq!( "node2:6379", - slot_map + *slot_map .slot_addr_for_route(&Route::new(2000, SlotAddr::Master)) .unwrap() ); @@ -293,17 +330,17 @@ mod tests { let addresses = slot_map.addresses_for_all_primaries(); assert_eq!( addresses, - HashSet::from_iter(["node1:6379", "node2:6379", "node3:6379"]) + process_expected(vec!["node1:6379", "node2:6379", "node3:6379"]) ); } #[test] fn test_slot_map_get_all_nodes() { let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); - let addresses = slot_map.addresses_for_all_nodes(); + let addresses = slot_map.all_node_addresses(); assert_eq!( addresses, - HashSet::from_iter([ + process_expected(vec![ "node1:6379", "node2:6379", "node3:6379", @@ -327,11 +364,11 @@ mod tests { let addresses = slot_map .addresses_for_multi_slot(&routes) .collect::>(); - assert!(addresses.contains(&Some("node1:6379"))); + assert!(addresses.contains(&Some(Arc::new("node1:6379".to_string())))); assert!( - addresses.contains(&Some("replica4:6379")) - || addresses.contains(&Some("replica5:6379")) - || addresses.contains(&Some("replica6:6379")) + addresses.contains(&Some(Arc::new("replica4:6379".to_string()))) + || addresses.contains(&Some(Arc::new("replica5:6379".to_string()))) + || addresses.contains(&Some(Arc::new("replica6:6379".to_string()))) ); } @@ -348,19 +385,21 @@ mod tests { (Route::new(3, SlotAddr::ReplicaOptional), vec![]), (Route::new(2003, SlotAddr::Master), vec![]), ]; - let addresses = slot_map + let addresses: Vec> = slot_map .addresses_for_multi_slot(&routes) - .collect::>(); + .flatten() + .collect(); + assert_eq!( addresses, - vec![ + process_expected_with_option(vec![ Some("replica1:6379"), Some("node3:6379"), Some("replica1:6379"), Some("node3:6379"), Some("replica1:6379"), Some("node3:6379") - ] + ]) ); } @@ -373,12 +412,19 @@ mod tests { (Route::new(6000, SlotAddr::ReplicaOptional), vec![]), (Route::new(2002, SlotAddr::Master), vec![]), ]; - let addresses = slot_map + let addresses: Vec> = slot_map .addresses_for_multi_slot(&routes) - .collect::>(); + .flatten() + .collect(); + assert_eq!( addresses, - vec![Some("replica1:6379"), None, None, Some("node3:6379")] + process_expected_with_option(vec![ + Some("replica1:6379"), + None, + None, + Some("node3:6379") + ]) ); } @@ -395,6 +441,9 @@ mod tests { assert_eq!( addresses, vec!["replica4:6379", "replica5:6379", "replica6:6379"] + .into_iter() + .map(|s| Arc::new(s.to_string())) + .collect::>() ); } @@ -402,33 +451,33 @@ mod tests { fn test_get_slots_of_node() { let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); assert_eq!( - slot_map.get_slots_of_node("node1:6379"), + slot_map.get_slots_of_node(Arc::new("node1:6379".to_string())), (1..1001).collect::>() ); assert_eq!( - slot_map.get_slots_of_node("node2:6379"), + slot_map.get_slots_of_node(Arc::new("node2:6379".to_string())), vec![1002..2001, 3001..4001] .into_iter() .flatten() .collect::>() ); assert_eq!( - slot_map.get_slots_of_node("replica3:6379"), + slot_map.get_slots_of_node(Arc::new("replica3:6379".to_string())), vec![1002..2001, 3001..4001] .into_iter() .flatten() .collect::>() ); assert_eq!( - slot_map.get_slots_of_node("replica4:6379"), + slot_map.get_slots_of_node(Arc::new("replica4:6379".to_string())), (2001..3001).collect::>() ); assert_eq!( - slot_map.get_slots_of_node("replica5:6379"), + slot_map.get_slots_of_node(Arc::new("replica5:6379".to_string())), (2001..3001).collect::>() ); assert_eq!( - slot_map.get_slots_of_node("replica6:6379"), + slot_map.get_slots_of_node(Arc::new("replica6:6379".to_string())), (2001..3001).collect::>() ); } diff --git a/redis/src/cluster_topology.rs b/redis/src/cluster_topology.rs index a2ce9ea078..8601a2bcf8 100644 --- a/redis/src/cluster_topology.rs +++ b/redis/src/cluster_topology.rs @@ -300,7 +300,7 @@ pub(crate) fn calculate_topology<'a>( #[cfg(test)] mod tests { use super::*; - use crate::cluster_routing::SlotAddrs; + use crate::cluster_routing::ShardAddrs; #[test] fn test_get_hashtag() { @@ -456,10 +456,11 @@ mod tests { assert_eq!(calculate_hash(&res1), calculate_hash(&res2)); assert_eq!(res1.0, res2.0); assert_eq!(res1.1.len(), res2.1.len()); - let equality_check = - res1.1.iter().zip(&res2.1).all(|(first, second)| { - first.start() == second.start() && first.end() == second.end() - }); + let equality_check = res1 + .1 + .iter() + .zip(&res2.1) + .all(|(first, second)| first.start == second.start && first.end == second.end); assert!(equality_check); let replicas_check = res1 .1 @@ -502,8 +503,24 @@ mod tests { } } - fn get_node_addr(name: &str, port: u16) -> SlotAddrs { - SlotAddrs::new(format!("{name}:{port}"), Vec::new()) + fn get_node_addr(name: &str, port: u16) -> ShardAddrs { + ShardAddrs { + primary: format!("{name}:{port}").into(), + replicas: Vec::new(), + } + } + + fn collect_shard_addrs(slot_map: &SlotMap) -> Vec { + let mut shard_addrs: Vec = slot_map + .nodes_map + .values() + .map(|shard_addrs| { + let addr_reader = shard_addrs.read().unwrap(); + addr_reader.clone() + }) + .collect(); + shard_addrs.sort_unstable(); + shard_addrs } #[test] @@ -524,9 +541,9 @@ mod tests { ReadFromReplicaStrategy::AlwaysFromPrimary, ) .unwrap(); - let res: Vec<_> = topology_view.values().collect(); + let res = collect_shard_addrs(&topology_view); let node_1 = get_node_addr("node1", 6379); - let expected: Vec<&SlotAddrs> = vec![&node_1]; + let expected: Vec = vec![node_1]; assert_eq!(res, expected); } @@ -566,10 +583,10 @@ mod tests { ReadFromReplicaStrategy::AlwaysFromPrimary, ) .unwrap(); - let res: Vec<_> = topology_view.values().collect(); + let res = collect_shard_addrs(&topology_view); let node_1 = get_node_addr("node1", 6379); let node_2 = get_node_addr("node2", 6380); - let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + let expected: Vec = vec![node_1, node_2]; assert_eq!(res, expected); } @@ -589,10 +606,10 @@ mod tests { ReadFromReplicaStrategy::AlwaysFromPrimary, ) .unwrap(); - let res: Vec<_> = topology_view.values().collect(); + let res = collect_shard_addrs(&topology_view); let node_1 = get_node_addr("node1", 6379); let node_2 = get_node_addr("node2", 6380); - let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + let expected: Vec = vec![node_1, node_2]; assert_eq!(res, expected); } @@ -613,10 +630,10 @@ mod tests { ReadFromReplicaStrategy::AlwaysFromPrimary, ) .unwrap(); - let res: Vec<_> = topology_view.values().collect(); + let res = collect_shard_addrs(&topology_view); let node_1 = get_node_addr("node3", 6381); let node_2 = get_node_addr("node4", 6382); - let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + let expected: Vec = vec![node_1, node_2]; assert_eq!(res, expected); } @@ -637,9 +654,9 @@ mod tests { ReadFromReplicaStrategy::AlwaysFromPrimary, ) .unwrap(); - let res: Vec<_> = topology_view.values().collect(); + let res = collect_shard_addrs(&topology_view); let node_1 = get_node_addr("node1", 6379); - let expected: Vec<&SlotAddrs> = vec![&node_1]; + let expected: Vec = vec![node_1]; assert_eq!(res, expected); } } diff --git a/redis/src/commands/cluster_scan.rs b/redis/src/commands/cluster_scan.rs index 9dc626034d..58a568cfb7 100644 --- a/redis/src/commands/cluster_scan.rs +++ b/redis/src/commands/cluster_scan.rs @@ -134,14 +134,14 @@ impl ScanStateRC { #[async_trait] pub(crate) trait ClusterInScan { /// Retrieves the address associated with a given slot in the cluster. - async fn get_address_by_slot(&self, slot: u16) -> RedisResult; + async fn get_address_by_slot(&self, slot: u16) -> RedisResult>; /// Retrieves the epoch of a given address in the cluster. /// The epoch represents the version of the address, which is updated when a failover occurs or slots migrate in. async fn get_address_epoch(&self, address: &str) -> Result; /// Retrieves the slots assigned to a given address in the cluster. - async fn get_slots_of_address(&self, address: &str) -> Vec; + async fn get_slots_of_address(&self, address: Arc) -> Vec; /// Routes a Redis command to a specific address in the cluster. async fn route_command(&self, cmd: Cmd, address: &str) -> RedisResult; @@ -165,7 +165,7 @@ pub(crate) struct ScanState { scanned_slots_map: SlotsBitsArray, // the address that is being scanned currently, based on the next slot set to 0 in the scanned_slots_map, and the address that "owns" the slot // in the SlotMap - pub(crate) address_in_scan: String, + pub(crate) address_in_scan: Arc, // epoch represent the version of the address, when a failover happens or slots migrate in the epoch will be updated to +1 address_epoch: u64, // the status of the scan operation @@ -189,7 +189,7 @@ impl ScanState { pub fn new( cursor: u64, scanned_slots_map: SlotsBitsArray, - address_in_scan: String, + address_in_scan: Arc, address_epoch: u64, scan_status: ScanStateStage, ) -> Self { @@ -206,7 +206,7 @@ impl ScanState { Self { cursor: 0, scanned_slots_map: [0; BITS_ARRAY_SIZE], - address_in_scan: String::new(), + address_in_scan: Default::default(), address_epoch: 0, scan_status: ScanStateStage::Finished, } @@ -310,7 +310,9 @@ impl ScanState { } // If epoch wasn't changed, the slots owned by the address after the refresh are all valid as slots that been scanned // So we will update the scanned_slots_map with the slots owned by the address - let slots_scanned = connection.get_slots_of_address(&self.address_in_scan).await; + let slots_scanned = connection + .get_slots_of_address(self.address_in_scan.clone()) + .await; for slot in slots_scanned { let slot_index = slot as usize / BITS_PER_U64; let slot_bit = slot as usize % BITS_PER_U64; @@ -349,7 +351,7 @@ impl ClusterInScan for Core where C: ConnectionLike + Connect + Clone + Send + Sync + 'static, { - async fn get_address_by_slot(&self, slot: u16) -> RedisResult { + async fn get_address_by_slot(&self, slot: u16) -> RedisResult> { let address = self .get_address_from_slot(slot, SlotAddr::ReplicaRequired) .await; @@ -374,7 +376,7 @@ where async fn get_address_epoch(&self, address: &str) -> Result { self.as_ref().get_address_epoch(address).await } - async fn get_slots_of_address(&self, address: &str) -> Vec { + async fn get_slots_of_address(&self, address: Arc) -> Vec { self.as_ref().get_slots_of_address(address).await } async fn route_command(&self, cmd: Cmd, address: &str) -> RedisResult { @@ -591,7 +593,7 @@ mod tests { let scan_state = ScanState { cursor: 0, scanned_slots_map: [0; BITS_ARRAY_SIZE], - address_in_scan: String::from("address1"), + address_in_scan: String::from("address1").into(), address_epoch: 1, scan_status: ScanStateStage::InProgress, }; @@ -607,7 +609,7 @@ mod tests { let scan_state = ScanState { cursor: 0, scanned_slots_map, - address_in_scan: String::from("address1"), + address_in_scan: String::from("address1").into(), address_epoch: 1, scan_status: ScanStateStage::InProgress, }; @@ -619,7 +621,7 @@ mod tests { let scan_state = ScanState { cursor: 0, scanned_slots_map, - address_in_scan: String::from("address1"), + address_in_scan: String::from("address1").into(), address_epoch: 1, scan_status: ScanStateStage::InProgress, }; @@ -633,14 +635,14 @@ mod tests { async fn refresh_if_topology_changed(&self) -> RedisResult { Ok(true) } - async fn get_address_by_slot(&self, _slot: u16) -> RedisResult { - Ok("mock_address".to_string()) + async fn get_address_by_slot(&self, _slot: u16) -> RedisResult> { + Ok("mock_address".to_string().into()) } async fn get_address_epoch(&self, _address: &str) -> Result { Ok(0) } - async fn get_slots_of_address(&self, address: &str) -> Vec { - if address == "mock_address" { + async fn get_slots_of_address(&self, address: Arc) -> Vec { + if address.as_str() == "mock_address" { vec![3, 4, 5] } else { vec![0, 1, 2] @@ -662,7 +664,10 @@ mod tests { // Assert that the scan state is initialized correctly assert_eq!(scan_state.cursor, 0); assert_eq!(scan_state.scanned_slots_map, [0; BITS_ARRAY_SIZE]); - assert_eq!(scan_state.address_in_scan, "mock_address"); + assert_eq!( + scan_state.address_in_scan, + "mock_address".to_string().into() + ); assert_eq!(scan_state.address_epoch, 0); } @@ -672,7 +677,7 @@ mod tests { let scan_state = ScanState { cursor: 0, scanned_slots_map: [0; BITS_ARRAY_SIZE], - address_in_scan: "".to_string(), + address_in_scan: "".to_string().into(), address_epoch: 0, scan_status: ScanStateStage::InProgress, }; @@ -708,7 +713,10 @@ mod tests { assert_eq!(updated_scan_state.cursor, 0); // address_in_scan should be updated to the new address - assert_eq!(updated_scan_state.address_in_scan, "mock_address"); + assert_eq!( + updated_scan_state.address_in_scan, + "mock_address".to_string().into() + ); // address_epoch should be updated to the new address epoch assert_eq!(updated_scan_state.address_epoch, 0); @@ -720,7 +728,7 @@ mod tests { let scan_state = ScanState::new( 0, [0; BITS_ARRAY_SIZE], - "address".to_string(), + "address".to_string().into(), 0, ScanStateStage::InProgress, ); @@ -731,7 +739,10 @@ mod tests { .unwrap(); assert_eq!(updated_scan_state.scanned_slots_map, scanned_slots_map); assert_eq!(updated_scan_state.cursor, 0); - assert_eq!(updated_scan_state.address_in_scan, "mock_address"); + assert_eq!( + updated_scan_state.address_in_scan, + "mock_address".to_string().into() + ); assert_eq!(updated_scan_state.address_epoch, 0); } }