Skip to content

Commit

Permalink
Merge pull request #6 from halzy/halzy/3-SendAll
Browse files Browse the repository at this point in the history
feat: #3 Improve efficiency of outgoing packets.
  • Loading branch information
halzy authored Feb 25, 2020
2 parents 37ac8de + e03f5a6 commit 6d621b1
Show file tree
Hide file tree
Showing 15 changed files with 1,101 additions and 553 deletions.
4 changes: 4 additions & 0 deletions .rgignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.git
.rgignore
.gitignore

14 changes: 11 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
[package]
name = "stream_multiplexer"
version = "0.1.0"
version = "0.2.0"
authors = ["Benjamin Halsted <[email protected]>"]
edition = "2018"
license = "MIT OR Apache-2.0"

categories = ["asynchronous", "network-programming"]
description = "Combines many streams into a few."
documentation = "https://docs.rs/stream_multiplexer"
keywords = ["async", "asynchronous", "multiplex", "stream", "tokio"]
readme = "README.md"
repository = "https://github.com/halzy/stream_multiplexer"

[dependencies]
byteorder = "1.3"
bytes = "0.5"
futures = { version = "0.3", default-features = false, features = ["alloc"] }
thiserror = "1.0"
tokio = { version = "0.2", features = ["full"] }
tokio-util = { version = "0.2", features = ["codec"] }
tracing = { version = "0.1", features = ["log"] }
byteorder = "1.3"
bytes = "0.5"
tracing-futures = "0.2"

[dev-dependencies]
tokio = { version = "0.2", features = ["full", "test-util"] }
matches = "0.1"
tracing-subscriber = "0.2"
futures = { version = "0.3", default-features = false, features = ["alloc","std"] }
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# stream_multiplexer

