diff --git a/redis/src/cluster_async/connections_container.rs b/redis/src/cluster_async/connections_container.rs index 2a438aa77..235c13b82 100644 --- a/redis/src/cluster_async/connections_container.rs +++ b/redis/src/cluster_async/connections_container.rs @@ -140,11 +140,7 @@ where /// Returns true if the address represents a known primary node. pub(crate) fn is_primary(&self, address: &ArcStr) -> bool { - self.connection_for_address(address).is_some() - && self - .slot_map - .values() - .any(|slot_addrs| slot_addrs.primary.as_str() == address) + self.connection_for_address(address).is_some() && self.slot_map.is_primary(address) } fn round_robin_read_from_replica( diff --git a/redis/src/cluster_routing.rs b/redis/src/cluster_routing.rs index a37b875be..0ca2c147d 100644 --- a/redis/src/cluster_routing.rs +++ b/redis/src/cluster_routing.rs @@ -889,23 +889,23 @@ pub enum SlotAddr { /// which stores only the master and [optional] replica /// to avoid the need to choose a replica each time /// a command is executed -#[derive(Debug, Eq, PartialEq)] -pub(crate) struct SlotAddrs { +#[derive(Debug, Eq, PartialEq, Clone)] +pub(crate) struct ShardAddrs { pub(crate) primary: String, pub(crate) replicas: Vec, } -impl SlotAddrs { +impl ShardAddrs { pub(crate) fn new(primary: String, replicas: Vec) -> Self { Self { primary, replicas } } pub(crate) fn from_slot(slot: Slot) -> Self { - SlotAddrs::new(slot.master, slot.replicas) + ShardAddrs::new(slot.master, slot.replicas) } } -impl<'a> IntoIterator for &'a SlotAddrs { +impl<'a> IntoIterator for &'a ShardAddrs { type Item = &'a String; type IntoIter = std::iter::Chain, std::slice::Iter<'a, String>>; diff --git a/redis/src/cluster_slotmap.rs b/redis/src/cluster_slotmap.rs index 7f1f70af9..9363f0083 100644 --- a/redis/src/cluster_slotmap.rs +++ b/redis/src/cluster_slotmap.rs @@ -1,28 +1,75 @@ +use std::sync::Arc; use std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, fmt::Display, sync::atomic::AtomicUsize, }; -use crate::cluster_routing::{Route, Slot, SlotAddr, SlotAddrs}; +use std::sync::RwLock; + +use arcstr::ArcStr; + +use crate::cluster_routing::{Route, ShardAddrs, Slot, SlotAddr}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct SlotRange { + pub start: u16, + pub end: u16, +} #[derive(Debug)] -pub(crate) struct SlotMapValue { - pub(crate) start: u16, - pub(crate) addrs: SlotAddrs, - pub(crate) latest_used_replica: AtomicUsize, +pub struct ShardInfo { + pub slot_ranges: HashSet, + pub addrs: Arc>, } -impl SlotMapValue { - fn from_slot(slot: Slot) -> Self { - Self { - start: slot.start(), - addrs: SlotAddrs::from_slot(slot), - latest_used_replica: AtomicUsize::new(0), - } - } +#[derive(Debug, Clone, PartialEq)] +pub enum NodeRole { + Primary, + Replica, +} + +#[derive(Debug)] +pub struct NodeInfo { + pub role: NodeRole, + pub shard_info: Arc>, +} + +#[derive(Debug)] +pub struct SlotMapValue { + pub addrs: Arc>, + pub slot_range: SlotRange, + pub latest_used_replica: AtomicUsize, +} + +#[derive(Debug, Default)] +pub struct SlotMap { + pub slots: BTreeMap, + pub nodes_map: HashMap>>, + read_from_replica: ReadFromReplicaStrategy, } +// #[derive(Debug)] +// pub(crate) struct SlotMapValue { +// pub(crate) start: u16, +// pub(crate) addrs: ShardAddrs, +// pub(crate) latest_used_replica: AtomicUsize, +// } + +// impl SlotMapValue { +// fn from_slot(slot: Slot) -> Self { +// Self { +// slot_range: SlotRange { +// start: slot.start(), +// end: slot.end(), +// }, +// start: slot.start(), +// addrs: ShardAddrs::from_slot(slot), +// latest_used_replica: AtomicUsize::new(0), +// } +// } +// } + #[derive(Debug, Default, Clone, PartialEq, Copy)] pub(crate) enum ReadFromReplicaStrategy { #[default] @@ -30,44 +77,119 @@ pub(crate) enum ReadFromReplicaStrategy { RoundRobin, } -#[derive(Debug, Default)] -pub(crate) struct SlotMap { - pub(crate) slots: BTreeMap, - read_from_replica: ReadFromReplicaStrategy, -} +// #[derive(Debug, Default)] +// pub(crate) struct SlotMap { +// /// A mapping of slot numbers to their associated `SlotMapValue`. +// /// +// /// This `BTreeMap` holds information about each slot in the cluster. +// /// The slot number is the key, and the `SlotMapValue` contains details +// /// about the slot's range, addresses, and the latest used replica. +// /// +// /// This map is used to keep track of slot assignments and helps in routing requests to the correct node in the cluster. +// pub(crate) slots: BTreeMap, +// // Maps primary node addresses to their owned slots +// /// This `HashMap` is used to quickly locate which slots are owned by which primary nodes, for example in order to update specific slot owner. +// pub(crate) primary_slots: HashMap>, +// read_from_replica: ReadFromReplicaStrategy, +// } fn get_address_from_slot( slot: &SlotMapValue, read_from_replica: ReadFromReplicaStrategy, slot_addr: SlotAddr, ) -> &str { - if slot_addr == SlotAddr::Master || slot.addrs.replicas.is_empty() { - return slot.addrs.primary.as_str(); + let addr_reader = slot.addrs.read().unwrap(); + if slot_addr == SlotAddr::Master || addr_reader.replicas.is_empty() { + return addr_reader.primary.as_str(); } match read_from_replica { - ReadFromReplicaStrategy::AlwaysFromPrimary => slot.addrs.primary.as_str(), + ReadFromReplicaStrategy::AlwaysFromPrimary => addr_reader.primary.as_str(), ReadFromReplicaStrategy::RoundRobin => { let index = slot .latest_used_replica .fetch_add(1, std::sync::atomic::Ordering::Relaxed) - % slot.addrs.replicas.len(); - slot.addrs.replicas[index].as_str() + % addr_reader.replicas.len(); + addr_reader.replicas[index].as_str() } } } impl SlotMap { - pub(crate) fn new(slots: Vec, read_from_replica: ReadFromReplicaStrategy) -> Self { - let mut this = Self { + // pub(crate) fn new(slots: Vec) -> Self { + // let mut this = Self { + // slots_map: Default::default(), + // nodes_map: Default::default(), + // read_from_replica, + // }; + // this.slots.extend( + // slots + // .into_iter() + // .map(|slot| (slot.end(), SlotMapValue::from_slot(slot))), + // ); + // this + // } + + pub fn new(slots: Vec, read_from_replica: ReadFromReplicaStrategy) -> Self { + let mut slot_map = SlotMap { slots: BTreeMap::new(), + nodes_map: HashMap::new(), read_from_replica, }; - this.slots.extend( - slots - .into_iter() - .map(|slot| (slot.end(), SlotMapValue::from_slot(slot))), - ); - this + + for slot in slots { + let primary = slot.master.clone(); + let replicas = slot.replicas.clone(); + + // Check if primary is already in nodes_map + let shard_info_arc = slot_map + .nodes_map + .entry(primary.clone()) + .or_insert_with(|| { + Arc::new(RwLock::new(ShardInfo { + slot_ranges: HashSet::new(), + addrs: Arc::new(RwLock::new(ShardAddrs { + primary, + replicas: replicas.clone(), + })), + })) + }) + .clone(); + + // Also add replicas with the same ShardInfo + for replica in &replicas { + slot_map + .nodes_map + .entry(replica.clone()) + .or_insert(shard_info_arc.clone()); + } + + let shard_info = shard_info_arc.write().unwrap(); + shard_info.slot_ranges.insert(SlotRange { + start: slot.start, + end: slot.end, + }); + + slot_map.slots.insert( + slot.end, + SlotMapValue { + addrs: shard_info.addrs.clone(), + slot_range: SlotRange { + start: slot.start, + end: slot.end, + }, + latest_used_replica: AtomicUsize::new(0), + }, + ); + } + + slot_map + } + + pub fn is_primary(&self, address: &ArcStr) -> bool { + let address = address.to_string(); + self.nodes_map.get(&address).map_or(false, |shard_info| { + shard_info.read().unwrap().addrs.read().unwrap().primary == address + }) } pub fn slot_value_for_route(&self, route: &Route) -> Option<&SlotMapValue> { @@ -76,7 +198,7 @@ impl SlotMap { .range(slot..) .next() .and_then(|(end, slot_value)| { - if slot <= *end && slot_value.start <= slot { + if slot <= *end && slot_value.slot_range.start <= slot { Some(slot_value) } else { None @@ -90,16 +212,23 @@ impl SlotMap { }) } - pub fn values(&self) -> impl Iterator { - self.slots.values().map(|slot_value| &slot_value.addrs) + pub fn values(&self) -> Vec { + self.slots + .values() + .map(|slot_value| { + let addr_reader = slot_value.addrs.read().unwrap(); + addr_reader.clone() + }) + .collect() } fn all_unique_addresses(&self, only_primaries: bool) -> HashSet<&str> { let mut addresses = HashSet::new(); - for slot in self.values() { - addresses.insert(slot.primary.as_str()); + for addr_rw in self.slots.values().map(|slot_value| slot_value.addrs) { + let addr_read = addr_rw.read().unwrap(); + addresses.insert(addr_read.primary.as_str()); if !only_primaries { - addresses.extend(slot.replicas.iter().map(|str| str.as_str())); + addresses.extend(addr_read.replicas.iter().map(|str| str.as_str())); } } @@ -132,10 +261,11 @@ impl SlotMap { self.slots .iter() .filter_map(|(end, slot_value)| { - if slot_value.addrs.primary == node_address - || slot_value.addrs.replicas.contains(&node_address) + let addr_reader = slot_value.addrs.read().unwrap(); + if addr_reader.primary == node_address + || addr_reader.replicas.contains(&node_address) { - Some(slot_value.start..(*end + 1)) + Some(slot_value.slot_range.start..(*end + 1)) } else { None } @@ -150,7 +280,7 @@ impl SlotMap { slot_addr: SlotAddr, ) -> Option { self.slots.range(slot..).next().and_then(|(_, slot_value)| { - if slot_value.start <= slot { + if slot_value.slot_range.start <= slot { Some( get_address_from_slot(slot_value, self.read_from_replica, slot_addr) .to_string(), @@ -160,6 +290,18 @@ impl SlotMap { } }) } + + pub(crate) fn update_primary_for_slot(&mut self, slot: u16, new_primary: String) { + // Remove the slot from the current primary's slots + // if let Some(curr_primary) = self.get_node_address_for_slot(slot, SlotAddr::Master) { + // self.primary_slots + // .get_mut(&curr_primary) + // .map(|slots_set| slots_set.remove(&slot)); + // self.slots. + // } + // Add the slot to the new primary's slots + todo!() + } } impl Display for SlotMap { @@ -169,10 +311,10 @@ impl Display for SlotMap { writeln!( f, "({}-{}): primary: {}, replicas: {:?}", - slot_map_value.start, + slot_map_value.slot_range.start, end, - slot_map_value.addrs.primary, - slot_map_value.addrs.replicas + slot_map_value.addrs.read().unwrap().primary, + slot_map_value.addrs.read().unwrap().replicas )?; } Ok(()) diff --git a/redis/src/cluster_topology.rs b/redis/src/cluster_topology.rs index a2ce9ea07..d00766774 100644 --- a/redis/src/cluster_topology.rs +++ b/redis/src/cluster_topology.rs @@ -300,7 +300,7 @@ pub(crate) fn calculate_topology<'a>( #[cfg(test)] mod tests { use super::*; - use crate::cluster_routing::SlotAddrs; + use crate::cluster_routing::ShardAddrs; #[test] fn test_get_hashtag() { @@ -502,8 +502,8 @@ mod tests { } } - fn get_node_addr(name: &str, port: u16) -> SlotAddrs { - SlotAddrs::new(format!("{name}:{port}"), Vec::new()) + fn get_node_addr(name: &str, port: u16) -> ShardAddrs { + ShardAddrs::new(format!("{name}:{port}"), Vec::new()) } #[test] @@ -526,7 +526,7 @@ mod tests { .unwrap(); let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node1", 6379); - let expected: Vec<&SlotAddrs> = vec![&node_1]; + let expected: Vec<&ShardAddrs> = vec![&node_1]; assert_eq!(res, expected); } @@ -569,7 +569,7 @@ mod tests { let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node1", 6379); let node_2 = get_node_addr("node2", 6380); - let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + let expected: Vec<&ShardAddrs> = vec![&node_1, &node_2]; assert_eq!(res, expected); } @@ -592,7 +592,7 @@ mod tests { let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node1", 6379); let node_2 = get_node_addr("node2", 6380); - let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + let expected: Vec<&ShardAddrs> = vec![&node_1, &node_2]; assert_eq!(res, expected); } @@ -616,7 +616,7 @@ mod tests { let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node3", 6381); let node_2 = get_node_addr("node4", 6382); - let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + let expected: Vec<&ShardAddrs> = vec![&node_1, &node_2]; assert_eq!(res, expected); } @@ -639,7 +639,7 @@ mod tests { .unwrap(); let res: Vec<_> = topology_view.values().collect(); let node_1 = get_node_addr("node1", 6379); - let expected: Vec<&SlotAddrs> = vec![&node_1]; + let expected: Vec<&ShardAddrs> = vec![&node_1]; assert_eq!(res, expected); } }