From 67e1ac8a81d7732fcde85c33a179e2e5a9c7c122 Mon Sep 17 00:00:00 2001 From: Rob Date: Thu, 8 Feb 2024 17:10:30 -0500 Subject: [PATCH] major functionality --- Cargo.lock | 9 +- broker/Cargo.toml | 3 +- broker/src/lib.rs | 519 +++++++++++++++++++++---- broker/src/main.rs | 7 +- broker/src/state.rs | 285 +++++++------- client/src/lib.rs | 65 +--- client/src/main.rs | 152 ++++++-- client/src/retry.rs | 240 ++++++------ marshal/src/lib.rs | 12 +- proto/schema/messages.capnp | 14 +- proto/src/connection/auth/broker.rs | 48 ++- proto/src/connection/auth/marshal.rs | 8 +- proto/src/connection/auth/mod.rs | 1 + proto/src/connection/auth/user.rs | 16 +- proto/src/connection/batch.rs | 300 ++++++++++++++ proto/src/connection/mod.rs | 1 + proto/src/connection/protocols/mod.rs | 133 +++++-- proto/src/connection/protocols/quic.rs | 300 ++++++-------- proto/src/connection/protocols/tcp.rs | 300 ++++++-------- proto/src/message.rs | 28 +- 20 files changed, 1567 insertions(+), 874 deletions(-) create mode 100644 proto/src/connection/batch.rs diff --git a/Cargo.lock b/Cargo.lock index 6a5b089..5bf4a46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -433,10 +433,9 @@ version = "0.1.0" dependencies = [ "ark-serialize", "clap", - "either", + "hex", "jf-primitives", "local-ip-address", - "parking_lot", "proto", "tokio", "tracing", @@ -954,6 +953,12 @@ version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "idna" version = "0.5.0" diff --git a/broker/Cargo.toml b/broker/Cargo.toml index 5faf968..047b3c0 100644 --- a/broker/Cargo.toml +++ b/broker/Cargo.toml @@ -12,7 +12,6 @@ tokio.workspace = true tracing.workspace = true ark-serialize.workspace = true tracing-subscriber.workspace = true -parking_lot = "0.12.1" clap.workspace = true local-ip-address = "0.5.7" -either = "1.9.0" +hex = "0.4.3" \ No newline at end of file diff --git a/broker/src/lib.rs b/broker/src/lib.rs index 5cefb9b..c9ac507 100644 --- a/broker/src/lib.rs +++ b/broker/src/lib.rs @@ -1,11 +1,11 @@ //! This file contains the implementation of the `Broker`, which routes messages //! for the Push CDN. -// TODO: convert QUIC to locked single sender/reciver +// TODO: massive cleanup on this file mod state; -use std::{marker::PhantomData, sync::Arc, time::Duration}; +use std::{collections::HashSet, marker::PhantomData, sync::Arc, time::Duration}; use jf_primitives::signatures::SignatureScheme as JfSignatureScheme; // TODO: figure out if we should use Tokio's here @@ -13,17 +13,19 @@ use proto::{ authenticate_with_broker, bail, connection::{ auth::broker::BrokerAuth, - protocols::{Connection, Listener, Protocol}, + batch::BatchedSender, + protocols::{Listener, Protocol, Receiver}, }, crypto::{KeyPair, Serializable}, error::{Error, Result}, + message::{Message, Subscribe, UsersConnected, UsersDisconnected}, parse_socket_address, redis::{self, BrokerIdentifier}, verify_broker, }; -use state::{ConnectionLookup, ConnectionWithQueue}; +use state::ConnectionLookup; use tokio::{select, spawn, sync::RwLock, time::sleep}; -use tracing::{error, warn}; +use tracing::{error, info, warn}; /// The broker's configuration. We need this when we create a new one. /// TODO: clean up these generics. could be a generic type that implements both @@ -72,12 +74,20 @@ struct Inner< /// The underlying (public) verification key, used to authenticate with the server. Checked /// against the stake table. - /// TODO: verif & signing key in one struct - pub keypair: KeyPair, + keypair: KeyPair, + + /// A set of all other brokers. We need this to send to all connected brokers. + other_brokers: RwLock>)>>, - pub broker_connections: RwLock>, + /// A map of interests to their possible broker connections. We use this to facilitate + /// where messages go. They need to be separate because of possible separate protocol + /// types. + broker_connection_lookup: RwLock>, - pub user_connections: RwLock>, + /// A map of interests to their possible user connections. We use this to facilitate + /// where messages go. They need to be separate because of possible separate protocol + /// types. + user_connection_lookup: RwLock>, // connected_keys: LoggedSet, /// The `PhantomData` that we need to be generic over protocol types. @@ -108,6 +118,84 @@ pub struct Broker< broker_listener: BrokerProtocolType::Listener, } +macro_rules! remove_local_and_exit_on_error { + ($operation: expr, $inner: expr, $sender: expr, $object: ident) => { + match $operation { + Ok(op) => op, + Err(_) => { + remove_local!($inner, $sender, $object); + return; + } + } + }; +} + +macro_rules! remove_local { + ($inner: expr, $sender: expr, user) => { + // Remove all connections associated with the user + $inner + .user_connection_lookup + .write() + .await + .unsubscribe_connection(&$sender); + }; + + ($inner: expr, $sender: expr, broker) => { + // Remove all connections associated with the broker + $inner + .broker_connection_lookup + .write() + .await + .unsubscribe_connection(&$sender.clone()); + + // Remove from "all brokers" + $inner.other_brokers.write().await.retain(|(_, broker)| broker != &$sender); + }; +} + +macro_rules! remove_remote_and_exit_on_error { + ($operation: expr, $inner: expr, $key: expr) => { + match $operation { + Ok(op) => op, + Err(_) => { + remove_remote!($inner, $key); + return; + } + } + }; +} + +macro_rules! remove_remote { + ($inner:expr, $key: expr) => { + // Tell all other brokers that we're done with the user + // TODO IMP: IF REMOVE TOPIC, SEND THAT TOPIC IS UNSUBSCRIBE + let brokers:Vec<(BrokerIdentifier, Arc>)> = + $inner.other_brokers.read().await.iter().cloned().collect(); + + // For all brokers, send the disconect message + if !brokers.is_empty() { + let disconnected_message = + // TODO: see if we need clone here + Message::UsersDisconnected(UsersDisconnected { users: vec![$key.clone()] }); + + // TODO: DOCUMENT THIS EXPECT + // Serialize the message + let disconnected_message:Arc> = Arc::from( + disconnected_message + .serialize() + .expect("serialization to succeed"), + ); + + // Send the message + for broker in brokers { + // If we fail to send it, remove the broker + // TODO: remove brokers here on error + let _ = broker.1.queue_message_back(disconnected_message.clone()); + } + } + }; +} + impl< BrokerSignatureScheme: JfSignatureScheme, BrokerProtocolType: Protocol, @@ -160,7 +248,7 @@ where // Create the user (public) listener let user_bind_address = parse_socket_address!(user_bind_address); let user_listener = bail!( - UserProtocolType::Listener::bind( + UserProtocolType::bind( user_bind_address, maybe_tls_cert_path.clone(), maybe_tls_key_path.clone(), @@ -176,12 +264,8 @@ where // Create the broker (private) listener let broker_bind_address = parse_socket_address!(broker_bind_address); let broker_listener = bail!( - BrokerProtocolType::Listener::bind( - broker_bind_address, - maybe_tls_cert_path, - maybe_tls_key_path, - ) - .await, + BrokerProtocolType::bind(broker_bind_address, maybe_tls_cert_path, maybe_tls_key_path,) + .await, Connection, format!( "failed to bind to public (user) bind address {}", @@ -195,8 +279,9 @@ where redis_client, identifier, keypair, - broker_connections: RwLock::from(ConnectionLookup::default()), - user_connections: RwLock::from(ConnectionLookup::default()), + other_brokers: RwLock::default(), + broker_connection_lookup: RwLock::default(), + user_connection_lookup: RwLock::default(), pd: PhantomData, }), user_listener, @@ -204,14 +289,12 @@ where }) } - /// This function handles a broker (private) connection. We take the following steps: - /// 1. Authenticate the broker - /// 2. TODO + /// This function is the callback for handling a broker (private) connection. async fn handle_broker_connection( inner: Arc< Inner, >, - connection: BrokerProtocolType::Connection, + mut connection: (BrokerProtocolType::Sender, BrokerProtocolType::Receiver), is_outbound: bool, ) { // Depending on which way the direction came in, we will want to authenticate with a different @@ -227,31 +310,168 @@ where authenticate_with_broker!(connection, inner) }; - // Create a new queued connection - let connection = Arc::from( - ConnectionWithQueue::::from_connection_and_params( - connection, - Duration::from_millis(50), - 5000, - ), - ); + // Create new batch sender + let (sender, mut receiver) = connection; + // TODO: parameterize max interval and max size + let sender = Arc::new(BatchedSender::from(sender, Duration::from_millis(50), 1500)); + + // Freeze the sender before adding it so we don't receive messages out of order + let _ = sender.freeze(); + + // Add to "other brokers" so we can start adding relevant messages to the queue + inner + .other_brokers + .write() + .await + .insert((broker_address.clone(), sender.clone())); + + // Create and serialize a message with the keys we're connected to + // TODO: macro for this + let users = inner.user_connection_lookup.read().await.get_all_keys(); + let message = Message::UsersConnected(UsersConnected { users }); + + // If we fail serialization, remove from the "all/other brokers" map. + let message = Arc::from(remove_local_and_exit_on_error!( + message.serialize(), + inner, + sender, + broker + )); + + // Put the message at the front of the queue so that we send messages in order + let _ = sender.queue_message_front(message); + + // Create and serialize a message with the topics we're interested in + let topics = inner.user_connection_lookup.read().await.get_all_topics(); + let message = Message::Subscribe(Subscribe { topics }); + + // If we fail serialization, remove from the "all/other brokers" map. + let message = Arc::from(remove_local_and_exit_on_error!( + message.serialize(), + inner, + sender, + broker + )); + + // Put the message at the front of the queue so that we send messages in order + let _ = sender.queue_message_front(message); + + // Unfreeze our queue, which flushes it and lets us finally send (in order) messages. + let _ = sender.unfreeze(); + + info!("received connection from broker {}", broker_address); + + // The message receive loop. On exit, remove the broker's connection everywhere + while let Ok(message) = receiver.recv_message().await { + // See what type of message this is + match message { + // A direct message. We want this to go to either the associated broker or user. + Message::Direct(direct) => { + // Find out where the message is supposed to go + let possible_user = inner + .user_connection_lookup + .read() + .await + .get_connection_by_key(&direct.recipient); + + // If user is connected, queue the message for sending + // TODO: max queue size before force quit + if let Some(user) = possible_user { + // Create a new `Data` and `Arc` it. + let message = Arc::new(remove_local_and_exit_on_error!( + Message::serialize(&Message::Direct(direct)), + inner, + sender, + broker + )); + + // Send them the message. If we fail, remove them + if user.queue_message_back(message).is_err() { + remove_local!(inner, user, user); + } + } + } + + Message::Broadcast(broadcast) => { + // TODO: macro this + // Find out where the message is supposed to go + let connections = inner + .user_connection_lookup + .read() + .await + .get_connections_by_topic(broadcast.topics.clone()); + + // If there are any users, queue the message to send for all of them + if !connections.is_empty() { + // Create a new `Data` and `Arc` it. + let message = Arc::new(remove_local_and_exit_on_error!( + Message::serialize(&Message::Broadcast(broadcast)), + inner, + sender, + broker + )); + + // For each user, them the message. If we fail, remove them + for user in connections { + if user.queue_message_back(message.clone()).is_err() { + remove_local!(inner, user, user); + } + } + } + } - // Add the connection to our map + // If we receive a subscribe message from a broker, subscribe them to those topics + Message::Subscribe(subscribe) => inner + .broker_connection_lookup + .write() + .await + .subscribe_connection_to_topics(sender.clone(), subscribe.topics), + + // If we receive an unsubscribe message from a broker, unsubscribe them from those topics + Message::Unsubscribe(unsubscribe) => inner + .broker_connection_lookup + .write() + .await + .unsubscribe_connection_from_topics(&sender, unsubscribe.topics), + + // If we receive a `UsersConnected` message, subscribe that connection to the keys it presented + Message::UsersConnected(users_connected) => inner + .broker_connection_lookup + .write() + .await + .subscribe_connection_to_keys(&sender, users_connected.users), + + // If we receive a `UsersConnected` message, unsubscribe that connection from the keys it presented + Message::UsersDisconnected(users_disconnected) => inner + .broker_connection_lookup + .write() + .await + .unsubscribe_connection_from_keys(users_disconnected.users), + + // We should not be receiving any of these messages + Message::AuthenticateResponse(_) + | Message::AuthenticateWithKey(_) + | Message::AuthenticateWithPermit(_) => { + remove_local!(inner, sender, broker); + return; + } + } + } + + remove_local!(inner, sender, broker); } - /// This function handles a user (public) connection. We take the following steps: - /// 1. Authenticate the user - /// 2. TODO + /// This function handles a user (public) connection. async fn handle_user_connection( inner: Arc< Inner, >, - connection: UserProtocolType::Connection, + mut connection: (UserProtocolType::Sender, UserProtocolType::Receiver), ) { // Verify (authenticate) the connection let Ok((verification_key, topics)) = BrokerAuth::::verify_user( - &connection, + &mut connection, &inner.identifier, &mut inner.redis_client.clone(), ) @@ -260,29 +480,194 @@ where return; }; - // Create a new queued connection - let connection = Arc::from( - ConnectionWithQueue::::from_connection_and_params( - connection, - Duration::from_millis(50), - 5000, - ), - ); - - println!("user subbed to {:?}", topics); + // Create new batch sender + let (sender, mut receiver) = connection; + let sender = Arc::new(BatchedSender::::from( + sender, + Duration::from_millis(50), + 1500, + )); + + // Send the information to other brokers, if any. + // Remove them if we failed to send to them, + // TODO: WAL HERE maybe + let brokers: Vec<(BrokerIdentifier, Arc>)> = + inner.other_brokers.read().await.iter().cloned().collect(); + if !brokers.is_empty() { + let connected_message = Message::UsersConnected(UsersConnected { + users: vec![verification_key.clone()], + }); + let subscribed_message = Message::Subscribe(Subscribe { + topics: topics.clone(), + }); + + // Arc and serialize the messages, prepare for sending + let connected_message: Arc> = Arc::from(remove_remote_and_exit_on_error!( + connected_message.serialize(), + inner, + verification_key + )); + let subscribed_message: Arc> = Arc::from(remove_remote_and_exit_on_error!( + subscribed_message.serialize(), + inner, + verification_key + )); + + // Send the messages + for broker in brokers { + let _ = broker.1.queue_message_back(connected_message.clone()); + let _ = broker.1.queue_message_back(subscribed_message.clone()); + } + }; - // Add the connection to our maps + // Add the user for their topics inner - .user_connections + .user_connection_lookup .write() .await - .subscribe_connection_to_broadcast(connection.clone(), topics); + .subscribe_connection_to_topics(sender.clone(), topics); + // Add the user for their keys inner - .user_connections + .user_connection_lookup .write() .await - .subscribe_connection_to_direct(connection.clone(), verification_key) + .subscribe_connection_to_keys(&sender, vec![verification_key.clone()]); + + info!( + "received connection from user {:?}", + hex::encode(&verification_key) + ); + + // The message receive loop. On exit, remove the broker's connection everywhere + while let Ok(message) = receiver.recv_message().await { + // See what type of message this is + match message { + // A direct message. This is supposed to go to the interested broker AND/OR interested user only + Message::Direct(direct) => { + // Find out where the message is supposed to go + let broker_connection = inner + .broker_connection_lookup + .read() + .await + .get_connection_by_key(&direct.recipient); + + let user_connection = inner + .user_connection_lookup + .read() + .await + .get_connection_by_key(&direct.recipient); + + if let Some(connection) = user_connection { + // `Arc` and serialize it. + // TODO IMP DOCUMENT INVARIANT + let message = Arc::new( + Message::Direct(direct) + .serialize() + .expect("serialization to succeed"), + ); + + // If we fail to send the message, remove the user + if connection.queue_message_back(message).is_err() { + remove_local!(inner, sender, user); + remove_remote!(inner, verification_key); + } + } else if let Some(connection) = broker_connection { + // `Arc` and serialize it. + let message = Arc::new( + Message::Direct(direct) + .serialize() + .expect("serialization to succeed"), + ); + + // If we fail to send the message, remove the broker + if connection.queue_message_back(message).is_err() { + remove_local!(inner, connection, broker); + }; + }; + } + + // A broadcast message. This is supposed to go to the interested brokers AND/OR interested users only + Message::Broadcast(broadcast) => { + // Figure out which brokers this message should go to + let broker_connections = inner + .broker_connection_lookup + .read() + .await + .get_connections_by_topic(broadcast.topics.clone()); + + // Figure out which users the message is supposed to go + let user_connections = inner + .user_connection_lookup + .read() + .await + .get_connections_by_topic(broadcast.topics.clone()); + + // If there are any users, queue the message to send for all of them + if !(broker_connections.is_empty() && user_connections.is_empty()) { + // Create a new `Data` and `Arc` it. + let message = Arc::new( + Message::Broadcast(broadcast) + .serialize() + .expect("serialization failed"), + ); + + // For each broker, send them the message. If we fail, remove them + for broker in broker_connections { + if broker.queue_message_back(message.clone()).is_err() { + remove_local!(inner, broker, broker); + } + } + + // For each broker, send them the message. If we fail, remove them + for user in user_connections { + if user.queue_message_back(message.clone()).is_err() { + remove_local!(inner, sender, user); + remove_remote!(inner, verification_key); + } + } + } + } + + // If we receive a subscription from the user, send to other brokers and update locally + Message::Subscribe(subscribe) => { + // Send the information to other brokers, if any. + // Remove them if we failed to send to them, + // TODO: WAL HERE maybe + let brokers: Vec<(BrokerIdentifier, Arc>)> = + inner.other_brokers.read().await.iter().cloned().collect(); + if !brokers.is_empty() { + let subscribed_message = Message::Subscribe(subscribe); + + // Arc and serialize the messages, prepare for sending + let subscribed_message: Arc> = Arc::from( + subscribed_message + .serialize() + .expect("serialization to succeed"), + ); + + // Send the messages + for broker in brokers { + // TODO: consider failing here + let _ = broker.1.queue_message_back(subscribed_message.clone()); + } + }; + } + + // If we receive an unsubscription from the user, send to other brokers and update locally + // TODO: THIS + Message::Unsubscribe(_) + | Message::AuthenticateResponse(_) + | Message::AuthenticateWithKey(_) + | Message::AuthenticateWithPermit(_) + | Message::UsersConnected(_) + | Message::UsersDisconnected(_) => {} + } + } + + // If we fail, remove locally and remote (with other brokers) + remove_local!(inner, sender, user); + remove_remote!(inner, verification_key); } /// The main loop for a broker. @@ -307,7 +692,7 @@ where if let Err(err) = redis_client .perform_heartbeat( // todo: actually pull in this number - 0, + inner.user_connection_lookup.read().await.get_key_count() as u64, Duration::from_secs(60), ) .await @@ -320,8 +705,15 @@ where match redis_client.get_other_brokers().await { Ok(brokers) => { // Calculate the difference, spawn tasks to connect to them - // TODO for broker in brokers.difference(&inner.brokers_connected.read()) { - for broker in brokers { + for broker in brokers.difference( + &inner + .other_brokers + .read() + .await + .iter() + .map(|(identifier, _)| identifier.clone()) + .collect(), + ) { // TODO: make this into a separate function // Extrapolate the address to connect to let to_connect_address = broker.broker_advertise_address.clone(); @@ -332,17 +724,14 @@ where // Spawn task to connect to a broker we haven't seen spawn(async move { // Connect to the broker - let connection = match BrokerProtocolType::Connection::connect( - to_connect_address, - ) - .await - { - Ok(connection) => connection, - Err(err) => { - error!("failed to connect to broker: {err}"); - return; - } - }; + let connection = + match BrokerProtocolType::connect(to_connect_address).await { + Ok(connection) => connection, + Err(err) => { + error!("failed to connect to broker: {err}"); + return; + } + }; // Handle the broker connection Self::handle_broker_connection(inner, connection, true).await; diff --git a/broker/src/main.rs b/broker/src/main.rs index 54f04d4..6439693 100644 --- a/broker/src/main.rs +++ b/broker/src/main.rs @@ -54,8 +54,8 @@ async fn main() -> Result<()> { redis_endpoint: args.redis_endpoint, keypair: proto::crypto::KeyPair { - signing_key, verification_key, + signing_key, }, // TODO: clap this @@ -64,10 +64,11 @@ async fn main() -> Result<()> { }; // Create new `Broker` - let marshal = Broker::::new(broker_config).await?; + // Uses TCP from broker connections and Quic for user connections. + let broker = Broker::::new(broker_config).await?; // Start the main loop, consuming it - marshal.start().await?; + broker.start().await?; Ok(()) } diff --git a/broker/src/state.rs b/broker/src/state.rs index 67763f2..2b4169d 100644 --- a/broker/src/state.rs +++ b/broker/src/state.rs @@ -1,192 +1,175 @@ +//! The following crate defines the internal state-tracking primitives as used +//! by the broker. + use std::{ collections::{HashMap, HashSet}, - hash::Hash, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, - }, - time::Duration, + sync::Arc, }; -use jf_primitives::signatures::SignatureScheme as JfSignatureScheme; -// TODO: maybe use Tokio's RwLock -use parking_lot::RwLock; use proto::{ - connection::protocols::{Connection, Protocol}, - crypto::Serializable, + connection::{batch::BatchedSender, protocols::Protocol}, message::Topic, }; -use tokio::{spawn, sync::Mutex, time::Instant}; - -pub struct ConnectionLookup< - SignatureScheme: JfSignatureScheme, - ProtocolType: Protocol, -> where - SignatureScheme::VerificationKey: Serializable, -{ - direct_message_lookup: - HashMap>>, - broadcast_message_lookup: HashMap>>>, - inverse_broadcast_message_lookup: - HashMap>, HashSet>, + +/// `ConnectionLookup` is what we use as a broker to "look up" where messages are supposed +/// to be directed to. +pub struct ConnectionLookup { + /// What we use to look up direct messages. The mapping is key -> sender + key_to_connection: HashMap, Arc>>, + /// Map is sender -> key. Helps us remove a sender on disconnection + connection_to_keys: HashMap>, HashSet>>, + /// What we use to look up broadcast messages. The mapping is topic -> set[sender] + topic_to_connections: HashMap>>>, + /// What we use when removing a key in O(1) from the forward broadcast map. The mapping is + /// sender -> set[topic]. + connection_to_topics: HashMap>, HashSet>, } -impl< - SignatureScheme: JfSignatureScheme, - ProtocolType: Protocol, - > Default for ConnectionLookup -where - SignatureScheme::Signature: Serializable, - SignatureScheme::VerificationKey: Serializable, - SignatureScheme::SigningKey: Serializable, -{ +impl Default for ConnectionLookup { + /// The default imeplementation is to just return empty maps. We need this because + /// of the trait bounds. fn default() -> Self { Self { - direct_message_lookup: HashMap::default(), - broadcast_message_lookup: HashMap::default(), - inverse_broadcast_message_lookup: HashMap::default(), + key_to_connection: HashMap::default(), + connection_to_keys: HashMap::default(), + topic_to_connections: HashMap::default(), + connection_to_topics: HashMap::default(), } } } -impl< - SignatureScheme: JfSignatureScheme, - ProtocolType: Protocol, - > ConnectionLookup -where - SignatureScheme::VerificationKey: Serializable, -{ - pub fn subscribe_connection_to_broadcast( +impl ConnectionLookup { + /// Get the count of all keys + pub fn get_key_count(&self) -> usize{ + self.key_to_connection.len() + } + + /// This returns all keys we are currently responsible for. + pub fn get_all_keys(&self) -> Vec> { + // Iterate over every key in the direct lookup and return it. + self.key_to_connection.keys().cloned().collect() + } + + /// This returns all topics that we are currently responsible for. + pub fn get_all_topics(&self) -> Vec { + // TODO: figure out if we need a clone here + // Iterate over every key in the broadcast lookup and return it. + self.topic_to_connections.keys().cloned().collect() + } + + /// This gets the associated connection for a direct message (if existing) + pub fn get_connection_by_key(&self, key: &Vec) -> Option>> { + // Look up the direct message key and return it. + self.key_to_connection.get(key).cloned() + } + + /// Subscribe a connection to some keys. This is used on the broker end + /// when we receive either a connection from that user, or the message that a broker + /// is interested in a particular user. + pub fn subscribe_connection_to_keys( + &mut self, + connection: &Arc>, + keys: Vec>, + ) { + for key in keys { + // Insert to the direct message lookup + self.key_to_connection.insert(key.clone(), connection.clone()); + + // Insert to the inverse + self.connection_to_keys + .entry(connection.clone()) + .or_default() + .insert(key); + }; + } + + /// Unsubscribe a connection from a particular key. This is used on the broker end + /// when we lose a connection to a user, or a broker says they're not interested anymore. + pub fn unsubscribe_connection_from_keys(&mut self, keys: Vec>) { + for key in keys { + // Remove the key from the lookup + self.key_to_connection.remove(&key); + } + } + + /// Fully unsubscribes a connection from all messages. Used to completely wipe a connection + /// when we are disconnected. + pub fn unsubscribe_connection(&mut self, connection: &Arc>) { + if let Some(keys) = self.connection_to_keys.remove(connection) { + for key in keys { + self.key_to_connection.remove(&key); + } + }; + + if let Some(topics) = self.connection_to_topics.remove(connection) { + for topic in topics { + self.topic_to_connections.remove(&topic); + } + } + } + + /// Look up the connections that are interested in a particular topic so we can + /// broadcast messages. + pub fn get_connections_by_topic( + &self, + topics: Vec, + ) -> HashSet>> { + let mut all_connections = HashSet::new(); + + // Since we don't want the intersection, iterate and add over every topic + for topic in topics { + // If the topic exists, add to our collection of connections. + if let Some(connections) = self.topic_to_connections.get(&topic) { + all_connections.extend(connections.clone()); + } + } + + all_connections + } + + /// This subscribes a particular connection to some topics. + pub fn subscribe_connection_to_topics( &mut self, - connection: Arc>, + connection: Arc>, topics: Vec, ) { - //topic -> [connection] + // Add the connection to each topic. + // topic -> [connection] for topic in topics.clone() { - self.broadcast_message_lookup + self.topic_to_connections .entry(topic) .or_default() .insert(connection.clone()); } - //connection -> [topic] - self.inverse_broadcast_message_lookup + // Add each topic to the connection (this is for O(1) removal later) + // connection -> [topic] + self.connection_to_topics .entry(connection) .or_default() .extend(topics); } - pub fn unsubscribe_connection_from_broadcast( + /// This unsubscribes a particular connection from a topic. + pub fn unsubscribe_connection_from_topics( &mut self, - connection: Arc>, + connection: &Arc>, topics: Vec, ) { - //topic -> [connection] + // For each topic, remove connection from it. + // topic -> [connection] for topic in topics.clone() { - // remove connection from topic, and remove topic if empty - if let Some(connections) = self.broadcast_message_lookup.get_mut(&topic) { - connections.remove(&connection); + // Remove connection from topic + if let Some(connections) = self.topic_to_connections.get_mut(&topic) { + connections.remove(connection); } } - //key -> [topic] - if let Some(connection_topics) = self.inverse_broadcast_message_lookup.get_mut(&connection) - { + // Remove the topic from the connection, if existing. + // key -> [topic] + if let Some(connection_topics) = self.connection_to_topics.get_mut(connection) { for topic in topics { connection_topics.remove(&topic); } } } - - pub fn subscribe_connection_to_direct( - &mut self, - connection: Arc>, - key: SignatureScheme::VerificationKey, - ) { - self.direct_message_lookup.insert(key, connection); - } - - pub fn unsubscribe_connection_from_direct(&mut self, key: SignatureScheme::VerificationKey) { - self.direct_message_lookup.remove(&key); - } -} - -pub struct ConnectionWithQueue { - queue: Mutex>>>, - connection: ProtocolType::Connection, - - current_size: AtomicU64, - last_sent: RwLock, - - min_duration: Duration, - min_size: u64, -} - -impl PartialEq for ConnectionWithQueue { - fn eq(&self, other: &Self) -> bool { - self.connection == other.connection - } -} - -impl Eq for ConnectionWithQueue { - fn assert_receiver_is_total_eq(&self) {} -} - -impl Hash for ConnectionWithQueue { - fn hash(&self, state: &mut H) { - self.connection.hash(state); - } - - /// This just calls `hash` on each item in the slice. - fn hash_slice(data: &[Self], state: &mut H) - where - Self: Sized, - { - data.iter().for_each(|item| item.hash(state)); - } -} - -impl ConnectionWithQueue { - pub fn from_connection_and_params( - connection: ProtocolType::Connection, - min_duration: Duration, - min_size: u64, - ) -> Self { - Self { - queue: Mutex::default(), - connection, - current_size: AtomicU64::default(), - last_sent: RwLock::from(Instant::now()), - min_duration, - min_size, - } - } - - pub async fn add_or_queue_message(&self, message: Arc>) { - // Push the reference to the message - let message_length = message.len() as u64; - let mut queue_guard = self.queue.lock().await; - queue_guard.push(message); - - // Update our size - let before_send_size = self - .current_size - .fetch_add(message_length, Ordering::Relaxed); - - // Bounds check to see if we should send - if (before_send_size + message_length) >= self.min_size - || self.last_sent.read().elapsed() >= self.min_duration - { - // Move messages out - // TODO: VEC WITH CAPACITY HERE - let messages = std::mem::replace(&mut *queue_guard, Vec::new()); - - // Spawn a task to flush our queue - // TODO: see if it's faster to not have this here - let connection = self.connection.clone(); - spawn(async move { - // Send the entire batch of messages - let _ = connection.send_messages_raw(messages).await; - }); - } - } } diff --git a/client/src/lib.rs b/client/src/lib.rs index 949a52e..f2b8ba9 100644 --- a/client/src/lib.rs +++ b/client/src/lib.rs @@ -37,29 +37,13 @@ where SignatureScheme::VerificationKey: Serializable, SignatureScheme::SigningKey: Serializable, { - /// Creates a new client from the given `Config`. Immediately will attempt - /// a conection if none is supplied. - /// + /// Creates a new `Retry` from a configuration. + /// /// # Errors - /// Errors if the downstream `Retry` object was unable to be made. - /// This usually happens when we can't bind to the specified endpoint. + /// If the initial connection fails pub async fn new(config: Config) -> Result { - Self::new_with_connection(config, Option::None).await - } - - /// Creates a new client from the given `Config` and an optional `Connection`. - /// Proxies the config to the `Retry` constructor since a `Client` is just a - /// light wrapper. - /// - /// # Errors - /// Errors if the downstream `Retry` object was unable to be created. - /// This usually happens when we can't bind to the specified endpoint. - pub async fn new_with_connection( - config: Config, - connection: Option, - ) -> Result { Ok(Self(bail!( - Retry::from_config_and_connection(config, connection).await, + Retry::from_config(config).await, Connection, "failed to create client" ))) @@ -78,44 +62,39 @@ where } /// Sends a pre-serialized message to the server, denoting recipients in the form - /// of a vector of topics. Use `send_message_raw` when the message is already - /// formed. If it fails, we return an error but try to initiate a new connection + /// of a vector of topics. If it fails, we return an error but try to initiate a new connection /// in the background. /// /// # Errors /// If the connection or serialization has failed - pub async fn send_broadcast_message(&self, topics: Vec, message: Vec) -> Result<()> { - // TODO: conditionally match error on whether deserialization OR the connection failed - + pub fn send_broadcast_message(&self, topics: Vec, message: Vec) -> Result<()> { // Form and send the single message - self.send_message(Message::Broadcast(Broadcast { topics, message })) - .await + self.send_message(&Message::Broadcast(Broadcast { topics, message })) } /// Sends a pre-serialized message to the server, denoting interest in delivery - /// to a single recipient. Use `send_message_raw` when the message is already formed. + /// to a single recipient. /// /// # Errors /// If the connection or serialization has failed - pub async fn send_direct_message( + pub fn send_direct_message( &self, - recipient: SignatureScheme::VerificationKey, + recipient: &SignatureScheme::VerificationKey, message: Vec, ) -> Result<()> { // Serialize recipient to a byte array before sending the message // TODO: maybe we can cache this. let recipient_bytes = bail!( - crypto::serialize(&recipient), + crypto::serialize(recipient), Serialize, "failed to serialize recipient" ); // Form and send the single message - self.send_message(Message::Direct(Direct { + self.send_message(&Message::Direct(Direct { recipient: recipient_bytes, message, })) - .await } /// Sends a message to the server that asserts that this client is interested in @@ -123,8 +102,6 @@ where /// /// # Errors /// If the connection or serialization has failed - /// - /// TODO IMPORTANT: see if we want this, or if we'd prefer `set_subscriptions()` pub async fn subscribe(&self, topics: Vec) -> Result<()> { // Lock subscriptions here so we maintain parity during a reconnection let mut subscribed_guard = self.0.inner.subscribed_topics.write().await; @@ -137,10 +114,9 @@ where // Send the topics bail!( - self.send_message(Message::Subscribe(Subscribe { + self.send_message(&Message::Subscribe(Subscribe { topics: topics_to_send.clone() - })) - .await, + })), Connection, "failed to send subscription message" ); @@ -173,10 +149,9 @@ where // Send the topics bail!( - self.send_message(Message::Unsubscribe(Unsubscribe { + self.send_message(&Message::Unsubscribe(Unsubscribe { topics: topics_to_send.clone() - })) - .await, + })), Connection, "failed to send unsubscription message" ); @@ -192,12 +167,12 @@ where Ok(()) } - /// Sends a pre-formed message over the wire. Various functions make use - /// of this one downstream. + /// Sends a message over the wire. Various functions make use + /// of this one upstream. /// /// # Errors /// - if the downstream message sending fails. - pub async fn send_message(&self, message: Message) -> Result<()> { - self.0.send_message(message).await + pub fn send_message(&self, message: &Message) -> Result<()> { + self.0.send_message(message) } } diff --git a/client/src/main.rs b/client/src/main.rs index cec85c1..ccd3897 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -1,50 +1,144 @@ //! The following is an example of a Push CDN client implementation. +//! We spawn two clients. In a single-broker run, this lets them connect +//! cross-broker. -use std::{marker::PhantomData, time::Duration}; +use std::{marker::PhantomData, sync::Arc}; use client::{Client, Config}; use proto::{ - connection::protocols::quic::Quic, + connection::protocols::{quic::Quic}, crypto::{self, KeyPair}, error::Result, - message::Topic, + message::{Message, Topic}, }; use jf_primitives::signatures::bls_over_bn254::BLSOverBN254CurveSignatureScheme as BLS; use rand::{rngs::StdRng, SeedableRng}; -use tokio::time::sleep; +use tokio::{join, spawn}; #[tokio::main] async fn main() -> Result<()> { // Initialize tracing tracing_subscriber::fmt::init(); - // Generate a random keypair - let (signing_key, verification_key) = + // Generate two random keypairs, one for each client + let (signing_key_1, verification_key_1) = crypto::generate_random_keypair::(StdRng::from_entropy())?; - // Create a client + let (signing_key_2, verification_key_2) = + crypto::generate_random_keypair::(StdRng::from_entropy())?; + + // Create our first client // TODO: constructors for config - let client = Client::::new(Config { - endpoint: "127.0.0.1:8082".to_string(), - keypair: KeyPair { - signing_key, - verification_key, - }, - subscribed_topics: vec![Topic::DA, Topic::Global], - - pd: PhantomData, - }) - .await?; - - loop { - // Send a direct message to ourselves - let _ = client - .send_direct_message(verification_key, vec![123]) - .await; - - // Receive the direct message (from ourselves) - println!("{:?}", client.receive_message().await); - sleep(Duration::from_secs(3)).await; - } + let client1 = Arc::new( + // We are running with the `BLS` key signing algorithm + // and `Quic` as a networking protocol. + Client::::new(Config { + // Our marshal address, locally running on port 8082 + endpoint: "127.0.0.1:8082".to_string(), + keypair: KeyPair { + signing_key: signing_key_1, + verification_key: verification_key_1, + }, + + // The topics we want to subscribe to initially + subscribed_topics: vec![Topic::DA, Topic::Global], + + // TODO: remove this via means of constructor + pd: PhantomData, + }) + .await?, + ); + + // Create our second client + let client2 = Arc::new( + Client::::new(Config { + // This is the same marshal, but a possibly different broker. + endpoint: "127.0.0.1:8082".to_string(), + keypair: KeyPair { + signing_key: signing_key_2, + verification_key: verification_key_2, + }, + subscribed_topics: vec![Topic::DA, Topic::Global], + pd: PhantomData, + }) + .await?, + ); + + // Run our first client, which sends a message to our second. + let client1 = spawn(async move { + // Clone our client + let client1_ = client1.clone(); + + // The sending side + let jh1 = spawn(async move { + // Send a message to client 2 + let message = "hello client2"; + client1_ + .send_direct_message(&verification_key_2, "hello client2".as_bytes().to_vec()) + .expect("failed to send message"); + + println!("client 1 sent \"{message}\""); + }); + + // The receiving side + let jh2 = spawn(async move { + let message = client1 + .receive_message() + .await + .expect("failed to receive message"); + + if let Message::Direct(direct) = message { + println!( + "client 1 received {}", + String::from_utf8(direct.message).expect("failed to deserialize message") + ); + } else { + panic!("received wrong message type"); + } + }); + + let _ = tokio::join!(jh1, jh2); + }); + + // Run our second client, which sends a message to our first. + let client2 = spawn(async move { + // Clone our client + let client2_ = client2.clone(); + + // The sending side + let jh1 = spawn(async move { + // Send a message to client 2 + let message = "hello client1"; + client2_ + .send_direct_message(&verification_key_1, "hello client1".as_bytes().to_vec()) + .expect("failed to send message"); + + println!("client 2 sent \"{message}\""); + }); + + // The receiving side + let jh2 = spawn(async move { + let message = client2 + .receive_message() + .await + .expect("failed to receive message"); + + if let Message::Direct(direct) = message { + println!( + "client 2 received {}", + String::from_utf8(direct.message).expect("failed to deserialize message") + ); + } else { + panic!("received wrong message type") + } + }); + + let _ = tokio::join!(jh1, jh2); + }); + + // Wait for both to finish + let _ = join!(client1, client2); + + Ok(()) } diff --git a/client/src/retry.rs b/client/src/retry.rs index 2b779e2..c97e7bf 100644 --- a/client/src/retry.rs +++ b/client/src/retry.rs @@ -12,14 +12,15 @@ use jf_primitives::signatures::SignatureScheme as JfSignatureScheme; use proto::{ connection::{ auth::user::UserAuth, - protocols::{Connection, Protocol}, + batch::BatchedSender, + protocols::{Protocol, Receiver}, }, crypto::{KeyPair, Serializable}, error::{Error, Result}, message::{Message, Topic}, }; use tokio::{ - sync::{RwLock, Semaphore}, + sync::{Mutex, RwLock, Semaphore}, time::sleep, }; use tracing::error; @@ -48,8 +49,13 @@ pub struct Inner< /// or a marshal. endpoint: String, - /// The underlying connection, which we modify to facilitate reconnections. - connection: RwLock, + /// The send-side of the connection. We can `RwLock` here because the BatchedSender + /// is already using interior mutability tricks. + sender: RwLock>, + + /// The receive side of the connection. We need a write-lock here because it needs to be + /// mutable. TODO: do something like `BatchedSender` but with `OwnedReader` or something. + receiver: Mutex, /// The task that runs in the background that reconnects us when we need /// to be. This is so we don't spawn multiple tasks at once @@ -57,9 +63,9 @@ pub struct Inner< pub keypair: KeyPair, - /// The topics we're currently subscribed to. We need this here so we can send our subscriptions + /// The topics we're currently subscribed to. We need this so we can send our subscriptions /// when we connect to a new server. - pub subscribed_topics: Arc>>, + pub subscribed_topics: RwLock>, /// Phantom data that lets us use `ProtocolType`, `AuthFlow`, and /// `SignatureScheme` downstream. @@ -91,69 +97,60 @@ pub struct Config< /// and receiving messages. You can specify the operation and it /// will reconnect on the operation's failure, while handling all /// reconnection logic and synchronization patterns. -/// -/// TODO: document invariant with "messages will not retry" macro_rules! try_with_reconnect { - ($self: expr, $operation: ident, $($arg:tt)*) => {{ - // Acquire read guard for sending and receiving messages - let Ok(read_guard) = $self.inner.connection.try_read() else { - return Err(Error::Connection("message failed: reconnection in progress".to_string())); - }; - + ($self: expr, $subject: ident, $out: expr) => {{ // Perform operation, see if it errors - let operation = read_guard.$operation($($arg)*).await; - match operation{ + match $out { Ok(res) => res, Err(err) => { - // Acquire semaphore. If another task is doing this, just return an error - // TODO: global sleep. If we try to connect twice, it happens sequentially without waiting (because the sleep - // only happens on the failed case. We can maybe store a variable somewhere and wait for that. - if $self.inner.reconnect_semaphore.try_acquire().is_ok() { - // Acquire write guard, drop read guard - drop(read_guard); - - // Clone everything we need to connect - // TODO: we want to minimize cloning this. We should sign a message - // earlier. - let inner = $self.inner.clone(); - - tokio::spawn(async move{ - // Get write guard on connection so we can write to it - let mut write_guard = inner.connection.write().await; - - // Loop to connect and authenticate - let connection = loop { - // Create a connection - match connect_and_authenticate::( - &inner.endpoint, - &inner.keypair, - inner.subscribed_topics.read().await.clone() - ) - .await{ - Ok(connection) => break connection, - Err(err) => { - error!("failed connection: {err}"); - // Sleep so we don't overload the server - sleep(Duration::from_secs(5)).await; - } - } - }; - - // Set connection to new connection - *write_guard = connection; - - // Drop here so other tasks can start sending messages - drop(write_guard); - }); + // Acquire semaphore. If another task is doing this, just return an error + if $self.inner.reconnect_semaphore.try_acquire().is_ok() { + // Clone everything we need to connect + let inner = $self.inner.clone(); + + // Spawn a task to reconnect + tokio::spawn(async move { + // Get write guard on connection so we can write to it + let mut send_guard = inner.sender.write().await; + let mut receive_guard = inner.receiver.lock().await; + + // Loop to connect and authenticate + let connection = loop { + // Create a connection + match connect_and_authenticate::( + &inner.endpoint, + &inner.keypair, + inner.subscribed_topics.read().await.clone(), + ) + .await + { + Ok(connection) => break connection, + Err(err) => { + error!("failed connection: {err}"); + // Sleep so we don't overload the server + sleep(Duration::from_secs(5)).await; + } + } + }; + + // Update sender and receiver + // TODO: parameterize duration and size + *send_guard = + BatchedSender::from(connection.0, Duration::from_millis(50), 1500); + *receive_guard = connection.1; + + drop(send_guard); + drop(receive_guard); + }); + } + + // If somebody is already trying to reconnect, fail instantly + return Err(Error::Connection(format!( + "connection failed, reconnecting to endpoint: {err}" + ))); + } } - - // If somebody is already trying to reconnect, fail instantly - return Err(Error::Connection(format!( - "connection failed, reconnecting to endpoint: {err}" - ))); - } - } -}}; + }}; } impl< @@ -165,20 +162,15 @@ where SignatureScheme::VerificationKey: Serializable, SignatureScheme::SigningKey: Serializable, { - /// Creates a new `Retry` connection from a `Config` and an (optional) pre-existing - /// `Fallible` connection. - /// - /// This allows us to create elastic clients that always try to maintain a connection - /// with each other. + /// Creates a new `Retry` connection from a `Config` + /// Attempts to make an initial connection. + /// This allows us to create elastic clients that always try to maintain a connection. /// /// # Errors /// - If we are unable to either parse or bind an endpoint to the local address. /// - If we are unable to make the initial connection /// TODO: figure out if we want retries here - pub async fn from_config_and_connection( - config: Config, - maybe_connection: Option, - ) -> Result { + pub async fn from_config(config: Config) -> Result { // Extrapolate values from the underlying client configuration let Config { endpoint, @@ -188,36 +180,31 @@ where } = config; // Wrap subscribed topics so we can use it now and later - let subscribed_topics = Arc::new(RwLock::new(HashSet::from_iter(subscribed_topics))); - - // Perform the initial connection and authentication if not provided. - // This is to validate that we have correct parameters and all. - // - // TODO: cancel conditionally depending on what kind of error, or retry- - // based. - // - // TODO: clean this up - let connection = if let Some(connection) = maybe_connection { - connection - } else { - bail!( - connect_and_authenticate::( - &endpoint, - &keypair, - subscribed_topics.read().await.clone() - ) - .await, - Connection, - "failed initial connection" + let subscribed_topics = RwLock::new(HashSet::from_iter(subscribed_topics)); + + // Perform the initial connection and authentication + let connection = bail!( + connect_and_authenticate::( + &endpoint, + &keypair, + subscribed_topics.read().await.clone() ) - }; + .await, + Connection, + "failed initial connection" + ); // Return the slightly transformed connection. Ok(Self { inner: Arc::from(Inner { endpoint, - // Use the existing connection - connection: RwLock::from(connection), + // TODO: parameterize batch params + sender: RwLock::from(BatchedSender::from( + connection.0, + Duration::from_millis(50), + 1500, + )), + receiver: Mutex::from(connection.1), reconnect_semaphore: Semaphore::const_new(1), keypair, subscribed_topics, @@ -226,23 +213,30 @@ where }) } - /// Sends a message to the underlying fallible connection. Reconnection logic is here, - /// but retry logic needs to be handled by the caller (e.g. re-send messages) + /// Sends a message to the underlying fallible connection. Reconnection is handled under + /// the hood. Messages will fail if the connection is currently closed or reconnecting. /// /// # Errors - /// - If we fail to serialize the message - /// - If we are in the middle of reconnecting - /// - If the message sending failed - pub async fn send_message(&self, message: Message) -> Result<()> { + /// If the message sending fails. For example: + /// - If we are reconnecting + /// - If we are disconnected + pub fn send_message(&self, message: &Message) -> Result<()> { // Serialize the message - let message = Arc::from(bail!( + let message = bail!( message.serialize(), Serialize, "failed to serialize message" - )); + ); - // Try to send the message, reconnecting if needed - Ok(try_with_reconnect!(self, send_message_raw, message,)) + // Try to acquire the read lock. If we can't, we are reconnecting. + if let Ok(send_lock) = self.inner.sender.try_read() { + // Continue if we were able to acquire the lock + let out = send_lock.queue_message_back(Arc::from(message)); + Ok(try_with_reconnect!(self, send_lock, out)) + } else { + // Return an error if we're reconnecting + Err(Error::Connection("reconnection in progress".to_string())) + } } /// Receives a message from the underlying fallible connection. Reconnection logic is here, @@ -252,8 +246,17 @@ where /// - If we are in the middle of reconnecting /// - If the message receiving failed pub async fn receive_message(&self) -> Result { - // Try to send the message, reconnecting if needed - Ok(try_with_reconnect!(self, recv_message,)) + // We can use `try_lock` here because only two tasks will be using it: + // either we're receiving or somebody is reconnecting us. + if let Ok(mut receiver_guard) = self.inner.receiver.try_lock() { + // We were able to get the lock, we aren't reconnecting + let out = receiver_guard.recv_message().await; + Ok(try_with_reconnect!(self, send_lock, out)) + } else { + // We couldn't get the lock, we are reconnecting + // Return an error + Err(Error::Connection("reconnection in progress".to_string())) + } } } @@ -264,30 +267,33 @@ async fn connect_and_authenticate< marshal_endpoint: &str, keypair: &KeyPair, subscribed_topics: HashSet, -) -> Result +) -> Result<(ProtocolType::Sender, ProtocolType::Receiver)> where SignatureScheme::Signature: Serializable, SignatureScheme::VerificationKey: Serializable, SignatureScheme::SigningKey: Serializable, { // Make the connection to the marshal - let connection = bail!( - ProtocolType::Connection::connect(marshal_endpoint.to_owned()).await, + let mut connection = bail!( + ProtocolType::connect(marshal_endpoint.to_owned()).await, Connection, "failed to connect to endpoint" ); // Authenticate the connection to the marshal (if not provided) let (broker_address, permit) = bail!( - UserAuth::::authenticate_with_marshal(&connection, keypair) - .await, + UserAuth::::authenticate_with_marshal( + &mut connection, + keypair + ) + .await, Authentication, "failed to authenticate to marshal" ); // Make the connection to the broker - let connection = bail!( - ProtocolType::Connection::connect(broker_address).await, + let mut connection = bail!( + ProtocolType::connect(broker_address).await, Connection, "failed to connect to broker" ); @@ -295,7 +301,7 @@ where // Authenticate the connection to the broker bail!( UserAuth::::authenticate_with_broker( - &connection, + &mut connection, permit, subscribed_topics ) diff --git a/marshal/src/lib.rs b/marshal/src/lib.rs index 1189f9b..6ca48f2 100644 --- a/marshal/src/lib.rs +++ b/marshal/src/lib.rs @@ -10,7 +10,7 @@ use proto::{ bail, connection::{ auth::marshal::MarshalAuth, - protocols::{Listener, Protocol}, + protocols::{Listener, Protocol, Sender}, }, crypto::Serializable, error::{Error, Result}, @@ -66,8 +66,7 @@ where // Create the `Listener` from the bind address let listener = bail!( - ProtocolType::Listener::bind(bind_address, maybe_tls_cert_path, maybe_tls_key_path) - .await, + ProtocolType::bind(bind_address, maybe_tls_cert_path, maybe_tls_key_path).await, Connection, format!("failed to listen to address {}", bind_address) ); @@ -89,15 +88,18 @@ where /// Handles a user's connection, including authentication. pub async fn handle_connection( - connection: ProtocolType::Connection, + mut connection: (ProtocolType::Sender, ProtocolType::Receiver), mut redis_client: redis::Client, ) { // Verify (authenticate) the connection let _ = MarshalAuth::::verify_user( - &connection, + &mut connection, &mut redis_client, ) .await; + + // We don't care about this, just drop the connection immediately. + let _ = connection.0.finish().await; } /// The main loop for a marshal. diff --git a/proto/schema/messages.capnp b/proto/schema/messages.capnp index a9f2cf0..49c8fce 100644 --- a/proto/schema/messages.capnp +++ b/proto/schema/messages.capnp @@ -25,6 +25,7 @@ struct Message { usersConnected @7 :UsersConnected; # The wrapper for an `UsersDisconnected` message usersDisconnected @8 :UsersDisconnected; + } } @@ -100,19 +101,14 @@ struct Unsubscribe { topics @0: List(Topic); } - # A message that is used to convey to other brokers that user(s) have connected to us. struct UsersConnected { - # The users connected to us - users @0: List(User); + # The user [keys] connected to us + users @0: List(Data); } # A message that is used to convey to other brokers that user(s) have disconnected from us. struct UsersDisconnected { - # The users that have disconnected from us - users @0: List(User); -} - -struct User { - key @0: Data; + # The user [keys] that have disconnected from us + users @0: List(Data); } \ No newline at end of file diff --git a/proto/src/connection/auth/broker.rs b/proto/src/connection/auth/broker.rs index 66a6750..5558333 100644 --- a/proto/src/connection/auth/broker.rs +++ b/proto/src/connection/auth/broker.rs @@ -10,7 +10,7 @@ use tracing::error; use crate::{ bail, - connection::protocols::{Connection, Protocol}, + connection::protocols::{Protocol, Receiver, Sender}, crypto::{self, DeterministicRng, KeyPair, Serializable}, error::{Error, Result}, fail_verification_with_message, @@ -29,12 +29,13 @@ pub struct BrokerAuth< } /// We use this macro upstream to conditionally order broker authentication flows +/// TODO: do something else with these macros #[macro_export] macro_rules! authenticate_with_broker { ($connection: expr, $inner: expr) => { - // Prove to the other broker + // Authenticate with the other broker, returning their reconnect address match BrokerAuth::::authenticate_with_broker( - &$connection, + &mut $connection, &$inner.keypair, ) .await @@ -52,9 +53,9 @@ macro_rules! authenticate_with_broker { #[macro_export] macro_rules! verify_broker { ($connection: expr, $inner: expr) => { - // Wait for other brokers' proof + // Verify the other broker's authentication if let Err(err) = BrokerAuth::::verify_broker( - &$connection, + &mut $connection, &$inner.identifier, &$inner.keypair.verification_key, ) @@ -84,20 +85,19 @@ where /// - If authentication fails /// - If our connection fails pub async fn verify_user( - connection: &ProtocolType::Connection, + connection: &mut (ProtocolType::Sender, ProtocolType::Receiver), broker_identifier: &BrokerIdentifier, redis_client: &mut redis::Client, - ) -> Result<(SignatureScheme::VerificationKey, Vec)> { + ) -> Result<(Vec, Vec)> { // Receive the permit let auth_message = bail!( - connection.recv_message().await, + connection.1.recv_message().await, Connection, "failed to receive message from user" ); // See if we're the right type of message let Message::AuthenticateWithPermit(auth_message) = auth_message else { - // TODO: macro for this error thing fail_verification_with_message!(connection, "wrong message type"); }; @@ -128,10 +128,10 @@ where }); // Send the successful response to the user - let _ = connection.send_message(response_message).await; + let _ = connection.0.send_message(response_message).await; - // Serialize the verification key - let verification_key = bail!( + // Try to serialize the verification key + bail!( crypto::deserialize(&serialized_verification_key), Crypto, "failed to deserialize verification key" @@ -139,7 +139,7 @@ where // Receive the subscribed topics let subscribed_topics_message = bail!( - connection.recv_message().await, + connection.1.recv_message().await, Connection, "failed to receive message from user" ); @@ -151,7 +151,10 @@ where }; // Return the verification key - Ok((verification_key, subscribed_topics_message.topics)) + Ok(( + serialized_verification_key, + subscribed_topics_message.topics, + )) } /// Authenticate with a broker (as a broker). @@ -162,7 +165,7 @@ where /// - If we fail to authenticate /// - If we have a connection failure pub async fn authenticate_with_broker( - connection: &ProtocolType::Connection, + connection: &mut (ProtocolType::Sender, ProtocolType::Receiver), keypair: &KeyPair, ) -> Result { // Get the current timestamp, which we sign to avoid replay attacks @@ -208,14 +211,14 @@ where // Create and send the authentication message from the above operations bail!( - connection.send_message(message).await, + connection.0.send_message(message).await, Connection, "failed to send auth message to broker" ); // Wait for the response with the permit and address let response = bail!( - connection.recv_message().await, + connection.1.recv_message().await, Connection, "failed to receive message from broker" ); @@ -247,14 +250,19 @@ where Ok(broker_address) } + /// Verify a broker as a broker. + /// Will fail verification if it does not match our verification key. + /// + /// # Errors + /// - If verification has failed pub async fn verify_broker( - connection: &ProtocolType::Connection, + connection: &mut (ProtocolType::Sender, ProtocolType::Receiver), our_identifier: &BrokerIdentifier, our_verification_key: &SignatureScheme::VerificationKey, ) -> Result<()> { // Receive the signed message from the user let auth_message = bail!( - connection.recv_message().await, + connection.1.recv_message().await, Connection, "failed to receive message from user" ); @@ -309,7 +317,7 @@ where }); // Send the permit to the user, along with the public broker advertise address - let _ = connection.send_message(response_message).await; + let _ = connection.0.send_message(response_message).await; Ok(()) } diff --git a/proto/src/connection/auth/marshal.rs b/proto/src/connection/auth/marshal.rs index e6bd187..9c7f2de 100644 --- a/proto/src/connection/auth/marshal.rs +++ b/proto/src/connection/auth/marshal.rs @@ -10,7 +10,7 @@ use tracing::error; use crate::{ bail, - connection::protocols::{Connection, Protocol}, + connection::protocols::{Protocol, Receiver, Sender}, crypto::{self, Serializable}, error::{Error, Result}, fail_verification_with_message, @@ -47,12 +47,12 @@ where /// - If authentication fails /// - If our connection fails pub async fn verify_user( - connection: &ProtocolType::Connection, + connection: &mut (ProtocolType::Sender, ProtocolType::Receiver), redis_client: &mut redis::Client, ) -> Result<()> { // Receive the signed message from the user let auth_message = bail!( - connection.recv_message().await, + connection.1.recv_message().await, Connection, "failed to receive message from user" ); @@ -130,7 +130,7 @@ where }); // Send the permit to the user, along with the public broker advertise address - let _ = connection.send_message(response_message).await; + let _ = connection.0.send_message(response_message).await; Ok(()) } diff --git a/proto/src/connection/auth/mod.rs b/proto/src/connection/auth/mod.rs index 0f9af55..2b171c0 100644 --- a/proto/src/connection/auth/mod.rs +++ b/proto/src/connection/auth/mod.rs @@ -13,6 +13,7 @@ macro_rules! fail_verification_with_message { ($connection: expr, $context: expr) => { // Send the error message let _ = $connection + .0 .send_message(Message::AuthenticateResponse(AuthenticateResponse { permit: 0, context: $context.to_string(), diff --git a/proto/src/connection/auth/user.rs b/proto/src/connection/auth/user.rs index 38ff717..acc428d 100644 --- a/proto/src/connection/auth/user.rs +++ b/proto/src/connection/auth/user.rs @@ -10,7 +10,7 @@ use jf_primitives::signatures::SignatureScheme as JfSignatureScheme; use crate::{ bail, - connection::protocols::{Connection, Protocol}, + connection::protocols::{Protocol, Receiver, Sender}, crypto::{self, DeterministicRng, KeyPair, Serializable}, error::{Error, Result}, message::{AuthenticateWithKey, AuthenticateWithPermit, Message, Subscribe, Topic}, @@ -44,7 +44,7 @@ where /// - If we fail authentication /// - If our connection fails pub async fn authenticate_with_marshal( - connection: &ProtocolType::Connection, + connection: &mut (ProtocolType::Sender, ProtocolType::Receiver), keypair: &KeyPair, ) -> Result<(String, u64)> { // Get the current timestamp, which we sign to avoid replay attacks @@ -90,14 +90,14 @@ where // Create and send the authentication message from the above operations bail!( - connection.send_message(message).await, + connection.0.send_message(message).await, Connection, "failed to send auth message to marshal" ); // Wait for the response with the permit and address let response = bail!( - connection.recv_message().await, + connection.1.recv_message().await, Connection, "failed to receive message from marshal" ); @@ -132,7 +132,7 @@ where /// - If authentication fails /// - If our connection fails pub async fn authenticate_with_broker( - connection: &ProtocolType::Connection, + connection: &mut (ProtocolType::Sender, ProtocolType::Receiver), permit: u64, subscribed_topics: HashSet, ) -> Result<()> { @@ -141,14 +141,14 @@ where // Send the authentication message to the broker bail!( - connection.send_message(auth_message).await, + connection.0.send_message(auth_message).await, Connection, "failed to send message to broker" ); // Wait for a response let response_message = bail!( - connection.recv_message().await, + connection.1.recv_message().await, Connection, "failed to receive response message from broker" ); @@ -173,7 +173,7 @@ where topics: Vec::from_iter(subscribed_topics), }); bail!( - connection.send_message(topic_message).await, + connection.0.send_message(topic_message).await, Connection, "failed to send topics to broker" ); diff --git a/proto/src/connection/batch.rs b/proto/src/connection/batch.rs new file mode 100644 index 0000000..dffc702 --- /dev/null +++ b/proto/src/connection/batch.rs @@ -0,0 +1,300 @@ +//! This crate defines a batching system for sending messages, wherein +//! we spawn a task that owns the sender and have a handle to a channel it's +//! listening on. +//! +//! TODO: dynamic batch size and time + +use std::{collections::VecDeque, hash::Hash, marker::PhantomData, sync::Arc, time::Duration}; + +use rand::{rngs::StdRng, RngCore, SeedableRng}; +use tokio::{ + select, spawn, + sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, +}; +use tracing::error; + +use crate::{ + bail, + connection::protocols::Sender, + error::{Error, Result}, +}; + +use super::protocols::Protocol; + +/// This is the format we send messages to the task with. Can either be a control message +/// or a data message. +enum QueueMessage { + /// A data message is something we actually want to send. + Data(Arc>, Position), + /// A control message is an enshrined message that is supposed to control the stream. + Control(Control), +} + +/// This is coupled with a data message, where we denote the position in the queue we want +/// our message to go to. This can be useful for implemtning monotonic writes or write-ahead +/// logging. +enum Position { + /// Add our message to the front of the batch + Front, + /// Add our message to the back of the batch + Back, +} + +/// These are our control messages, which we use to tell the task to do something other than +/// send the message. +enum Control { + /// Freeze sending messages in the queue. This can be useful when coupled with `Position::Front` + /// so we can delay and order messages the way we want to. + Freeze, + /// Unfreeze the queue. Allows for sending messages again + Unfreeze, + /// Shut down the task. The task will be killed + Shutdown, +} + +/// `BatchedSender` is a wrapper around a send stream that owns the sender. It allows us +/// to queue messages for sending with a minimum time or size. Is clonable through an `Arc`. +pub struct BatchedSender { + /// The underlying channel that we receive messages over. + channel: UnboundedSender, + /// A unique, randomly generated ID that we use to compare and hash against. + stable_id: u64, + /// The `PhantomData` we need to use a generic protocol type. + pd: PhantomData, +} + +/// `PartialEq` here uses the randomly generated `stable_id` +impl PartialEq for BatchedSender { + fn eq(&self, other: &Self) -> bool { + self.stable_id == other.stable_id + } +} + +/// Asserts that `PartialEq` == `Eq` +impl Eq for BatchedSender {} + +/// `Hash` here uses the randomly generated `stable_id` +impl Hash for BatchedSender { + fn hash(&self, state: &mut H) { + self.stable_id.hash(state); + } +} + +/// The underlying queue object that a `BatchedSender`'s task operates over. +/// Contains the queue as fields for data tracking purposes. +pub struct Queue { + /// The actual message queue + inner: VecDeque>>, + + /// The current size of the queue, in bytes + current_size: u64, + + /// The maximum duration to wait before sending a message. + max_duration: Duration, + /// The maximum message size before sending, in bytes + max_size_in_bytes: u64, + + /// Whether or not the queue is currently frozen + frozen: bool, +} + +macro_rules! flush_queue { + ($queue:expr, $sender:expr) => { + // Atomically replace the inner with a new `VecDeque::default()` + let messages = std::mem::take(&mut $queue.inner); + // Reset the size + $queue.current_size = 0; + + // Send the replaced messages + if let Err(e) = $sender.send_messages(messages).await { + error!("message send failed: {e}"); + return; + }; + }; +} + +impl BatchedSender { + /// Freeze sending messages in the queue. This can be useful when coupled with `Position::Front` + /// so we can delay and order messages the way we want to. + /// + /// # Errors + /// - If the send-side is closed. + pub fn freeze(&self) -> Result<()> { + // Send a control message to freeze the queue + bail!( + self.channel.send(QueueMessage::Control(Control::Freeze)), + Connection, + "connection closed" + ); + + Ok(()) + } + + /// Unfreeze message sending operations in the queue. + /// + /// # Errors + /// - If the send-side is closed. + pub fn unfreeze(&self) -> Result<()> { + // Send a control message to unfreeze the queue + bail!( + self.channel.send(QueueMessage::Control(Control::Unfreeze)), + Connection, + "connection closed" + ); + + Ok(()) + } + + /// Queue a serialized message to the front of the queue. Can be useful in conjunction with + /// `.freeze()` or if we have an important message. + /// + /// # Errors + /// - If the send-side is closed. + pub fn queue_message_front(&self, message: Arc>) -> Result<()> { + // Send a data message + bail!( + self.channel + .send(QueueMessage::Data(message, Position::Front)), + Connection, + "connection closed" + ); + + Ok(()) + } + + /// Queue a serialized message to the front of the queue. + /// + /// # Errors + /// - If the send-side is closed. + pub fn queue_message_back(&self, message: Arc>) -> Result<()> { + // Send a data message + bail!( + self.channel + .send(QueueMessage::Data(message, Position::Back)), + Connection, + "connection closed" + ); + + Ok(()) + } + + /// Create a `BatchedSender` from a normal sender, along with a maximum duration and maximum + /// queue size before we flush. + pub fn from( + sender: ProtocolType::Sender, + max_duration: Duration, + max_size_in_bytes: u64, + ) -> Self { + // Create the send and receive sides of a channel. + let (send_side, receive_side) = unbounded_channel(); + + // Create a new queue from our parameters and defaults + let batch_params = Queue { + inner: VecDeque::default(), + current_size: 0, + max_duration, + max_size_in_bytes, + frozen: false, + }; + + // Spawn the sending task where the send handle moves into. Would normally use a `JoinHandle` shutdown + // but we don't have that luxury with the `async_compatibility_layer` + spawn(Self::batch_loop(sender, receive_side, batch_params)); + + // Return a sender with a unique `stable_id`. + Self { + stable_id: StdRng::from_entropy().next_u64(), + channel: send_side, + pd: PhantomData, + } + } + + /// This is the main loop that the send-side runs. This is where we deal with incoming + /// data and control messages. + async fn batch_loop( + mut sender: ProtocolType::Sender, + mut receiver: UnboundedReceiver, + mut queue: Queue, + ) { + // Create a timer that ticks every max interval. We reset it if we actually send a message. + let mut timer = tokio::time::interval(queue.max_duration); + + loop { + // Select on either a new message or a timer event + select! { + // Receive a message. Will return `None` if the send side is closed. + possible_message = receiver.recv() => { + let Some(message) = possible_message else { + // If the send-side is closed, drop everything and stop. + return + }; + + // See what type of message we have + match message{ + // A data message. This is a message that we actually want to add + // to the queue. + QueueMessage::Data(data, position) => { + // Get the current length of data. + let data_len = data.len(); + + // See in which position we wanted to add to the queue + match position{ + Position::Front => { + queue.inner.push_front(data); + } + Position::Back => { + queue.inner.push_back(data); + } + } + + // Increase the queue size by the data length + queue.current_size += data_len as u64; + + // If we're frozen, don't continue to any sending logic + if queue.frozen{ + continue + } + + // Bounds check to see if we should send + if queue.current_size >= queue.max_size_in_bytes{ + // Flush the queue, sending all in-flight messages. + flush_queue!(queue, sender); + } + } + + // We got a control message; a message we don't actually want to send. + QueueMessage::Control(control) => { + match control { + Control::Freeze => {queue.frozen = true} + Control::Unfreeze => {queue.frozen = false} + // Return if we see a shutdown message + Control::Shutdown => return + } + } + } + } + // We hit this when the timer expires without having sent a message. + _ = timer.tick() => { + // Don't do anything if the queue is currently frozen + if queue.frozen{ + continue + } + + // Flush the queue, sending all in-flight messages. + flush_queue!(queue, sender); + } + } + + // Reset the timer when we are done sending the message. + timer.reset(); + } + } +} + +// When we drop, we want to send the shutdown message to the sender. +impl Drop for BatchedSender { + fn drop(&mut self) { + // Shut down the channel + let _ = self.channel.send(QueueMessage::Control(Control::Shutdown)); + } +} diff --git a/proto/src/connection/mod.rs b/proto/src/connection/mod.rs index 0051cb3..f889199 100644 --- a/proto/src/connection/mod.rs +++ b/proto/src/connection/mod.rs @@ -2,4 +2,5 @@ //! for any network protocol. pub mod auth; +pub mod batch; pub mod protocols; diff --git a/proto/src/connection/protocols/mod.rs b/proto/src/connection/protocols/mod.rs index c2ca254..075050e 100644 --- a/proto/src/connection/protocols/mod.rs +++ b/proto/src/connection/protocols/mod.rs @@ -1,6 +1,6 @@ //! This module defines connections, listeners, and their implementations. -use std::{hash::Hash, net::SocketAddr, sync::Arc}; +use std::{collections::VecDeque, net::SocketAddr, sync::Arc}; use async_trait::async_trait; @@ -12,65 +12,132 @@ pub mod tcp; /// TODO: find out if there is a better way than the `u64` cast const _: [(); 0 - (!(usize::BITS >= u64::BITS)) as usize] = []; -pub trait Protocol: Send + Sync + 'static + Clone { - type Connection: Connection + Clone + PartialEq + Hash; - type Listener: Listener; -} - +/// The `Protocol` trait lets us be generic over a connection type (Tcp, Quic, etc). #[async_trait] -pub trait Connection: Send + Sync + 'static { - /// Receive a single message from the connection. +pub trait Protocol: Send + Sync + 'static { + // TODO: make these generic over reader/writer + // TODO: make these connection type that defines into_split + type Sender: Sender + Send + Sync; + type Receiver: Receiver + Send + Sync; + + type Listener: Listener + Send + Sync; + + /// Connect to a remote address, returning an instance of `Self`. /// /// # Errors - /// Errors if we either fail to receive the message. This usually means a connection problem. - async fn recv_message(&self) -> Result; + /// Errors if we fail to connect or if we fail to bind to the interface we want. + async fn connect(remote_endpoint: String) -> Result<(Self::Sender, Self::Receiver)>; - /// Send a single message over the connection. + /// Bind to the local address, returning an instance of `Listener`. /// /// # Errors - /// - If we fail to deliver the message - /// - If we fail to serialize the message - async fn send_message(&self, message: Message) -> Result<()>; + /// If we fail to bind to the given socket address + async fn bind( + bind_address: SocketAddr, + maybe_tls_cert_path: Option, + maybe_tls_key_path: Option, + ) -> Result; +} - /// Send a pre-formed message over the connection. +#[async_trait] +pub trait Sender { + /// Send a message over the connection. /// /// # Errors /// - If we fail to deliver the message. This usually means a connection problem. - async fn send_message_raw(&self, message: Arc>) -> Result<()>; + async fn send_message(&mut self, message: Message) -> Result<()>; /// Send a vector of pre-formed messages over the connection. /// /// # Errors /// - If we fail to deliver any of the messages. This usually means a connection problem. - async fn send_messages_raw(&self, messages: Vec>>) -> Result<()>; + async fn send_messages(&mut self, messages: VecDeque>>) -> Result<()>; - /// Connect to a remote address, returning an instance of `Self`. + /// Gracefully shuts down the outgoing stream, ensuring all data + /// has been written. /// /// # Errors - /// Errors if we fail to connect or if we fail to bind to the interface we want. - async fn connect(remote_endpoint: String) -> Result - where - Self: Sized; + /// - If we could not shut down the stream. + async fn finish(&mut self) -> Result<()>; } #[async_trait] -pub trait Listener: Send + Sync + 'static { - /// Bind to the local address, returning an instance of `Self`. +pub trait Receiver { + /// Receives a single message over the stream and deserializes + /// it. /// /// # Errors - /// If we fail to bind to the given socket address - async fn bind( - bind_address: SocketAddr, - maybe_tls_cert_path: Option, - maybe_tls_key_path: Option, - ) -> Result - where - Self: Sized; + /// - if we fail to receive the message + /// - if we fail deserialization + async fn recv_message(&mut self) -> Result; + /// Receives a single message over the stream without deserializing + /// it. + /// + /// # Errors + /// - if we fail to receive the message + async fn recv_message_raw(&mut self) -> Result>; +} + +#[async_trait] +pub trait Listener { /// Accept a connection from the local, bound socket. /// Returns a connection or an error if we encountered one. /// /// # Errors /// If we fail to accept a connection - async fn accept(&self) -> Result; + async fn accept(&self) -> Result<(Sender, Receiver)>; +} + +/// A macro to write a length-delimited (serialized) message to a stream. +#[macro_export] +macro_rules! write_length_delimited { + ($stream: expr, $message:expr) => { + // Write the message size to the stream + bail!( + $stream.write_u64($message.len() as u64).await, + Connection, + "failed to send message size" + ); + + // Write the message to the stream + bail!( + $stream.write_all(&$message).await, + Connection, + "failed to send message" + ); + }; +} + +/// A macro to read a length-delimited (serialized) message from a stream. +/// Has a bounds check for if the message is too big +#[macro_export] +macro_rules! read_length_delimited { + ($stream: expr) => {{ + // Read the message size from the stream + let message_size = bail!( + $stream.read_u64().await, + Connection, + "failed to read message size" + ); + + // Make sure the message isn't too big + if message_size > MAX_MESSAGE_SIZE { + return Err(Error::Connection( + "expected to receive message that was too big".to_string(), + )); + } + + // Create buffer of the proper size + let mut buffer = vec![0; usize::try_from(message_size).expect("64 bit system")]; + + // Read the message from the stream + bail!( + $stream.read_exact(&mut buffer).await, + Connection, + "failed to receive message from connection" + ); + + buffer + }}; } diff --git a/proto/src/connection/protocols/quic.rs b/proto/src/connection/protocols/quic.rs index 0587a31..f5f3409 100644 --- a/proto/src/connection/protocols/quic.rs +++ b/proto/src/connection/protocols/quic.rs @@ -4,177 +4,31 @@ use async_trait::async_trait; use quinn::{ClientConfig, Endpoint, ServerConfig}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::{ bail, bail_option, crypto::{self, SkipServerVerification}, error::{Error, Result}, message::Message, - MAX_MESSAGE_SIZE, + read_length_delimited, write_length_delimited, MAX_MESSAGE_SIZE, }; -use core::hash::Hash; -use std::{net::ToSocketAddrs, sync::Arc}; +use std::{collections::VecDeque, net::ToSocketAddrs, sync::Arc}; -use super::{Connection, Listener, Protocol}; +use super::{Listener, Protocol, Receiver, Sender}; /// The `Quic` protocol. We use this to define commonalities between QUIC /// listeners, connections, etc. -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Eq)] pub struct Quic; -/// We define the `Quic` protocol as being composed of both a QUIC listener -/// and connection. +#[async_trait] impl Protocol for Quic { - type Connection = QuicConnection; + type Sender = QuicSender; + type Receiver = QuicReceiver; type Listener = QuicListener; -} - -/// `QuicConnection` is a thin wrapper around `quinn::Connection` that implements -/// `Connection`. -#[derive(Clone)] -pub struct QuicConnection(pub quinn::Connection); - -/// `PartialEq` for a `QuicConnection` connection is determined by the `stable_id` since it -/// will not change for the duration of the connection. -impl PartialEq for QuicConnection { - fn eq(&self, other: &Self) -> bool { - self.0.stable_id() == other.0.stable_id() - } -} - -/// Assertion for `QuicConnection` that `PartialEq` == `Eq` -impl Eq for QuicConnection { - fn assert_receiver_is_total_eq(&self) {} -} - -/// `Hash` for a `QuicConnection` connection is determined by the `stable_id` since it -/// will not change for the duration of the connection. We just want to hash that. -impl Hash for QuicConnection { - fn hash(&self, state: &mut H) { - self.0.stable_id().hash(state); - } - - /// This just calls `hash` on each item in the slice. - fn hash_slice(data: &[Self], state: &mut H) - where - Self: Sized, - { - data.iter().for_each(|item| item.hash(state)); - } -} - -#[async_trait] -impl Connection for QuicConnection { - /// Receives a single message from the QUIC connection. Since we use - /// virtual streams as a message framing method, this function first accepts a stream - /// and then reads and deserializes a single message from it. - /// - /// # Errors - /// Errors if we either failed to accept the stream or receive the message over that stream. - /// This usually means a connection problem. - async fn recv_message(&self) -> Result { - // Accept the incoming unidirectional stream - let mut stream = bail!( - self.0.accept_uni().await, - Connection, - "failed to accept unidirectional stream" - ); - - // Read the full message, until the sender closes the stream - let message_bytes = bail!( - stream - .read_to_end(usize::try_from(MAX_MESSAGE_SIZE).expect("64 bit system")) - .await, - Connection, - "failed to read from stream" - ); - - // Deserialize and return the message - Ok(bail!( - Message::deserialize(&message_bytes), - Deserialize, - "failed to deserialize message" - )) - } - - /// Sends a single message to the QUIC connection. This function first opens a - /// stream and then serializes and sends a single message to it. - /// - /// # Errors - /// - If we fail to serialize the message - /// - If we fail to open the stream - /// - If we fail to send the message over that stream - /// This usually means a connection problem. - async fn send_message(&self, message: Message) -> Result<()> { - // Serialize the message - let message_bytes = bail!( - message.serialize(), - Serialize, - "failed to serialize message" - ); - - // Send the message - self.send_message_raw(Arc::from(message_bytes)).await - } - - /// Send a pre-formed message over the connection. - /// - /// # Errors - /// - If we fail to deliver the message. This usually means a connection problem. - async fn send_message_raw(&self, message: Arc>) -> Result<()> { - // Open the outgoing unidirectional stream - let mut stream = bail!( - self.0.open_uni().await, - Connection, - "failed to open unidirectional stream" - ); - - // Write the full message to the stream - bail!( - stream.write_all(&message).await, - Connection, - "failed to write to stream" - ); - - // Finish the stream, denoting to the peer that the - // message has been fully written - Ok(bail!( - stream.finish().await, - Connection, - "failed to finish stream" - )) - } - /// Send a vector of pre-formed message over the connection. - /// - /// TODO: FIGURE OUT IF WE WANT TO FRAME LIKE THIS. it may be more performant with batching - /// to not do it this way. - /// - /// # Errors - /// - If we fail to deliver any of the messages. This usually means a connection problem. - async fn send_messages_raw(&self, messages: Vec>>) -> Result<()> { - // Send each message over the connection - for message in messages { - bail!( - self.send_message_raw(message).await, - Connection, - "failed to send message" - ); - } - - Ok(()) - } - - /// Connect to a remote endpoint, returning an instance of `Self`. With QUIC, - /// this requires creating an endpoint, binding to it, and then attempting - /// a connection. - /// - /// # Errors - /// Errors if we fail to connect or if we fail to bind to the interface we want. - async fn connect(remote_endpoint: String) -> Result - where - Self: Sized, - { + async fn connect(remote_endpoint: String) -> Result<(QuicSender, QuicReceiver)> { // Parse the socket address let remote_address = bail_option!( bail!( @@ -219,7 +73,7 @@ impl Connection for QuicConnection { endpoint.set_default_client_config(config); // Connect with QUIC endpoint to remote address - let connection = Self(bail!( + let connection = bail!( bail!( endpoint.connect(remote_address, domain_name), Connection, @@ -228,18 +82,18 @@ impl Connection for QuicConnection { .await, Connection, "failed to connect to remote address" - )); + ); - Ok(connection) - } -} + // Open a bidirectional stream over the connection + let (sender, receiver) = bail!( + connection.open_bi().await, + Connection, + "failed to open bidirectional stream" + ); -/// The listener struct. Needed to receive messages over QUIC. Is a light -/// wrapper around `quinn::Endpoint`. -pub struct QuicListener(pub quinn::Endpoint); + Ok((QuicSender(sender), QuicReceiver(receiver))) + } -#[async_trait] -impl Listener for QuicListener { /// Binds to a local endpoint. Uses `maybe_tls_cert_path` and `maybe_tls_cert_key` /// to conditionally load or generate the given (or not given) certificate. /// @@ -250,10 +104,7 @@ impl Listener for QuicListener { bind_address: std::net::SocketAddr, maybe_tls_cert_path: Option, maybe_tls_key_path: Option, - ) -> Result - where - Self: Sized, - { + ) -> Result { // Conditionally load or generate a certificate and key let (certificates, key) = bail!( crypto::load_or_self_sign_tls_certificate_and_key( @@ -273,22 +124,116 @@ impl Listener for QuicListener { // Create endpoint from the given server configuration and // bind address - Ok(Self(bail!( + Ok(QuicListener(bail!( Endpoint::server(server_config, bind_address), Connection, "failed to bind to local address" ))) } +} + +pub struct QuicSender(quinn::SendStream); + +#[async_trait] +impl Sender for QuicSender { + /// Send a message over the connection. + /// + /// # Errors + /// - If we fail to deliver the message. This usually means a connection problem. + async fn send_message(&mut self, message: Message) -> Result<()> { + // Serialize the message + let message = bail!( + message.serialize(), + Serialize, + "failed to serialize message" + ); + + // Write the message to the stream + write_length_delimited!(self.0, message); + + Ok(()) + } + + /// Send a vector of pre-formed messages over the connection. + /// + /// # Errors + /// - If we fail to deliver any of the messages. This usually means a connection problem. + async fn send_messages(&mut self, messages: VecDeque>>) -> Result<()> { + // Write each message (length-delimited) + for message in messages { + write_length_delimited!(self.0, message); + } + + Ok(()) + } + /// Gracefully shuts down the outgoing stream, ensuring all data + /// has been written. + /// + /// # Errors + /// - If we could not shut down the stream. + async fn finish(&mut self) -> Result<()> { + bail!( + self.0.finish().await, + Connection, + "failed to finish connection" + ); + + Ok(()) + } +} + +pub struct QuicReceiver(quinn::RecvStream); + +#[async_trait] +impl Receiver for QuicReceiver { + /// Receives a single message over the stream and deserializes + /// it. + /// + /// # Errors + /// - if we fail to receive the message + /// - if we fail deserialization + async fn recv_message(&mut self) -> Result { + // Receive the raw message + let raw_message = bail!( + self.recv_message_raw().await, + Connection, + "failed to receive message" + ); + + // Deserialize and return the message + Ok(bail!( + Message::deserialize(&raw_message), + Deserialize, + "failed to deserialize message" + )) + } + + /// Receives a single message over the stream without deserializing + /// it. + /// + /// # Errors + /// - if we fail to receive the message + async fn recv_message_raw(&mut self) -> Result> { + Ok(read_length_delimited!(self.0)) + } +} + +/// The listener struct. Needed to receive messages over QUIC. Is a light +/// wrapper around `quinn::Endpoint`. +pub struct QuicListener(pub quinn::Endpoint); + +#[async_trait] +impl Listener for QuicListener { /// Accept a connection from the listener. /// /// # Errors /// - If we fail to accept a connection from the listener. /// TODO: be more descriptive with this /// TODO: match on whether the endpoint is closed, return a different error - async fn accept(&self) -> Result { + async fn accept(&self) -> Result<(QuicSender, QuicReceiver)> { // Try to accept a connection from the QUIC endpoint - Ok(QuicConnection(bail!( + let connection = bail!( bail_option!( self.0.accept().await, Connection, @@ -297,6 +242,15 @@ impl Listener for QuicListener { .await, Connection, "failed to accept connection" - ))) + ); + + // Accept a bidirectional stream from the connection + let (sender, receiver) = bail!( + connection.accept_bi().await, + Connection, + "failed to accept bidirectional stream" + ); + + Ok((QuicSender(sender), QuicReceiver(receiver))) } } diff --git a/proto/src/connection/protocols/tcp.rs b/proto/src/connection/protocols/tcp.rs index ef02068..2cdf6b5 100644 --- a/proto/src/connection/protocols/tcp.rs +++ b/proto/src/connection/protocols/tcp.rs @@ -3,200 +3,41 @@ //! logic. use async_trait::async_trait; -use rand::{rngs::StdRng, RngCore, SeedableRng}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpSocket, }, - sync::Mutex, }; use crate::{ bail, bail_option, error::{Error, Result}, message::Message, - MAX_MESSAGE_SIZE, + read_length_delimited, write_length_delimited, MAX_MESSAGE_SIZE, }; -use std::{hash::Hash, net::ToSocketAddrs, sync::Arc}; +use std::{collections::VecDeque, net::ToSocketAddrs, sync::Arc}; -use super::{Connection, Listener, Protocol}; +use super::{Listener, Protocol, Receiver, Sender}; /// The `Tcp` protocol. We use this to define commonalities between TCP /// listeners, connections, etc. -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Eq)] pub struct Tcp; -/// We define the `Tcp` protocol as being composed of both a TCP listener -/// and connection. +#[async_trait] impl Protocol for Tcp { - type Connection = TcpConnection; + type Sender = TcpSender; + type Receiver = TcpReceiver; type Listener = TcpListener; -} - -/// `TcpConnection` is a thin wrapper around `OwnedReadHalf` and `OwnedWriteHalf` that implements -/// `Connection`. -#[derive(Clone)] -pub struct TcpConnection { - pub receiver: Arc>, - pub sender: Arc>, - pub stable_id: u64, -} - -/// `PartialEq` for a `TcpConnection` connection is determined by the `stable_id` since it -/// will not change for the duration of the connection. -impl PartialEq for TcpConnection { - fn eq(&self, other: &Self) -> bool { - self.stable_id == other.stable_id - } -} - -/// Assertion for `QuicConnection` that `PartialEq` == `Eq` -impl Eq for TcpConnection { - fn assert_receiver_is_total_eq(&self) {} -} - -/// `Hash` for a `TcpConnection` connection is determined by the `stable_id` since it -/// will not change for the duration of the connection. We just want to hash that. -impl Hash for TcpConnection { - fn hash(&self, state: &mut H) { - self.stable_id.hash(state); - } - - /// This just calls `hash` on each item in the slice. - fn hash_slice(data: &[Self], state: &mut H) - where - Self: Sized, - { - data.iter().for_each(|item| item.hash(state)); - } -} - -#[async_trait] -impl Connection for TcpConnection { - /// Receives a single message from the TCP connection. It reads the size - /// of the message from the stream, reads the message, and then - /// deserializes and returns it. - /// - /// # Errors - /// Errors if we either failed to receive or deserialize the message. - /// This usually means a connection problem. - async fn recv_message(&self) -> Result { - // Lock the stream so we don't receive message/message sizes interleaved - let mut receiver_guard = self.receiver.lock().await; - - // Read the message size from the stream - let message_size = bail!( - receiver_guard.read_u64().await, - Connection, - "failed to read message size" - ); - // Make sure the message isn't too big - if message_size > MAX_MESSAGE_SIZE { - return Err(Error::Connection( - "expected to receive message that was too big".to_string(), - )); - } - - // Create buffer of the proper size - let mut buffer = vec![0; usize::try_from(message_size).expect("64 bit system")]; - - // Read the message from the stream - bail!( - receiver_guard.read_exact(&mut buffer).await, - Connection, - "failed to receive message from connection" - ); - drop(receiver_guard); - - // Deserialize and return the message - Ok(bail!( - Message::deserialize(&buffer), - Deserialize, - "failed to deserialize message" - )) - } - - /// Sends a single (deserialized) message over the TCP connection - /// - /// # Errors - /// - If we fail to serialize the message - /// - If we fail to send the message - async fn send_message(&self, message: Message) -> Result<()> { - // Serialize the message - let serialized_message = bail!( - message.serialize(), - Serialize, - "failed to serialize message" - ); - - // Send the serialized message - self.send_message_raw(Arc::from(serialized_message)).await - } - - /// Send a pre-formed message over the connection. - /// - /// # Errors - /// - If we fail to deliver the message. This usually means a connection problem. - async fn send_message_raw(&self, message: Arc>) -> Result<()> { - // Lock the stream so we don't send message/message sizes interleaved - let mut sender_guard = self.sender.lock().await; - - // Write the message size to the stream - bail!( - sender_guard.write_u64(message.len() as u64).await, - Connection, - "failed to send message size" - ); - - // Write the message to the stream - bail!( - sender_guard.write_all(&message).await, - Connection, - "failed to send message" - ); - drop(sender_guard); - - Ok(()) - } - - /// Send a vector pre-formed messages over the connection. - /// - /// # Errors - /// - If we fail to deliver the message. This usually means a connection problem. - async fn send_messages_raw(&self, messages: Vec>>) -> Result<()> { - // Lock the stream so we don't send message/message sizes interleaved - let mut sender_guard = self.sender.lock().await; - - // For each message: - for message in messages { - // Write the message size to the stream - bail!( - sender_guard.write_u64(message.len() as u64).await, - Connection, - "failed to send message size" - ); - - // Write the message to the stream - bail!( - sender_guard.write_all(&message).await, - Connection, - "failed to send message" - ); - } - - drop(sender_guard); - - Ok(()) - } /// Connect to a remote endpoint, returning an instance of `Self`. /// With TCP, this requires just connecting to the remote endpoint. /// /// # Errors /// Errors if we fail to connect or if we fail to bind to the interface we want. - async fn connect(remote_endpoint: String) -> Result + async fn connect(remote_endpoint: String) -> Result<(Self::Sender, Self::Receiver)> where Self: Sized, { @@ -230,21 +71,9 @@ impl Connection for TcpConnection { // concurrently over both let (read_half, write_half) = stream.into_split(); - // `Mutex` and `Arc` each side - Ok(Self { - receiver: Arc::from(Mutex::from(read_half)), - sender: Arc::from(Mutex::from(write_half)), - stable_id: StdRng::from_entropy().next_u64(), - }) + Ok((TcpSender(write_half), TcpReceiver(read_half))) } -} -/// The listener struct. Needed to receive messages over TCP. Is a light -/// wrapper around `tokio::net::TcpListener`. -pub struct TcpListener(pub tokio::net::TcpListener); - -#[async_trait] -impl Listener for TcpListener { /// Binds to a local endpoint. Does not use a TLS configuration. /// /// # Errors @@ -253,25 +82,118 @@ impl Listener for TcpListener { bind_address: std::net::SocketAddr, _maybe_tls_cert_path: Option, _maybe_tls_key_path: Option, - ) -> Result - where - Self: Sized, - { + ) -> Result { // Try to bind to the local address - Ok(Self(bail!( + Ok(TcpListener(bail!( tokio::net::TcpListener::bind(bind_address).await, Connection, "failed to bind to local address" ))) } +} + +/// This struct is a light wrapper over the send half of a TCP connection. +pub struct TcpSender(OwnedWriteHalf); +#[async_trait] +impl Sender for TcpSender { + /// Send an unserialized message over the connection. + /// + /// # Errors + /// - If we fail to deliver the message. This usually means a connection problem. + async fn send_message(&mut self, message: Message) -> Result<()> { + // Serialize the message + let message = bail!( + message.serialize(), + Serialize, + "failed to serialize message" + ); + + // Write the message to the stream + write_length_delimited!(self.0, message); + + Ok(()) + } + + /// Send a vector of pre-formed messages over the connection. + /// + /// # Errors + /// - If we fail to deliver the message. This usually means a connection problem. + async fn send_messages(&mut self, messages: VecDeque>>) -> Result<()> { + // Write each message (length-delimited) + for message in messages { + write_length_delimited!(self.0, message); + } + + Ok(()) + } + + /// Gracefully shuts down the outgoing stream, ensuring all data + /// has been written. + /// + /// # Errors + /// - If we could not shut down the stream. + async fn finish(&mut self) -> Result<()> { + bail!( + self.0.shutdown().await, + Connection, + "failed to finish connection" + ); + + Ok(()) + } +} + +/// This is a light wrapper over the read half of a TCP connection +pub struct TcpReceiver(OwnedReadHalf); + +#[async_trait] +impl Receiver for TcpReceiver { + /// Receives a single message over the stream and deserializes + /// it. + /// + /// # Errors + /// - if we fail to receive the message + /// - if we fail deserialization + async fn recv_message(&mut self) -> Result { + // Receive the raw message + let raw_message = bail!( + self.recv_message_raw().await, + Connection, + "failed to receive message" + ); + + // Deserialize and return the message + Ok(bail!( + Message::deserialize(&raw_message), + Deserialize, + "failed to deserialize message" + )) + } + + /// Receives a single message over the stream without deserializing + /// it. + /// + /// # Errors + /// - if we fail to receive the message + async fn recv_message_raw(&mut self) -> Result> { + Ok(read_length_delimited!(self.0)) + } +} + +/// The listener struct. Needed to receive messages over TCP. Is a light +/// wrapper around `tokio::net::TcpListener`. +pub struct TcpListener(pub tokio::net::TcpListener); + +#[async_trait] +impl Listener for TcpListener { /// Accept a connection from the listener. /// /// # Errors /// - If we fail to accept a connection from the listener. /// TODO: be more descriptive with this /// TODO: match on whether the endpoint is closed, return a different error - async fn accept(&self) -> Result { + async fn accept(&self) -> Result<(TcpSender, TcpReceiver)> { // Try to accept a connection from the underlying endpoint // Split into reader and writer half let (receiver, sender) = bail!( @@ -283,10 +205,6 @@ impl Listener for TcpListener { .into_split(); // Wrap our halves so they can be used across threads - Ok(TcpConnection { - receiver: Arc::from(Mutex::from(receiver)), - sender: Arc::from(Mutex::from(sender)), - stable_id: StdRng::from_entropy().next_u64(), - }) + Ok((TcpSender(sender), TcpReceiver(receiver))) } } diff --git a/proto/src/message.rs b/proto/src/message.rs index 613fdad..51e5ec4 100644 --- a/proto/src/message.rs +++ b/proto/src/message.rs @@ -61,14 +61,7 @@ macro_rules! deserialize { Deserialize, "failed to deserialize users" ) { - users.push( - bail!( - user.get_key(), - Deserialize, - "failed to deserialize user key" - ) - .to_vec(), - ); + users.push(bail!(user, Deserialize, "failed to deserialize user key").to_vec()); } users }}; @@ -112,10 +105,13 @@ impl Message { /// /// # Errors /// Errors if the downstream serialization fails. + /// + /// # Panics + /// If we can't cast from a usize to a u32 pub fn serialize(&self) -> Result> { // Create a new root message, our message base - let mut message = capnp::message::Builder::new_default(); - let root: messages_capnp::message::Builder = message.init_root(); + let mut default_message = capnp::message::Builder::new_default(); + let root: messages_capnp::message::Builder = default_message.init_root(); // Conditional logic based on what kind of message we passed in match self { @@ -212,12 +208,11 @@ impl Message { // Init the users let mut users = message .reborrow() - .init_users(to_serialize.users.len() as u32); + .init_users(u32::try_from(to_serialize.users.len()).expect("serialization failed")); // For each user, reborrow and serialize for (i, user) in to_serialize.users.iter().enumerate() { - let mut cur_user = users.reborrow().get(i as u32); - cur_user.set_key(user); + users.reborrow().set(u32::try_from(i).expect("serialization failed"), user); } } @@ -228,17 +223,16 @@ impl Message { // Init the users let mut users = message .reborrow() - .init_users(to_serialize.users.len() as u32); + .init_users(u32::try_from(to_serialize.users.len()).expect("serialization failed")); // For each user, reborrow and serialize for (i, user) in to_serialize.users.iter().enumerate() { - let mut cur_user = users.reborrow().get(i as u32); - cur_user.set_key(user); + users.reborrow().set(u32::try_from(i).expect("serialization failed"), user); } } } - Ok(write_message_segments_to_words(&message)) + Ok(write_message_segments_to_words(&default_message)) } /// `deserialize` is used to deserialize a message. It returns a