![Rust](https://github.com/halzy/stream_multiplexer/workflows/Rust/badge.svg)
![Crates.io](https://img.shields.io/crates/v/stream_multiplexer)
[![Build Status][actions_badge]][actions]
[![Latest Version][crates_badge]][crates]
[![Rust Documentation][docs_badge]][docs]

Highly unstable API!

This library multiplexes many streams into fewer streams.
New streams are assigned an identifier. Data from those streams are wrapped in a data structure that contains the Id and Bytes, and then funneled into another stream.

[docs_badge]: https://docs.rs/stream_multiplexer/badge.svg
[docs]: https://docs.rs/stream_multiplexer
[crates_badge]: https://img.shields.io/crates/v/stream_multiplexer.svg
[crates]: https://crates.io/crates/stream_multiplexer
[actions_badge]: https://github.com/halzy/stream_multiplexer/workflows/Rust/badge.svg
[actions]: https://github.com/halzy/stream_multiplexer/actions

// FIXME: ...
25 changes: 25 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/// A collection of errors that can be returned.
#[derive(thiserror::Error, Debug)]
pub enum MultiplexerError {
// /// Sending can fail to enqueue a message to a stream.
// #[error("Could not send to the stream")]
// Send(#[from] tokio::sync::mpsc::error::TrySendError<Result<OutgoingMessage<OV>, ()>>),
//
// FIXME: outgoing error stream ?
// /// If the stream that is trying to be sent to has gone away
// #[error("Sending to nonexistent stream {0}")]
// SendNoStream(StreamId),

// #[error("Sending to full stream {0}")]
// StreamFull(StreamId),

// #[error("Sending to full stream {0}")]
// StreamClosed(StreamId),
/// Wrapper around std::io::Error
#[error("IoError")]
IoError(#[from] std::io::Error),

/// Nothing to see here
#[error("Should never happen")]
UnitError,
}
144 changes: 51 additions & 93 deletions src/halt.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
use futures::prelude::*;
use futures::task::{AtomicWaker, Context, Poll};
use tokio::io::{
AsyncRead, AsyncWrite, Error as TioError, ErrorKind as TioErrorKind, ReadHalf,
Result as TioResult, WriteHalf,
};
use tokio::net::TcpStream;
use tokio::sync::oneshot;

use std::net::Shutdown;

use std::pin::Pin;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::Arc;

pub trait StreamShutdown {
fn shutdown(&self) -> TioResult<()>;
}

impl StreamShutdown for TcpStream {
#[tracing::instrument(level = "trace", skip(self))]
fn shutdown(&self) -> TioResult<()> {
self.shutdown(Shutdown::Read)
}
}

#[derive(Debug)]
struct Inner {
waker: AtomicWaker,
Expand All @@ -37,17 +20,15 @@ pub struct HaltRead {
impl HaltRead {
#[tracing::instrument(level = "trace", skip(self))]
pub fn signal(&self) {
tracing::trace!("setting atomic bool, triggering waker");
self.inner.set.store(true, Relaxed);
self.inner.waker.wake();
}

#[tracing::instrument(level = "trace", skip(read, writer))]
pub fn wrap<T>(
read: ReadHalf<T>,
writer: oneshot::Receiver<WriteHalf<T>>,
) -> (Self, HaltAsyncRead<T>)
#[tracing::instrument(level = "trace", skip(read))]
pub fn wrap<St>(read: St) -> (Self, HaltAsyncRead<St>)
where
T: AsyncRead + AsyncWrite,
St: Stream,
{
let inner = Arc::new(Inner {
waker: AtomicWaker::new(),
Expand All @@ -60,83 +41,64 @@ impl HaltRead {
HaltAsyncRead {
inner,
read: Some(read),
writer,
},
)
}
}

#[derive(Debug)]
pub struct HaltAsyncRead<T> {
pub struct HaltAsyncRead<St> {
inner: Arc<Inner>,
read: Option<ReadHalf<T>>,
writer: oneshot::Receiver<WriteHalf<T>>,
read: Option<St>,
}
impl<T> HaltAsyncRead<T>
impl<St> HaltAsyncRead<St>
where
T: StreamShutdown,
St: Stream,
{
#[tracing::instrument(level = "trace", skip(self))]
fn shutdown(&mut self) -> Poll<TioResult<usize>> {
use tokio::sync::oneshot::error::TryRecvError;

fn shutdown(&mut self) -> Poll<Option<St::Item>> {
match self.read {
None => Poll::Ready(Err(TioError::new(
TioErrorKind::Other,
"Double shutdown on stream.",
))),
None => {
tracing::error!("stream already shutdown");
}
Some(_) => {
// _ = reader, it's taken below
match self.writer.try_recv() {
Err(TryRecvError::Empty) => {
// Return pending if we do not yet have the write half
Poll::Pending
}
Err(TryRecvError::Closed) => {
// Because we take() the reader below and guard against none above.
unreachable!()
}
Ok(writer) => {
let reader = self
.read
.take()
.expect("Reader should still exist, was checked above");
let stream = reader.unsplit(writer);
stream.shutdown()?;
Poll::Ready(Ok(0)) // returning 0 will signal EOF and close the connection.
}
}
let _ = self.read.take();
}
}

Poll::Ready(None)
}
}
impl<T> Unpin for HaltAsyncRead<T> {}
impl<T> AsyncRead for HaltAsyncRead<T>

impl<St> Unpin for HaltAsyncRead<St> where St: Stream + Unpin {}
impl<St> Stream for HaltAsyncRead<St>
where
T: StreamShutdown + AsyncRead + AsyncWrite,
St: Stream + Unpin,
{
#[tracing::instrument(level = "trace", skip(self))]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<TioResult<usize>> {
type Item = St::Item;

#[tracing::instrument(level = "trace", skip(self, ctx))]
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Option<Self::Item>> {
// quick check to avoid registration if already done.
if self.inner.set.load(Relaxed) {
tracing::trace!("pre-waker shutdown");
return self.shutdown();
}

self.inner.waker.register(cx.waker());
tracing::trace!("waker registration");
self.inner.waker.register(ctx.waker());

// Need to check condition **after** `register` to avoid a race
// condition that would result in lost notifications.
if self.inner.set.load(Relaxed) {
tracing::trace!("shutting down");
self.shutdown()
} else {
// is only ever Some() here because inner.set being true
// causes self.read to become none, and we take the other
// branches.
Pin::new(&mut self.read.as_mut().unwrap()).poll_read(cx, buf)
tracing::trace!("self.read.poll_read()");
Pin::new(&mut self.read.as_mut().unwrap()).poll_next(ctx)
}
}
}
Expand All @@ -145,50 +107,46 @@ where
mod tests {
use super::*;

use tokio::io::AsyncReadExt as _;
use tokio::io::ErrorKind;

use futures::future::FutureExt;
use bytes::Bytes;
use tokio_util::codec::length_delimited::LengthDelimitedCodec;

use std::io::Cursor;

impl StreamShutdown for Cursor<Vec<u8>> {
fn shutdown(&self) -> TioResult<()> {
Ok(())
}
}

#[tokio::test(basic_scheduler)]
async fn halt() {
//crate::tests::init_logging();

// Stream of u8, from 0 to 15
let cursor: Cursor<Vec<u8>> = Cursor::new((0..16).into_iter().collect());
let (reader, writer) = tokio::io::split(cursor);
let (reader, _writer) = tokio::io::split(cursor);
let framed_reader = LengthDelimitedCodec::builder()
.length_field_length(1)
.new_read(reader);

let (stop_tx, stop_rx) = oneshot::channel();
let (halt, mut reader) = HaltRead::wrap(reader, stop_rx);
let (halt, mut reader) = HaltRead::wrap(framed_reader);

assert_eq!(0_u8, reader.read_u8().await.unwrap());
assert_eq!(1_u8, reader.read_u8().await.unwrap());
// Zero bytes,
assert_eq!(Bytes::from(vec![]), reader.next().await.unwrap().unwrap());

// 1 byte, value of 2
assert_eq!(
Bytes::from(vec![2_u8]),
reader.next().await.unwrap().unwrap()
);

// Shut down the read stream
halt.signal();

// Check that we can't read while waitng for the writer
assert!(reader.read_u8().now_or_never().is_none());

// Send the writer to finish the shutdown
stop_tx.send(writer).unwrap();
assert!(reader.next().await.is_none());

// check that reading has stopped
assert_eq!(
reader.read_u8().await.unwrap_err().kind(),
ErrorKind::UnexpectedEof
);
assert!(reader.next().await.is_none());

// Shut down the read stream
halt.signal();

// Ensure the double shutdown error is returned
assert_eq!(reader.read_u8().await.unwrap_err().kind(), ErrorKind::Other);
assert!(reader.next().await.is_none());
}
}
8 changes: 8 additions & 0 deletions src/id_gen.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
use crate::StreamId;

/// Provided to MultiplexerSenders to override the default incrementing generator
pub trait IdGen: Default {
/// Produces a new Id
fn next(&mut self) -> StreamId;

/// The current Id
fn id(&self) -> StreamId;

/// Useful for setting a random seed, or a starting value.
fn seed(&mut self, _seed: usize) {}
}

/// The default IdGen for MultiplexerSenders
#[derive(Default, Copy, Clone, PartialEq, Debug)]
pub struct IncrementIdGen {
id: StreamId,
}
impl Unpin for IncrementIdGen {}
impl IdGen for IncrementIdGen {
/// Find the next available StreamId
#[tracing::instrument(level = "trace", skip(self))]
Expand Down
Loading

0 comments on commit 6d621b1

Please sign in to comment.