diff --git a/src/zhttpsocket.rs b/src/zhttpsocket.rs index c00c4192..b99fa41f 100644 --- a/src/zhttpsocket.rs +++ b/src/zhttpsocket.rs @@ -272,6 +272,11 @@ struct SessionItem { handle_index: usize, } +enum SessionAddError { + Full, + Exists, +} + struct SessionDataInner { items: Slab, items_by_key: HashMap, @@ -292,11 +297,15 @@ impl SessionData { } } - fn add(&self, key: SessionKey, handle_index: usize) -> Result { + fn add(&self, key: SessionKey, handle_index: usize) -> Result { let inner = &mut *self.inner.lock().unwrap(); - if inner.items.len() == inner.items.capacity() || inner.items_by_key.contains_key(&key) { - return Err(()); + if inner.items.len() == inner.items.capacity() { + return Err(SessionAddError::Full); + } + + if inner.items_by_key.contains_key(&key) { + return Err(SessionAddError::Exists); } let item_key = inner.items.insert(SessionItem { @@ -355,7 +364,7 @@ impl SessionTable { } } - fn add(&self, key: SessionKey, handle_index: usize) -> Result { + fn add(&self, key: SessionKey, handle_index: usize) -> Result { self.data.add(key, handle_index) } @@ -843,6 +852,8 @@ impl StreamHandles { } } +struct ReqHandlesSendError(MultipartHeader); + struct ServerReqHandles { nodes: Slab>, list: list::List, @@ -918,6 +929,7 @@ impl ServerReqHandles { } } + // waits until at least one handle is likely writable #[allow(clippy::await_holding_refcell_ref)] async fn check_send(&self) { let mut any_valid = false; @@ -961,9 +973,13 @@ impl ServerReqHandles { } // non-blocking send. caller should use check_send() first - fn send(&self, header: MultipartHeader, msg: &arena::Arc) { + fn send( + &self, + header: MultipartHeader, + msg: &arena::Arc, + ) -> Result<(), ReqHandlesSendError> { if self.nodes.is_empty() { - return; + return Err(ReqHandlesSendError(header)); } let mut skip = self.send_index.get(); @@ -983,19 +999,30 @@ impl ServerReqHandles { } } - if let Some(nkey) = selected { - let n = &self.nodes[nkey]; - let p = &n.value; + let nkey = match selected { + Some(nkey) => nkey, + None => return Err(ReqHandlesSendError(header)), + }; + + let n = &self.nodes[nkey]; + let p = &n.value; - match p.pe.sender.try_send((header, arena::Arc::clone(msg))) { - Ok(_) => {} - Err(mpsc::TrySendError::Full(_)) => error!("req sender is full"), - Err(mpsc::TrySendError::Disconnected(_)) => { + if let Err(e) = p.pe.sender.try_send((header, arena::Arc::clone(msg))) { + let header = match e { + mpsc::TrySendError::Full((header, _)) => header, + mpsc::TrySendError::Disconnected((header, _)) => { p.valid.set(false); + self.need_cleanup.set(true); + + header } - } + }; + + return Err(ReqHandlesSendError(header)); } + + Ok(()) } fn need_cleanup(&self) -> bool { @@ -1026,10 +1053,18 @@ impl ServerReqHandles { } } +enum StreamHandlesSendError { + BadFormat, + NoneReady, + SessionExists, + SessionCapacityFull, +} + struct ServerStreamHandles { nodes: Slab>, list: list::List, recv_scratch: RefCell>, + check_send_any_scratch: RefCell, Session)>>, send_direct_scratch: RefCell>, need_cleanup: Cell, send_index: Cell, @@ -1042,6 +1077,7 @@ impl ServerStreamHandles { nodes: Slab::with_capacity(capacity), list: list::List::default(), recv_scratch: RefCell::new(RecvScratch::new(capacity)), + check_send_any_scratch: RefCell::new(CheckSendScratch::new(capacity)), send_direct_scratch: RefCell::new(Vec::with_capacity(capacity)), need_cleanup: Cell::new(false), send_index: Cell::new(0), @@ -1103,27 +1139,62 @@ impl ServerStreamHandles { } } + // waits until at least one handle is likely writable + #[allow(clippy::await_holding_refcell_ref)] async fn check_send_any(&self) { - let mut at_least_one_writable = false; + let mut any_valid = false; + let mut any_writable = false; for (_, p) in self.list.iter(&self.nodes) { if p.valid.get() { - p.pe.sender_any.wait_writable().await; - at_least_one_writable = true; + any_valid = true; + + if p.pe.sender_any.is_writable() { + any_writable = true; + break; + } } } + if any_writable { + return; + } + // if there are no valid pipes then hang forever. caller can // try again by dropping the future and making a new one - if !at_least_one_writable { + if !any_valid { std::future::pending::<()>().await; } + + // there are valid pipes but none are writable. we'll wait + + let mut scratch = self.check_send_any_scratch.borrow_mut(); + let (mut tasks, slice_scratch) = scratch.get(); + + for (_, p) in self.list.iter(&self.nodes) { + if p.valid.get() { + assert!(tasks.len() < tasks.capacity()); + + tasks.push(p.pe.sender_any.wait_writable()); + } + } + + select_slice(&mut tasks, slice_scratch).await; } // non-blocking send. caller should use check_send_any() first - fn send_any(&self, msg: &arena::Arc, from: &[u8], ids: &[Id<'_>]) { - if self.nodes.is_empty() || ids.is_empty() { - return; + fn send_any( + &self, + msg: &arena::Arc, + from: &[u8], + ids: &[Id], + ) -> Result<(), StreamHandlesSendError> { + if from.len() > FROM_MAX || ids.is_empty() || ids[0].id.len() > REQ_ID_MAX { + return Err(StreamHandlesSendError::BadFormat); + } + + if self.nodes.is_empty() { + return Err(StreamHandlesSendError::NoneReady); } let mut skip = self.send_index.get(); @@ -1143,34 +1214,39 @@ impl ServerStreamHandles { } } - if let Some(nkey) = selected { - let n = &self.nodes[nkey]; - let p = &n.value; + let nkey = match selected { + Some(nkey) => nkey, + None => return Err(StreamHandlesSendError::NoneReady), + }; - let from = match ArrayVec::try_from(from) { - Ok(v) => v, - Err(_) => return, - }; + let n = &self.nodes[nkey]; + let p = &n.value; - let id = match ArrayVec::try_from(ids[0].id) { - Ok(v) => v, - Err(_) => return, - }; + let from = ArrayVec::try_from(from).unwrap(); + let id = ArrayVec::try_from(ids[0].id).unwrap(); - let key = (from, id); + let key = (from, id); - if let Ok(session) = self.sessions.add(key, nkey) { - match p.pe.sender_any.try_send((arena::Arc::clone(msg), session)) { - Ok(_) => {} - Err(mpsc::TrySendError::Full(_)) => error!("stream sender_any is full"), - Err(mpsc::TrySendError::Disconnected(_)) => { - p.valid.set(false); + let session = match self.sessions.add(key, nkey) { + Ok(s) => s, + Err(SessionAddError::Full) => return Err(StreamHandlesSendError::SessionCapacityFull), + Err(SessionAddError::Exists) => return Err(StreamHandlesSendError::SessionExists), + }; - self.need_cleanup.set(true); - } + if let Err(e) = p.pe.sender_any.try_send((arena::Arc::clone(msg), session)) { + match e { + mpsc::TrySendError::Full(_) => {} + mpsc::TrySendError::Disconnected(_) => { + p.valid.set(false); + + self.need_cleanup.set(true); } } + + return Err(StreamHandlesSendError::NoneReady); } + + Ok(()) } #[allow(clippy::await_holding_refcell_ref)] @@ -1969,7 +2045,10 @@ impl ServerSocketManager { let messages_memory = Arc::new(arena::SyncMemory::new(arena_size)); - let sessions_max = stream_maxconn + (HANDLES_MAX * handle_bound); + // sessions are created at the time of attempting to send to a handle, so we need enough + // sessions to max out the workers, and max out all the handle channels, and have one + // left to use when attempting to send + let sessions_max = stream_maxconn + (HANDLES_MAX * handle_bound) + 1; let req_sock = AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::ROUTER)); @@ -2155,6 +2234,8 @@ impl ServerSocketManager { trace!("IN server req {}", packet_to_string(&msg)); } + let msg = arena::Arc::new(msg, &messages_memory).unwrap(); + req_in_msg = Some((header, msg)); } Err(e) => error!("server req zmq recv: {}", e), @@ -2176,11 +2257,7 @@ impl ServerSocketManager { req_send = None; } // req_handles_check_send - Select10::R5(()) => { - let (header, msg) = req_in_msg.take().unwrap(); - - Self::handle_req_message(header, msg, &messages_memory, &req_handles); - } + Select10::R5(()) => Self::handle_req_message(&mut req_in_msg, &req_handles), // stream_in_recv Select10::R6(result) => match result { Ok(msg) => { @@ -2188,6 +2265,8 @@ impl ServerSocketManager { trace!("IN server stream {}", packet_to_string(&msg)); } + let msg = arena::Arc::new(msg, &messages_memory).unwrap(); + stream_in_msg = Some(msg); } Err(e) => error!("server stream zmq recv: {}", e), @@ -2222,9 +2301,7 @@ impl ServerSocketManager { } // stream_handles_check_send_any Select10::R10(()) => { - let msg = stream_in_msg.take().unwrap(); - - Self::handle_stream_message_any(msg, &messages_memory, &stream_handles); + Self::handle_stream_message_any(&mut stream_in_msg, &stream_handles); } } @@ -2268,34 +2345,47 @@ impl ServerSocketManager { } fn handle_req_message( - header: MultipartHeader, - msg: zmq::Message, - messages_memory: &Arc>, + next_msg: &mut Option<(MultipartHeader, arena::Arc)>, handles: &ServerReqHandles, ) { - let msg = arena::Arc::new(msg, messages_memory).unwrap(); + let (header, msg) = next_msg.take().unwrap(); - handles.send(header, &msg); + if let Err(ReqHandlesSendError(header)) = handles.send(header, &msg) { + *next_msg = Some((header, msg)); + } } fn handle_stream_message_any( - msg: zmq::Message, - messages_memory: &Arc>, + next_msg: &mut Option>, handles: &ServerStreamHandles, ) { - let msg = arena::Arc::new(msg, messages_memory).unwrap(); + let msg = next_msg.take().unwrap(); - let mut scratch = ParseScratch::new(); + let ret = { + let mut scratch = ParseScratch::new(); - let (from, ids) = match parse_ids(msg.get(), &mut scratch) { - Ok(ret) => ret, - Err(e) => { - warn!("unable to determine packet id(s): {}", e); - return; - } + let (from, ids) = match parse_ids(msg.get(), &mut scratch) { + Ok(ret) => ret, + Err(e) => { + warn!("unable to determine packet id(s): {}", e); + return; + } + }; + + handles.send_any(&msg, from, ids) }; - handles.send_any(&msg, from, ids); + match ret { + Ok(()) => {} + Err(StreamHandlesSendError::BadFormat) => warn!("stream send_any: bad format"), + Err(StreamHandlesSendError::NoneReady) => *next_msg = Some(msg), + Err(StreamHandlesSendError::SessionExists) => { + warn!("stream send_any: session id in use") + } + Err(StreamHandlesSendError::SessionCapacityFull) => { + error!("stream send_any: session capacity full") + } + } } async fn handle_stream_message_direct(