diff --git a/Cargo.toml b/Cargo.toml index f15ebc6..6a7491b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,8 +29,9 @@ stream = [] url = ["tungstenite/url"] [dependencies] -log = "0.4.17" futures-util = { version = "0.3.28", default-features = false, features = ["sink", "std"] } +log = "0.4.17" +pin-project = "1.1.8" tokio = { version = "1.0.0", default-features = false, features = ["io-util"] } [dependencies.tungstenite] diff --git a/src/compat.rs b/src/compat.rs index 6bbd0e6..edf4fb5 100644 --- a/src/compat.rs +++ b/src/compat.rs @@ -16,7 +16,9 @@ pub(crate) enum ContextWaker { } #[derive(Debug)] +#[pin_project::pin_project] pub(crate) struct AllowStd { + #[pin] inner: S, // We have the problem that external read operations (i.e. the Stream impl) // can trigger both read (AsyncRead) and write (AsyncWrite) operations on @@ -115,21 +117,27 @@ impl task::ArcWake for WakerProxy { } } -impl AllowStd -where - S: Unpin, -{ - fn with_context(&mut self, kind: ContextWaker, f: F) -> Poll> +impl AllowStd { + fn with_context( + self: Pin<&mut Self>, + kind: ContextWaker, + f: F, + ) -> Poll> where F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll>, { trace!("{}:{} AllowStd.with_context", file!(), line!()); + + let this = self.project(); + let waker = match kind { - ContextWaker::Read => task::waker_ref(&self.read_waker_proxy), - ContextWaker::Write => task::waker_ref(&self.write_waker_proxy), + ContextWaker::Read => task::waker_ref(&this.read_waker_proxy), + ContextWaker::Write => task::waker_ref(&this.write_waker_proxy), }; + let mut context = task::Context::from_waker(&waker); - f(&mut context, Pin::new(&mut self.inner)) + + f(&mut context, this.inner) } pub(crate) fn get_mut(&mut self) -> &mut S { @@ -143,12 +151,20 @@ where impl Read for AllowStd where - S: AsyncRead + Unpin, + S: AsyncRead, { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { trace!("{}:{} Read.read", file!(), line!()); + + let this = unsafe { + // SAFETY: we rely on the fact `AllowStd` is only used internally and + // the wrapper that uses it is going to pin it. + Pin::new_unchecked(self) + }; + let mut buf = ReadBuf::new(buf); - match self.with_context(ContextWaker::Read, |ctx, stream| { + + match this.with_context(ContextWaker::Read, |ctx, stream| { trace!("{}:{} Read.with_context read -> poll_read", file!(), line!()); stream.poll_read(ctx, &mut buf) }) { @@ -161,11 +177,18 @@ where impl Write for AllowStd where - S: AsyncWrite + Unpin, + S: AsyncWrite, { fn write(&mut self, buf: &[u8]) -> std::io::Result { trace!("{}:{} Write.write", file!(), line!()); - match self.with_context(ContextWaker::Write, |ctx, stream| { + + let this = unsafe { + // SAFETY: we rely on the fact `AllowStd` is only used internally and + // the wrapper that uses it is going to pin it. + Pin::new_unchecked(self) + }; + + match this.with_context(ContextWaker::Write, |ctx, stream| { trace!("{}:{} Write.with_context write -> poll_write", file!(), line!()); stream.poll_write(ctx, buf) }) { @@ -176,7 +199,14 @@ where fn flush(&mut self) -> std::io::Result<()> { trace!("{}:{} Write.flush", file!(), line!()); - match self.with_context(ContextWaker::Write, |ctx, stream| { + + let this = unsafe { + // SAFETY: we rely on the fact `AllowStd` is only used internally and + // the wrapper that uses it is going to pin it. + Pin::new_unchecked(self) + }; + + match this.with_context(ContextWaker::Write, |ctx, stream| { trace!("{}:{} Write.with_context flush -> poll_flush", file!(), line!()); stream.poll_flush(ctx) }) { diff --git a/src/lib.rs b/src/lib.rs index 679640a..39c3e74 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -189,7 +189,9 @@ where /// them in `futures-rs` crate documentation or have a look on the examples /// and unit tests for this crate. #[derive(Debug)] +#[pin_project::pin_project] pub struct WebSocketStream { + #[pin] inner: WebSocket>, closing: bool, ended: bool, @@ -234,17 +236,23 @@ impl WebSocketStream { Self { inner: ws, closing: false, ended: false, ready: true } } - fn with_context(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R + fn with_context( + self: Pin<&mut Self>, + ctx: Option<(ContextWaker, &mut Context<'_>)>, + f: F, + ) -> R where - S: Unpin, - F: FnOnce(&mut WebSocket>) -> R, + F: FnOnce(Pin<&mut WebSocket>>) -> R, AllowStd: Read + Write, { trace!("{}:{} WebSocketStream.with_context", file!(), line!()); + let this = self.project(); + if let Some((kind, ctx)) = ctx { - self.inner.get_mut().set_waker(kind, ctx.waker()); + this.inner.get_ref().set_waker(kind, ctx.waker()); } - f(&mut self.inner) + + f(this.inner) } /// Returns a shared reference to the inner stream. @@ -279,7 +287,7 @@ impl WebSocketStream { impl Stream for WebSocketStream where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite, { type Item = Result; @@ -293,13 +301,22 @@ where return Poll::Ready(None); } - match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| { - trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!()); - cvt(s.read()) - })) { + match futures_util::ready!(self.as_mut().with_context( + Some((ContextWaker::Read, cx)), + |s| { + trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!()); + unsafe { + // SAFETY: library's `Read` impl is going to Pin anyway + cvt(s.get_unchecked_mut().read()) + } + } + )) { Ok(v) => Poll::Ready(Some(Ok(v))), Err(e) => { - self.ended = true; + unsafe { + // SAFETY: not moving out + self.get_unchecked_mut().ended = true; + } if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) { Poll::Ready(None) } else { @@ -312,7 +329,7 @@ where impl FusedStream for WebSocketStream where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite, { fn is_terminated(&self) -> bool { self.ended @@ -321,7 +338,7 @@ where impl Sink for WebSocketStream where - T: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite, { type Error = WsError; @@ -330,27 +347,47 @@ where Poll::Ready(Ok(())) } else { // Currently blocked so try to flush the blockage away - (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| { - self.ready = true; - r - }) + self.as_mut() + .with_context(Some((ContextWaker::Write, cx)), |s| unsafe { + // SAFETY: library's `Write` impl is going to Pin anyway + cvt(s.get_unchecked_mut().flush()) + }) + .map(|r| { + unsafe { + // SAFETY: not moving out + self.get_unchecked_mut().ready = true; + } + r + }) } } fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - match (*self).with_context(None, |s| s.write(item)) { + match self.as_mut().with_context(None, |s| unsafe { + // SAFETY: library's `Write` impl is going to Pin anyway + s.get_unchecked_mut().write(item) + }) { Ok(()) => { - self.ready = true; + unsafe { + // SAFETY: not moving out + self.get_unchecked_mut().ready = true; + } Ok(()) } Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => { // the message was accepted and queued so not an error // but `poll_ready` will now start trying to flush the block - self.ready = false; + unsafe { + // SAFETY: not moving out + self.get_unchecked_mut().ready = false; + } Ok(()) } Err(e) => { - self.ready = true; + unsafe { + // SAFETY: not moving out + self.get_unchecked_mut().ready = true; + } debug!("websocket start_send error: {}", e); Err(e) } @@ -358,23 +395,39 @@ where } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - (*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| { - self.ready = true; - match r { - // WebSocket connection has just been closed. Flushing completed, not an error. - Err(WsError::ConnectionClosed) => Ok(()), - other => other, - } - }) + self.as_mut() + .with_context(Some((ContextWaker::Write, cx)), |s| unsafe { + // SAFETY: library's `Write` impl is going to Pin anyway + cvt(s.get_unchecked_mut().flush()) + }) + .map(|r| { + unsafe { + self.get_unchecked_mut().ready = true; + } + match r { + // WebSocket connection has just been closed. Flushing completed, not an error. + Err(WsError::ConnectionClosed) => Ok(()), + other => other, + } + }) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.ready = true; + unsafe { + self.as_mut().get_unchecked_mut().ready = true; + } + let res = if self.closing { // After queueing it, we call `flush` to drive the close handshake to completion. - (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush()) + self.as_mut().with_context(Some((ContextWaker::Write, cx)), |s| unsafe { + // SAFETY: library's `Write` impl is going to Pin anyway + s.get_unchecked_mut().flush() + }) } else { - (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None)) + self.as_mut().with_context(Some((ContextWaker::Write, cx)), |s| unsafe { + // SAFETY: library's `Write` impl is going to Pin anyway + s.get_unchecked_mut().close(None) + }) }; match res { @@ -382,7 +435,10 @@ where Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())), Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => { trace!("WouldBlock"); - self.closing = true; + unsafe { + // SAFETY: not moving out + self.get_unchecked_mut().closing = true; + } Poll::Pending } Err(err) => {