Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove Unpin requirement for the inner stream #358

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ stream = []
url = ["tungstenite/url"]

[dependencies]
log = "0.4.17"
futures-util = { version = "0.3.28", default-features = false, features = ["sink", "std"] }
log = "0.4.17"
pin-project = "1.1.8"
tokio = { version = "1.0.0", default-features = false, features = ["io-util"] }

[dependencies.tungstenite]
Expand Down
56 changes: 43 additions & 13 deletions src/compat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ pub(crate) enum ContextWaker {
}

#[derive(Debug)]
#[pin_project::pin_project]
pub(crate) struct AllowStd<S> {
#[pin]
inner: S,
// We have the problem that external read operations (i.e. the Stream impl)
// can trigger both read (AsyncRead) and write (AsyncWrite) operations on
Expand Down Expand Up @@ -115,21 +117,27 @@ impl task::ArcWake for WakerProxy {
}
}

impl<S> AllowStd<S>
where
S: Unpin,
{
fn with_context<F, R>(&mut self, kind: ContextWaker, f: F) -> Poll<std::io::Result<R>>
impl<S> AllowStd<S> {
fn with_context<F, R>(
self: Pin<&mut Self>,
kind: ContextWaker,
f: F,
) -> Poll<std::io::Result<R>>
where
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> Poll<std::io::Result<R>>,
{
trace!("{}:{} AllowStd.with_context", file!(), line!());

let this = self.project();

let waker = match kind {
ContextWaker::Read => task::waker_ref(&self.read_waker_proxy),
ContextWaker::Write => task::waker_ref(&self.write_waker_proxy),
ContextWaker::Read => task::waker_ref(&this.read_waker_proxy),
ContextWaker::Write => task::waker_ref(&this.write_waker_proxy),
};

let mut context = task::Context::from_waker(&waker);
f(&mut context, Pin::new(&mut self.inner))

f(&mut context, this.inner)
}

pub(crate) fn get_mut(&mut self) -> &mut S {
Expand All @@ -143,12 +151,20 @@ where

impl<S> Read for AllowStd<S>
where
S: AsyncRead + Unpin,
S: AsyncRead,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
trace!("{}:{} Read.read", file!(), line!());

let this = unsafe {
// SAFETY: we rely on the fact `AllowStd` is only used internally and
// the wrapper that uses it is going to pin it.
Pin::new_unchecked(self)
};

let mut buf = ReadBuf::new(buf);
match self.with_context(ContextWaker::Read, |ctx, stream| {

match this.with_context(ContextWaker::Read, |ctx, stream| {
trace!("{}:{} Read.with_context read -> poll_read", file!(), line!());
stream.poll_read(ctx, &mut buf)
}) {
Expand All @@ -161,11 +177,18 @@ where

impl<S> Write for AllowStd<S>
where
S: AsyncWrite + Unpin,
S: AsyncWrite,
{
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
trace!("{}:{} Write.write", file!(), line!());
match self.with_context(ContextWaker::Write, |ctx, stream| {

let this = unsafe {
// SAFETY: we rely on the fact `AllowStd` is only used internally and
// the wrapper that uses it is going to pin it.
Pin::new_unchecked(self)
};

match this.with_context(ContextWaker::Write, |ctx, stream| {
trace!("{}:{} Write.with_context write -> poll_write", file!(), line!());
stream.poll_write(ctx, buf)
}) {
Expand All @@ -176,7 +199,14 @@ where

fn flush(&mut self) -> std::io::Result<()> {
trace!("{}:{} Write.flush", file!(), line!());
match self.with_context(ContextWaker::Write, |ctx, stream| {

let this = unsafe {
// SAFETY: we rely on the fact `AllowStd` is only used internally and
// the wrapper that uses it is going to pin it.
Pin::new_unchecked(self)
};

match this.with_context(ContextWaker::Write, |ctx, stream| {
trace!("{}:{} Write.with_context flush -> poll_flush", file!(), line!());
stream.poll_flush(ctx)
}) {
Expand Down
122 changes: 89 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ where
/// them in `futures-rs` crate documentation or have a look on the examples
/// and unit tests for this crate.
#[derive(Debug)]
#[pin_project::pin_project]
pub struct WebSocketStream<S> {
#[pin]
inner: WebSocket<AllowStd<S>>,
closing: bool,
ended: bool,
Expand Down Expand Up @@ -234,17 +236,23 @@ impl<S> WebSocketStream<S> {
Self { inner: ws, closing: false, ended: false, ready: true }
}

fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
fn with_context<F, R>(
self: Pin<&mut Self>,
ctx: Option<(ContextWaker, &mut Context<'_>)>,
f: F,
) -> R
where
S: Unpin,
F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
F: FnOnce(Pin<&mut WebSocket<AllowStd<S>>>) -> R,
AllowStd<S>: Read + Write,
{
trace!("{}:{} WebSocketStream.with_context", file!(), line!());
let this = self.project();

if let Some((kind, ctx)) = ctx {
self.inner.get_mut().set_waker(kind, ctx.waker());
this.inner.get_ref().set_waker(kind, ctx.waker());
}
f(&mut self.inner)

f(this.inner)
}

/// Returns a shared reference to the inner stream.
Expand Down Expand Up @@ -279,7 +287,7 @@ impl<S> WebSocketStream<S> {

impl<T> Stream for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite,
{
type Item = Result<Message, WsError>;

Expand All @@ -293,13 +301,22 @@ where
return Poll::Ready(None);
}

match futures_util::ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!());
cvt(s.read())
})) {
match futures_util::ready!(self.as_mut().with_context(
Some((ContextWaker::Read, cx)),
|s| {
trace!("{}:{} Stream.with_context poll_next -> read()", file!(), line!());
unsafe {
// SAFETY: library's `Read` impl is going to Pin anyway
cvt(s.get_unchecked_mut().read())
}
}
)) {
Ok(v) => Poll::Ready(Some(Ok(v))),
Err(e) => {
self.ended = true;
unsafe {
// SAFETY: not moving out
self.get_unchecked_mut().ended = true;
}
if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
Poll::Ready(None)
} else {
Expand All @@ -312,7 +329,7 @@ where

impl<T> FusedStream for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite,
{
fn is_terminated(&self) -> bool {
self.ended
Expand All @@ -321,7 +338,7 @@ where

impl<T> Sink<Message> for WebSocketStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite,
{
type Error = WsError;

Expand All @@ -330,59 +347,98 @@ where
Poll::Ready(Ok(()))
} else {
// Currently blocked so try to flush the blockage away
(*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
self.ready = true;
r
})
self.as_mut()
.with_context(Some((ContextWaker::Write, cx)), |s| unsafe {
// SAFETY: library's `Write` impl is going to Pin anyway
cvt(s.get_unchecked_mut().flush())
})
.map(|r| {
unsafe {
// SAFETY: not moving out
self.get_unchecked_mut().ready = true;
}
r
})
}
}

fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match (*self).with_context(None, |s| s.write(item)) {
match self.as_mut().with_context(None, |s| unsafe {
// SAFETY: library's `Write` impl is going to Pin anyway
s.get_unchecked_mut().write(item)
}) {
Ok(()) => {
self.ready = true;
unsafe {
// SAFETY: not moving out
self.get_unchecked_mut().ready = true;
}
Ok(())
}
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
// the message was accepted and queued so not an error
// but `poll_ready` will now start trying to flush the block
self.ready = false;
unsafe {
// SAFETY: not moving out
self.get_unchecked_mut().ready = false;
}
Ok(())
}
Err(e) => {
self.ready = true;
unsafe {
// SAFETY: not moving out
self.get_unchecked_mut().ready = true;
}
debug!("websocket start_send error: {}", e);
Err(e)
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
(*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
self.ready = true;
match r {
// WebSocket connection has just been closed. Flushing completed, not an error.
Err(WsError::ConnectionClosed) => Ok(()),
other => other,
}
})
self.as_mut()
.with_context(Some((ContextWaker::Write, cx)), |s| unsafe {
// SAFETY: library's `Write` impl is going to Pin anyway
cvt(s.get_unchecked_mut().flush())
})
.map(|r| {
unsafe {
self.get_unchecked_mut().ready = true;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing SAFETY comment:

// SAFETY: not moving out

}
match r {
// WebSocket connection has just been closed. Flushing completed, not an error.
Err(WsError::ConnectionClosed) => Ok(()),
other => other,
}
})
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.ready = true;
unsafe {
self.as_mut().get_unchecked_mut().ready = true;
}

let res = if self.closing {
// After queueing it, we call `flush` to drive the close handshake to completion.
(*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
self.as_mut().with_context(Some((ContextWaker::Write, cx)), |s| unsafe {
// SAFETY: library's `Write` impl is going to Pin anyway
s.get_unchecked_mut().flush()
})
} else {
(*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
self.as_mut().with_context(Some((ContextWaker::Write, cx)), |s| unsafe {
// SAFETY: library's `Write` impl is going to Pin anyway
s.get_unchecked_mut().close(None)
})
};

match res {
Ok(()) => Poll::Ready(Ok(())),
Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
trace!("WouldBlock");
self.closing = true;
unsafe {
// SAFETY: not moving out
self.get_unchecked_mut().closing = true;
}
Poll::Pending
}
Err(err) => {
Expand Down