Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zhttpsocket: proper handling of is_writable #47882

Merged
merged 1 commit into from
Jan 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 156 additions & 66 deletions src/zhttpsocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ struct SessionItem {
handle_index: usize,
}

enum SessionAddError {
Full,
Exists,
}

struct SessionDataInner {
items: Slab<SessionItem>,
items_by_key: HashMap<SessionKey, usize>,
Expand All @@ -292,11 +297,15 @@ impl SessionData {
}
}

fn add(&self, key: SessionKey, handle_index: usize) -> Result<Session, ()> {
fn add(&self, key: SessionKey, handle_index: usize) -> Result<Session, SessionAddError> {
let inner = &mut *self.inner.lock().unwrap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mutex handling can be improved improved by handling the potential error that can occur when locking the mutex, instead of using unwrap()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lock() fails if the mutex is poisoned, which means we have a logic error and so it's reasonable to abort the program on the spot


if inner.items.len() == inner.items.capacity() || inner.items_by_key.contains_key(&key) {
return Err(());
if inner.items.len() == inner.items.capacity() {
return Err(SessionAddError::Full);
}

if inner.items_by_key.contains_key(&key) {
return Err(SessionAddError::Exists);
}

let item_key = inner.items.insert(SessionItem {
Expand Down Expand Up @@ -355,7 +364,7 @@ impl SessionTable {
}
}

fn add(&self, key: SessionKey, handle_index: usize) -> Result<Session, ()> {
fn add(&self, key: SessionKey, handle_index: usize) -> Result<Session, SessionAddError> {
self.data.add(key, handle_index)
}

Expand Down Expand Up @@ -843,6 +852,8 @@ impl StreamHandles {
}
}

struct ReqHandlesSendError(MultipartHeader);

struct ServerReqHandles {
nodes: Slab<list::Node<ServerReqPipe>>,
list: list::List,
Expand Down Expand Up @@ -918,6 +929,7 @@ impl ServerReqHandles {
}
}

// waits until at least one handle is likely writable
#[allow(clippy::await_holding_refcell_ref)]
async fn check_send(&self) {
let mut any_valid = false;
Expand Down Expand Up @@ -961,9 +973,13 @@ impl ServerReqHandles {
}

// non-blocking send. caller should use check_send() first
fn send(&self, header: MultipartHeader, msg: &arena::Arc<zmq::Message>) {
fn send(
&self,
header: MultipartHeader,
msg: &arena::Arc<zmq::Message>,
) -> Result<(), ReqHandlesSendError> {
if self.nodes.is_empty() {
return;
return Err(ReqHandlesSendError(header));
}

let mut skip = self.send_index.get();
Expand All @@ -983,19 +999,30 @@ impl ServerReqHandles {
}
}

if let Some(nkey) = selected {
let n = &self.nodes[nkey];
let p = &n.value;
let nkey = match selected {
Some(nkey) => nkey,
None => return Err(ReqHandlesSendError(header)),
};

let n = &self.nodes[nkey];
let p = &n.value;

match p.pe.sender.try_send((header, arena::Arc::clone(msg))) {
Ok(_) => {}
Err(mpsc::TrySendError::Full(_)) => error!("req sender is full"),
Err(mpsc::TrySendError::Disconnected(_)) => {
if let Err(e) = p.pe.sender.try_send((header, arena::Arc::clone(msg))) {
let header = match e {
mpsc::TrySendError::Full((header, _)) => header,
mpsc::TrySendError::Disconnected((header, _)) => {
p.valid.set(false);

self.need_cleanup.set(true);

header
}
}
};

return Err(ReqHandlesSendError(header));
}

Ok(())
}

fn need_cleanup(&self) -> bool {
Expand Down Expand Up @@ -1026,10 +1053,18 @@ impl ServerReqHandles {
}
}

enum StreamHandlesSendError {
BadFormat,
NoneReady,
SessionExists,
SessionCapacityFull,
}

struct ServerStreamHandles {
nodes: Slab<list::Node<ServerStreamPipe>>,
list: list::List,
recv_scratch: RefCell<RecvScratch<zmq::Message>>,
check_send_any_scratch: RefCell<CheckSendScratch<(arena::Arc<zmq::Message>, Session)>>,
send_direct_scratch: RefCell<Vec<bool>>,
need_cleanup: Cell<bool>,
send_index: Cell<usize>,
Expand All @@ -1042,6 +1077,7 @@ impl ServerStreamHandles {
nodes: Slab::with_capacity(capacity),
list: list::List::default(),
recv_scratch: RefCell::new(RecvScratch::new(capacity)),
check_send_any_scratch: RefCell::new(CheckSendScratch::new(capacity)),
send_direct_scratch: RefCell::new(Vec::with_capacity(capacity)),
need_cleanup: Cell::new(false),
send_index: Cell::new(0),
Expand Down Expand Up @@ -1103,27 +1139,62 @@ impl ServerStreamHandles {
}
}

// waits until at least one handle is likely writable
#[allow(clippy::await_holding_refcell_ref)]
async fn check_send_any(&self) {
let mut at_least_one_writable = false;
let mut any_valid = false;
let mut any_writable = false;

for (_, p) in self.list.iter(&self.nodes) {
if p.valid.get() {
p.pe.sender_any.wait_writable().await;
at_least_one_writable = true;
any_valid = true;

if p.pe.sender_any.is_writable() {
any_writable = true;
break;
}
}
}

if any_writable {
return;
}

// if there are no valid pipes then hang forever. caller can
// try again by dropping the future and making a new one
if !at_least_one_writable {
if !any_valid {
std::future::pending::<()>().await;
}

// there are valid pipes but none are writable. we'll wait

let mut scratch = self.check_send_any_scratch.borrow_mut();
let (mut tasks, slice_scratch) = scratch.get();

for (_, p) in self.list.iter(&self.nodes) {
if p.valid.get() {
assert!(tasks.len() < tasks.capacity());

tasks.push(p.pe.sender_any.wait_writable());
}
}

select_slice(&mut tasks, slice_scratch).await;
}

// non-blocking send. caller should use check_send_any() first
fn send_any(&self, msg: &arena::Arc<zmq::Message>, from: &[u8], ids: &[Id<'_>]) {
if self.nodes.is_empty() || ids.is_empty() {
return;
fn send_any(
&self,
msg: &arena::Arc<zmq::Message>,
from: &[u8],
ids: &[Id],
) -> Result<(), StreamHandlesSendError> {
if from.len() > FROM_MAX || ids.is_empty() || ids[0].id.len() > REQ_ID_MAX {
return Err(StreamHandlesSendError::BadFormat);
}

if self.nodes.is_empty() {
return Err(StreamHandlesSendError::NoneReady);
}

let mut skip = self.send_index.get();
Expand All @@ -1143,34 +1214,39 @@ impl ServerStreamHandles {
}
}

if let Some(nkey) = selected {
let n = &self.nodes[nkey];
let p = &n.value;
let nkey = match selected {
Some(nkey) => nkey,
None => return Err(StreamHandlesSendError::NoneReady),
};

let from = match ArrayVec::try_from(from) {
Ok(v) => v,
Err(_) => return,
};
let n = &self.nodes[nkey];
let p = &n.value;

let id = match ArrayVec::try_from(ids[0].id) {
Ok(v) => v,
Err(_) => return,
};
let from = ArrayVec::try_from(from).unwrap();
let id = ArrayVec::try_from(ids[0].id).unwrap();

let key = (from, id);
let key = (from, id);

if let Ok(session) = self.sessions.add(key, nkey) {
match p.pe.sender_any.try_send((arena::Arc::clone(msg), session)) {
Ok(_) => {}
Err(mpsc::TrySendError::Full(_)) => error!("stream sender_any is full"),
Err(mpsc::TrySendError::Disconnected(_)) => {
p.valid.set(false);
let session = match self.sessions.add(key, nkey) {
Ok(s) => s,
Err(SessionAddError::Full) => return Err(StreamHandlesSendError::SessionCapacityFull),
Err(SessionAddError::Exists) => return Err(StreamHandlesSendError::SessionExists),
};

self.need_cleanup.set(true);
}
if let Err(e) = p.pe.sender_any.try_send((arena::Arc::clone(msg), session)) {
match e {
mpsc::TrySendError::Full(_) => {}
mpsc::TrySendError::Disconnected(_) => {
p.valid.set(false);

self.need_cleanup.set(true);
}
}

return Err(StreamHandlesSendError::NoneReady);
}

Ok(())
}

#[allow(clippy::await_holding_refcell_ref)]
Expand Down Expand Up @@ -1969,7 +2045,10 @@ impl ServerSocketManager {

let messages_memory = Arc::new(arena::SyncMemory::new(arena_size));

let sessions_max = stream_maxconn + (HANDLES_MAX * handle_bound);
// sessions are created at the time of attempting to send to a handle, so we need enough
// sessions to max out the workers, and max out all the handle channels, and have one
// left to use when attempting to send
let sessions_max = stream_maxconn + (HANDLES_MAX * handle_bound) + 1;

let req_sock = AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::ROUTER));

Expand Down Expand Up @@ -2155,6 +2234,8 @@ impl ServerSocketManager {
trace!("IN server req {}", packet_to_string(&msg));
}

let msg = arena::Arc::new(msg, &messages_memory).unwrap();

req_in_msg = Some((header, msg));
}
Err(e) => error!("server req zmq recv: {}", e),
Expand All @@ -2176,18 +2257,16 @@ impl ServerSocketManager {
req_send = None;
}
// req_handles_check_send
Select10::R5(()) => {
let (header, msg) = req_in_msg.take().unwrap();

Self::handle_req_message(header, msg, &messages_memory, &req_handles);
}
Select10::R5(()) => Self::handle_req_message(&mut req_in_msg, &req_handles),
// stream_in_recv
Select10::R6(result) => match result {
Ok(msg) => {
if log_enabled!(log::Level::Trace) {
trace!("IN server stream {}", packet_to_string(&msg));
}

let msg = arena::Arc::new(msg, &messages_memory).unwrap();

stream_in_msg = Some(msg);
}
Err(e) => error!("server stream zmq recv: {}", e),
Expand Down Expand Up @@ -2222,9 +2301,7 @@ impl ServerSocketManager {
}
// stream_handles_check_send_any
Select10::R10(()) => {
let msg = stream_in_msg.take().unwrap();

Self::handle_stream_message_any(msg, &messages_memory, &stream_handles);
Self::handle_stream_message_any(&mut stream_in_msg, &stream_handles);
}
}

Expand Down Expand Up @@ -2268,34 +2345,47 @@ impl ServerSocketManager {
}

fn handle_req_message(
header: MultipartHeader,
msg: zmq::Message,
messages_memory: &Arc<arena::ArcMemory<zmq::Message>>,
next_msg: &mut Option<(MultipartHeader, arena::Arc<zmq::Message>)>,
handles: &ServerReqHandles,
) {
let msg = arena::Arc::new(msg, messages_memory).unwrap();
let (header, msg) = next_msg.take().unwrap();

handles.send(header, &msg);
if let Err(ReqHandlesSendError(header)) = handles.send(header, &msg) {
*next_msg = Some((header, msg));
}
}

fn handle_stream_message_any(
msg: zmq::Message,
messages_memory: &Arc<arena::ArcMemory<zmq::Message>>,
next_msg: &mut Option<arena::Arc<zmq::Message>>,
handles: &ServerStreamHandles,
) {
let msg = arena::Arc::new(msg, messages_memory).unwrap();
let msg = next_msg.take().unwrap();

let mut scratch = ParseScratch::new();
let ret = {
let mut scratch = ParseScratch::new();

let (from, ids) = match parse_ids(msg.get(), &mut scratch) {
Ok(ret) => ret,
Err(e) => {
warn!("unable to determine packet id(s): {}", e);
return;
}
let (from, ids) = match parse_ids(msg.get(), &mut scratch) {
Ok(ret) => ret,
Err(e) => {
warn!("unable to determine packet id(s): {}", e);
return;
}
};

handles.send_any(&msg, from, ids)
};

handles.send_any(&msg, from, ids);
match ret {
Ok(()) => {}
Err(StreamHandlesSendError::BadFormat) => warn!("stream send_any: bad format"),
Err(StreamHandlesSendError::NoneReady) => *next_msg = Some(msg),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this *next_msg = Some(msg) for retry purposes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

Err(StreamHandlesSendError::SessionExists) => {
warn!("stream send_any: session id in use")
}
Err(StreamHandlesSendError::SessionCapacityFull) => {
error!("stream send_any: session capacity full")
}
}
}

async fn handle_stream_message_direct(
Expand Down
Loading