Skip to content

Commit

Permalink
zhttpsocket: proper handling of is_writable (#47882)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkarneges authored Jan 24, 2024
1 parent 864f20b commit 42ac17c
Showing 1 changed file with 156 additions and 66 deletions.
222 changes: 156 additions & 66 deletions src/zhttpsocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ struct SessionItem {
handle_index: usize,
}

enum SessionAddError {
Full,
Exists,
}

struct SessionDataInner {
items: Slab<SessionItem>,
items_by_key: HashMap<SessionKey, usize>,
Expand All @@ -292,11 +297,15 @@ impl SessionData {
}
}

fn add(&self, key: SessionKey, handle_index: usize) -> Result<Session, ()> {
fn add(&self, key: SessionKey, handle_index: usize) -> Result<Session, SessionAddError> {
let inner = &mut *self.inner.lock().unwrap();

if inner.items.len() == inner.items.capacity() || inner.items_by_key.contains_key(&key) {
return Err(());
if inner.items.len() == inner.items.capacity() {
return Err(SessionAddError::Full);
}

if inner.items_by_key.contains_key(&key) {
return Err(SessionAddError::Exists);
}

let item_key = inner.items.insert(SessionItem {
Expand Down Expand Up @@ -355,7 +364,7 @@ impl SessionTable {
}
}

fn add(&self, key: SessionKey, handle_index: usize) -> Result<Session, ()> {
fn add(&self, key: SessionKey, handle_index: usize) -> Result<Session, SessionAddError> {
self.data.add(key, handle_index)
}

Expand Down Expand Up @@ -843,6 +852,8 @@ impl StreamHandles {
}
}

struct ReqHandlesSendError(MultipartHeader);

struct ServerReqHandles {
nodes: Slab<list::Node<ServerReqPipe>>,
list: list::List,
Expand Down Expand Up @@ -918,6 +929,7 @@ impl ServerReqHandles {
}
}

// waits until at least one handle is likely writable
#[allow(clippy::await_holding_refcell_ref)]
async fn check_send(&self) {
let mut any_valid = false;
Expand Down Expand Up @@ -961,9 +973,13 @@ impl ServerReqHandles {
}

// non-blocking send. caller should use check_send() first
fn send(&self, header: MultipartHeader, msg: &arena::Arc<zmq::Message>) {
fn send(
&self,
header: MultipartHeader,
msg: &arena::Arc<zmq::Message>,
) -> Result<(), ReqHandlesSendError> {
if self.nodes.is_empty() {
return;
return Err(ReqHandlesSendError(header));
}

let mut skip = self.send_index.get();
Expand All @@ -983,19 +999,30 @@ impl ServerReqHandles {
}
}

if let Some(nkey) = selected {
let n = &self.nodes[nkey];
let p = &n.value;
let nkey = match selected {
Some(nkey) => nkey,
None => return Err(ReqHandlesSendError(header)),
};

let n = &self.nodes[nkey];
let p = &n.value;

match p.pe.sender.try_send((header, arena::Arc::clone(msg))) {
Ok(_) => {}
Err(mpsc::TrySendError::Full(_)) => error!("req sender is full"),
Err(mpsc::TrySendError::Disconnected(_)) => {
if let Err(e) = p.pe.sender.try_send((header, arena::Arc::clone(msg))) {
let header = match e {
mpsc::TrySendError::Full((header, _)) => header,
mpsc::TrySendError::Disconnected((header, _)) => {
p.valid.set(false);

self.need_cleanup.set(true);

header
}
}
};

return Err(ReqHandlesSendError(header));
}

Ok(())
}

fn need_cleanup(&self) -> bool {
Expand Down Expand Up @@ -1026,10 +1053,18 @@ impl ServerReqHandles {
}
}

enum StreamHandlesSendError {
BadFormat,
NoneReady,
SessionExists,
SessionCapacityFull,
}

struct ServerStreamHandles {
nodes: Slab<list::Node<ServerStreamPipe>>,
list: list::List,
recv_scratch: RefCell<RecvScratch<zmq::Message>>,
check_send_any_scratch: RefCell<CheckSendScratch<(arena::Arc<zmq::Message>, Session)>>,
send_direct_scratch: RefCell<Vec<bool>>,
need_cleanup: Cell<bool>,
send_index: Cell<usize>,
Expand All @@ -1042,6 +1077,7 @@ impl ServerStreamHandles {
nodes: Slab::with_capacity(capacity),
list: list::List::default(),
recv_scratch: RefCell::new(RecvScratch::new(capacity)),
check_send_any_scratch: RefCell::new(CheckSendScratch::new(capacity)),
send_direct_scratch: RefCell::new(Vec::with_capacity(capacity)),
need_cleanup: Cell::new(false),
send_index: Cell::new(0),
Expand Down Expand Up @@ -1103,27 +1139,62 @@ impl ServerStreamHandles {
}
}

// waits until at least one handle is likely writable
#[allow(clippy::await_holding_refcell_ref)]
async fn check_send_any(&self) {
let mut at_least_one_writable = false;
let mut any_valid = false;
let mut any_writable = false;

for (_, p) in self.list.iter(&self.nodes) {
if p.valid.get() {
p.pe.sender_any.wait_writable().await;
at_least_one_writable = true;
any_valid = true;

if p.pe.sender_any.is_writable() {
any_writable = true;
break;
}
}
}

if any_writable {
return;
}

// if there are no valid pipes then hang forever. caller can
// try again by dropping the future and making a new one
if !at_least_one_writable {
if !any_valid {
std::future::pending::<()>().await;
}

// there are valid pipes but none are writable. we'll wait

let mut scratch = self.check_send_any_scratch.borrow_mut();
let (mut tasks, slice_scratch) = scratch.get();

for (_, p) in self.list.iter(&self.nodes) {
if p.valid.get() {
assert!(tasks.len() < tasks.capacity());

tasks.push(p.pe.sender_any.wait_writable());
}
}

select_slice(&mut tasks, slice_scratch).await;
}

// non-blocking send. caller should use check_send_any() first
fn send_any(&self, msg: &arena::Arc<zmq::Message>, from: &[u8], ids: &[Id<'_>]) {
if self.nodes.is_empty() || ids.is_empty() {
return;
fn send_any(
&self,
msg: &arena::Arc<zmq::Message>,
from: &[u8],
ids: &[Id],
) -> Result<(), StreamHandlesSendError> {
if from.len() > FROM_MAX || ids.is_empty() || ids[0].id.len() > REQ_ID_MAX {
return Err(StreamHandlesSendError::BadFormat);
}

if self.nodes.is_empty() {
return Err(StreamHandlesSendError::NoneReady);
}

let mut skip = self.send_index.get();
Expand All @@ -1143,34 +1214,39 @@ impl ServerStreamHandles {
}
}

if let Some(nkey) = selected {
let n = &self.nodes[nkey];
let p = &n.value;
let nkey = match selected {
Some(nkey) => nkey,
None => return Err(StreamHandlesSendError::NoneReady),
};

let from = match ArrayVec::try_from(from) {
Ok(v) => v,
Err(_) => return,
};
let n = &self.nodes[nkey];
let p = &n.value;

let id = match ArrayVec::try_from(ids[0].id) {
Ok(v) => v,
Err(_) => return,
};
let from = ArrayVec::try_from(from).unwrap();
let id = ArrayVec::try_from(ids[0].id).unwrap();

let key = (from, id);
let key = (from, id);

if let Ok(session) = self.sessions.add(key, nkey) {
match p.pe.sender_any.try_send((arena::Arc::clone(msg), session)) {
Ok(_) => {}
Err(mpsc::TrySendError::Full(_)) => error!("stream sender_any is full"),
Err(mpsc::TrySendError::Disconnected(_)) => {
p.valid.set(false);
let session = match self.sessions.add(key, nkey) {
Ok(s) => s,
Err(SessionAddError::Full) => return Err(StreamHandlesSendError::SessionCapacityFull),
Err(SessionAddError::Exists) => return Err(StreamHandlesSendError::SessionExists),
};

self.need_cleanup.set(true);
}
if let Err(e) = p.pe.sender_any.try_send((arena::Arc::clone(msg), session)) {
match e {
mpsc::TrySendError::Full(_) => {}
mpsc::TrySendError::Disconnected(_) => {
p.valid.set(false);

self.need_cleanup.set(true);
}
}

return Err(StreamHandlesSendError::NoneReady);
}

Ok(())
}

#[allow(clippy::await_holding_refcell_ref)]
Expand Down Expand Up @@ -1969,7 +2045,10 @@ impl ServerSocketManager {

let messages_memory = Arc::new(arena::SyncMemory::new(arena_size));

let sessions_max = stream_maxconn + (HANDLES_MAX * handle_bound);
// sessions are created at the time of attempting to send to a handle, so we need enough
// sessions to max out the workers, and max out all the handle channels, and have one
// left to use when attempting to send
let sessions_max = stream_maxconn + (HANDLES_MAX * handle_bound) + 1;

let req_sock = AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::ROUTER));

Expand Down Expand Up @@ -2155,6 +2234,8 @@ impl ServerSocketManager {
trace!("IN server req {}", packet_to_string(&msg));
}

let msg = arena::Arc::new(msg, &messages_memory).unwrap();

req_in_msg = Some((header, msg));
}
Err(e) => error!("server req zmq recv: {}", e),
Expand All @@ -2176,18 +2257,16 @@ impl ServerSocketManager {
req_send = None;
}
// req_handles_check_send
Select10::R5(()) => {
let (header, msg) = req_in_msg.take().unwrap();

Self::handle_req_message(header, msg, &messages_memory, &req_handles);
}
Select10::R5(()) => Self::handle_req_message(&mut req_in_msg, &req_handles),
// stream_in_recv
Select10::R6(result) => match result {
Ok(msg) => {
if log_enabled!(log::Level::Trace) {
trace!("IN server stream {}", packet_to_string(&msg));
}

let msg = arena::Arc::new(msg, &messages_memory).unwrap();

stream_in_msg = Some(msg);
}
Err(e) => error!("server stream zmq recv: {}", e),
Expand Down Expand Up @@ -2222,9 +2301,7 @@ impl ServerSocketManager {
}
// stream_handles_check_send_any
Select10::R10(()) => {
let msg = stream_in_msg.take().unwrap();

Self::handle_stream_message_any(msg, &messages_memory, &stream_handles);
Self::handle_stream_message_any(&mut stream_in_msg, &stream_handles);
}
}

Expand Down Expand Up @@ -2268,34 +2345,47 @@ impl ServerSocketManager {
}

fn handle_req_message(
header: MultipartHeader,
msg: zmq::Message,
messages_memory: &Arc<arena::ArcMemory<zmq::Message>>,
next_msg: &mut Option<(MultipartHeader, arena::Arc<zmq::Message>)>,
handles: &ServerReqHandles,
) {
let msg = arena::Arc::new(msg, messages_memory).unwrap();
let (header, msg) = next_msg.take().unwrap();

handles.send(header, &msg);
if let Err(ReqHandlesSendError(header)) = handles.send(header, &msg) {
*next_msg = Some((header, msg));
}
}

fn handle_stream_message_any(
msg: zmq::Message,
messages_memory: &Arc<arena::ArcMemory<zmq::Message>>,
next_msg: &mut Option<arena::Arc<zmq::Message>>,
handles: &ServerStreamHandles,
) {
let msg = arena::Arc::new(msg, messages_memory).unwrap();
let msg = next_msg.take().unwrap();

let mut scratch = ParseScratch::new();
let ret = {
let mut scratch = ParseScratch::new();

let (from, ids) = match parse_ids(msg.get(), &mut scratch) {
Ok(ret) => ret,
Err(e) => {
warn!("unable to determine packet id(s): {}", e);
return;
}
let (from, ids) = match parse_ids(msg.get(), &mut scratch) {
Ok(ret) => ret,
Err(e) => {
warn!("unable to determine packet id(s): {}", e);
return;
}
};

handles.send_any(&msg, from, ids)
};

handles.send_any(&msg, from, ids);
match ret {
Ok(()) => {}
Err(StreamHandlesSendError::BadFormat) => warn!("stream send_any: bad format"),
Err(StreamHandlesSendError::NoneReady) => *next_msg = Some(msg),
Err(StreamHandlesSendError::SessionExists) => {
warn!("stream send_any: session id in use")
}
Err(StreamHandlesSendError::SessionCapacityFull) => {
error!("stream send_any: session capacity full")
}
}
}

async fn handle_stream_message_direct(
Expand Down

0 comments on commit 42ac17c

Please sign in to comment.