diff --git a/boring/src/ssl/bio.rs b/boring/src/ssl/bio.rs index b88fdaff..3395130d 100644 --- a/boring/src/ssl/bio.rs +++ b/boring/src/ssl/bio.rs @@ -166,9 +166,13 @@ unsafe extern "C" fn ctrl( let state = state::(bio); if cmd == BIO_CTRL_FLUSH { + BIO_clear_retry_flags(bio); match catch_unwind(AssertUnwindSafe(|| state.stream.flush())) { Ok(Ok(())) => 1, Ok(Err(err)) => { + if retriable_error(&err) { + BIO_set_retry_write(bio); + } state.error = Some(err); 0 } diff --git a/boring/src/ssl/test/mod.rs b/boring/src/ssl/test/mod.rs index 200ced24..f345cd3d 100644 --- a/boring/src/ssl/test/mod.rs +++ b/boring/src/ssl/test/mod.rs @@ -563,6 +563,68 @@ fn test_select_cert_alpn_extension() { ); } +#[test] +fn test_io_retry() { + #[derive(Debug)] + struct RetryStream { + inner: TcpStream, + first_read: bool, + first_write: bool, + first_flush: bool, + } + + impl Read for RetryStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + if mem::replace(&mut self.first_read, false) { + Err(io::Error::new(io::ErrorKind::WouldBlock, "first read")) + } else { + self.inner.read(buf) + } + } + } + + impl Write for RetryStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + if mem::replace(&mut self.first_write, false) { + Err(io::Error::new(io::ErrorKind::WouldBlock, "first write")) + } else { + self.inner.write(buf) + } + } + + fn flush(&mut self) -> io::Result<()> { + if mem::replace(&mut self.first_flush, false) { + Err(io::Error::new(io::ErrorKind::WouldBlock, "first flush")) + } else { + self.inner.flush() + } + } + } + + let server = Server::builder().build(); + + let stream = RetryStream { + inner: server.connect_tcp(), + first_read: true, + first_write: true, + first_flush: true, + }; + + let ctx = SslContext::builder(SslMethod::tls()).unwrap(); + let mut s = match Ssl::new(&ctx.build()).unwrap().connect(stream) { + Ok(mut s) => return s.read_exact(&mut [0]).unwrap(), + Err(HandshakeError::WouldBlock(s)) => s, + Err(_) => panic!("should not fail on setup"), + }; + loop { + match s.handshake() { + Ok(mut s) => return s.read_exact(&mut [0]).unwrap(), + Err(HandshakeError::WouldBlock(mid_s)) => s = mid_s, + Err(_) => panic!("should not fail on handshake"), + } + } +} + #[test] #[should_panic(expected = "blammo")] fn write_panic() {