Skip to content

Commit

Permalink
Handle PubSub commands routing (#176)
Browse files Browse the repository at this point in the history
---------
Signed-off-by: shohame <[email protected]>
  • Loading branch information
shohamazon authored Jul 29, 2024
1 parent b8c921d commit de53b2b
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 12 deletions.
7 changes: 7 additions & 0 deletions redis/src/cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,13 @@ where
_ => crate::cluster_routing::combine_array_results(results),
}
}
Some(ResponsePolicy::CombineMaps) => {
let results = results
.into_iter()
.map(|res| res.map(|(_, val)| val))
.collect::<RedisResult<Vec<_>>>()?;
crate::cluster_routing::combine_map_results(results)
}
Some(ResponsePolicy::Special) | None => {
// This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user.
// TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes.
Expand Down
5 changes: 5 additions & 0 deletions redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,11 @@ where
_ => crate::cluster_routing::combine_array_results(results),
})
}
Some(ResponsePolicy::CombineMaps) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(crate::cluster_routing::combine_map_results)
}
Some(ResponsePolicy::Special) | None => {
// This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user.
// TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes.
Expand Down
125 changes: 113 additions & 12 deletions redis/src/cluster_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub enum ResponsePolicy {
CombineArrays,
/// Handling is not defined by the Redis standard. Will receive a special case
Special,
/// Combines multiple map responses into a single map.
CombineMaps,
}

/// Defines whether a request should be routed to a single node, or multiple ones.
Expand Down Expand Up @@ -187,8 +189,42 @@ pub fn logical_aggregate(values: Vec<Value>, op: LogicalAggregateOp) -> RedisRes
.collect(),
))
}
/// Aggregate array responses into a single map.
pub fn combine_map_results(values: Vec<Value>) -> RedisResult<Value> {
let mut map: HashMap<Vec<u8>, i64> = HashMap::new();

/// Aggreagte arrau responses into a single array.
for value in values {
match value {
Value::Array(elements) => {
let mut iter = elements.into_iter();

while let Some(key) = iter.next() {
if let Value::BulkString(key_bytes) = key {
if let Some(Value::Int(value)) = iter.next() {
*map.entry(key_bytes).or_insert(0) += value;
} else {
return Err((ErrorKind::TypeError, "expected integer value").into());
}
} else {
return Err((ErrorKind::TypeError, "expected string key").into());
}
}
}
_ => {
return Err((ErrorKind::TypeError, "expected array of values as response").into());
}
}
}

let result_vec: Vec<(Value, Value)> = map
.into_iter()
.map(|(k, v)| (Value::BulkString(k), Value::Int(v)))
.collect();

Ok(Value::Map(result_vec))
}

/// Aggregate array responses into a single array.
pub fn combine_array_results(values: Vec<Value>) -> RedisResult<Value> {
let mut results = Vec::new();

Expand Down Expand Up @@ -302,7 +338,9 @@ impl ResponsePolicy {
b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)),

b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK"
| b"LATENCY RESET" => Some(ResponsePolicy::Aggregate(AggregateOp::Sum)),
| b"LATENCY RESET" | b"PUBSUB NUMPAT" => {
Some(ResponsePolicy::Aggregate(AggregateOp::Sum))
}

b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)),

Expand All @@ -314,7 +352,10 @@ impl ResponsePolicy {
Some(ResponsePolicy::AllSucceeded)
}

b"KEYS" | b"MGET" | b"SLOWLOG GET" => Some(ResponsePolicy::CombineArrays),
b"KEYS" | b"MGET" | b"SLOWLOG GET" | b"PUBSUB CHANNELS" | b"PUBSUB SHARDCHANNELS" => {
Some(ResponsePolicy::CombineArrays)
}
b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps),

b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded),

Expand Down Expand Up @@ -354,11 +395,30 @@ enum RouteBy {

fn base_routing(cmd: &[u8]) -> RouteBy {
match cmd {
b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"CLIENT SETNAME" | b"CLIENT SETINFO"
| b"SLOWLOG GET" | b"SLOWLOG LEN" | b"SLOWLOG RESET" | b"CONFIG SET"
| b"CONFIG RESETSTAT" | b"CONFIG REWRITE" | b"SCRIPT FLUSH" | b"SCRIPT LOAD"
| b"LATENCY RESET" | b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY"
| b"LATENCY DOCTOR" | b"LATENCY LATEST" => RouteBy::AllNodes,
b"ACL SETUSER"
| b"ACL DELUSER"
| b"ACL SAVE"
| b"CLIENT SETNAME"
| b"CLIENT SETINFO"
| b"SLOWLOG GET"
| b"SLOWLOG LEN"
| b"SLOWLOG RESET"
| b"CONFIG SET"
| b"CONFIG RESETSTAT"
| b"CONFIG REWRITE"
| b"SCRIPT FLUSH"
| b"SCRIPT LOAD"
| b"LATENCY RESET"
| b"LATENCY GRAPH"
| b"LATENCY HISTOGRAM"
| b"LATENCY HISTORY"
| b"LATENCY DOCTOR"
| b"LATENCY LATEST"
| b"PUBSUB NUMPAT"
| b"PUBSUB CHANNELS"
| b"PUBSUB NUMSUB"
| b"PUBSUB SHARDCHANNELS"
| b"PUBSUB SHARDNUMSUB" => RouteBy::AllNodes,

b"DBSIZE"
| b"FLUSHALL"
Expand Down Expand Up @@ -463,10 +523,6 @@ fn base_routing(cmd: &[u8]) -> RouteBy {
| b"MODULE LOAD"
| b"MODULE LOADEX"
| b"MODULE UNLOAD"
| b"PUBSUB CHANNELS"
| b"PUBSUB NUMPAT"
| b"PUBSUB NUMSUB"
| b"PUBSUB SHARDCHANNELS"
| b"READONLY"
| b"READWRITE"
| b"SAVE"
Expand Down Expand Up @@ -1233,4 +1289,49 @@ mod tests {
])
);
}

#[test]
fn test_combine_map_results() {
let input = vec![];
let result = super::combine_map_results(input).unwrap();
assert_eq!(result, Value::Map(vec![]));

let input = vec![
Value::Array(vec![
Value::BulkString(b"key1".to_vec()),
Value::Int(5),
Value::BulkString(b"key2".to_vec()),
Value::Int(10),
]),
Value::Array(vec![
Value::BulkString(b"key1".to_vec()),
Value::Int(3),
Value::BulkString(b"key3".to_vec()),
Value::Int(15),
]),
];
let result = super::combine_map_results(input).unwrap();
let mut expected = vec![
(Value::BulkString(b"key1".to_vec()), Value::Int(8)),
(Value::BulkString(b"key2".to_vec()), Value::Int(10)),
(Value::BulkString(b"key3".to_vec()), Value::Int(15)),
];
expected.sort_unstable_by(|a, b| match (&a.0, &b.0) {
(Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
_ => std::cmp::Ordering::Equal,
});
let mut result_vec = match result {
Value::Map(v) => v,
_ => panic!("Expected Map"),
};
result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) {
(Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
_ => std::cmp::Ordering::Equal,
});
assert_eq!(result_vec, expected);

let input = vec![Value::Int(5)];
let result = super::combine_map_results(input);
assert!(result.is_err());
}
}

0 comments on commit de53b2b

Please sign in to comment.