Skip to content

Commit

Permalink
Merge pull request #46 from Congyuwang/compat-layer
Browse files Browse the repository at this point in the history
Make it possible to opt out tokio
  • Loading branch information
sticnarf authored Jul 25, 2024
2 parents f1eac4f + 26139a7 commit aa311ea
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 50 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ edition = "2018"
travis-ci = { repository = "sticnarf/tokio-socks" }

[features]
default = ["tokio"]
tor = []
futures-io = ["dep:futures-io"]

[[example]]
name = "socket"
Expand All @@ -28,7 +30,8 @@ required-features = ["tor"]

[dependencies]
futures-util = { version = "0.3", default-features = false }
tokio = { version = "1.0", features = ["io-util", "net"] }
futures-io = { version = "0.3", optional = true }
tokio = { version = "1.0", features = ["io-util", "net"], optional = true }
either = "1"
thiserror = "1.0"

Expand Down
112 changes: 112 additions & 0 deletions src/io/futures.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
//! Compat layer for `futures-io` types.
//!
//! This module provides a compatibility layer for using `futures-io` types with
//! `async-socks5`. AsyncSocket is implemented for Compat<S> where S is an
//! AsyncRead + AsyncWrite + Unpin type from `futures-io`.
use super::AsyncSocket;
use futures_io::{AsyncRead, AsyncWrite};
use std::{
io::Result as IoResult,
ops::{Deref, DerefMut},
pin::Pin,
task::{Context, Poll},
};

/// A compatibility layer for using `futures-io` types with `async-socks5`.
///
/// See `FuturesIoCompatExt` trait for details.
pub struct Compat<S>(S);

impl<S> Compat<S> {
pub(crate) fn new(inner: S) -> Self {
Compat(inner)
}

/// Unwraps this Compat, returning the inner value.
pub fn into_inner(self) -> S {
self.0
}
}

impl<S> Deref for Compat<S> {
type Target = S;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<S> DerefMut for Compat<S> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

/// Import this trait to use socks with `futures-io` compatible runtime.
///
/// Example:
/// ```no_run
/// use async_std::os::unix::net::UnixStream;
/// use tokio_socks::{io::FuturesIoCompatExt as _, tcp::Socks5Stream};
///
/// let socket = UnixStream::connect(proxy_addr)
/// .await
/// .map_err(Error::Io)?
/// .compat(); // Compat<UnixStream>
/// let conn =
/// Socks5Stream::connect_with_password_and_socket(socket, target, username, pswd).await?;
/// // Socks5Stream has implemented futures-io AsyncRead + AsyncWrite.
/// ```
pub trait FuturesIoCompatExt {
fn compat(self) -> Compat<Self>
where
Self: Sized;
}

impl<S> FuturesIoCompatExt for S
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn compat(self) -> Compat<Self> {
Compat::new(self)
}
}

impl<S> AsyncSocket for Compat<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<IoResult<usize>> {
AsyncRead::poll_read(Pin::new(self.get_mut().deref_mut()), cx, buf)
}

fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
AsyncWrite::poll_write(Pin::new(self.get_mut().deref_mut()), cx, buf)
}
}

impl<S> AsyncRead for Compat<S>
where
S: AsyncRead + Unpin,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<IoResult<usize>> {
AsyncRead::poll_read(Pin::new(self.get_mut().deref_mut()), cx, buf)
}
}

impl<S> AsyncWrite for Compat<S>
where
S: AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
AsyncWrite::poll_write(Pin::new(self.get_mut().deref_mut()), cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
AsyncWrite::poll_flush(Pin::new(self.get_mut().deref_mut()), cx)
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
AsyncWrite::poll_close(Pin::new(self.get_mut().deref_mut()), cx)
}
}
110 changes: 110 additions & 0 deletions src/io/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
//! Asynchronous I/O abstractions for sockets.
#[cfg(feature = "futures-io")]
mod futures;
#[cfg(feature = "tokio")]
mod tokio;

use futures_util::ready;
use std::{
future::Future,
io::{Error, ErrorKind},
mem,
pin::Pin,
task::{Context, Poll},
};

#[cfg(feature = "futures-io")]
pub use futures::{Compat, FuturesIoCompatExt};

/// A trait for asynchronous socket I/O.
///
/// Any type that implements tokio's `AsyncRead` and `AsyncWrite` traits
/// has implemented `AsyncSocket` trait.
///
/// Use `FuturesIoCompatExt` to wrap `futures-io` types as `AsyncSocket` types.
pub trait AsyncSocket {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize, Error>>;

fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>>;
}

pub(crate) trait AsyncSocketExt {
fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadExact<'a, Self>
where
Self: Sized;

fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> WriteAll<'a, Self>
where
Self: Sized;
}

impl<S: AsyncSocket> AsyncSocketExt for S {
fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadExact<'a, Self>
where
Self: Sized,
{
let capacity = buf.len();
ReadExact {
reader: self,
buf,
capacity,
}
}

fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> WriteAll<'a, Self>
where
Self: Sized,
{
WriteAll { writer: self, buf }
}
}

pub(crate) struct ReadExact<'a, R> {
reader: &'a mut R,
buf: &'a mut [u8],
capacity: usize,
}

impl<R: AsyncSocket + Unpin> Future for ReadExact<'_, R> {
type Output = Result<usize, Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
while !this.buf.is_empty() {
let n = ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf))?;
{
let (_, rest) = mem::take(&mut this.buf).split_at_mut(n);
this.buf = rest;
}
if n == 0 {
return Poll::Ready(Err(ErrorKind::UnexpectedEof.into()));
}
}
Poll::Ready(Ok(this.capacity))
}
}

pub(crate) struct WriteAll<'a, W> {
writer: &'a mut W,
buf: &'a [u8],
}

impl<W: AsyncSocket + Unpin> Future for WriteAll<'_, W> {
type Output = Result<(), Error>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
while !this.buf.is_empty() {
let n = ready!(Pin::new(&mut *this.writer).poll_write(cx, this.buf))?;
{
let (_, rest) = mem::take(&mut this.buf).split_at(n);
this.buf = rest;
}
if n == 0 {
return Poll::Ready(Err(ErrorKind::WriteZero.into()));
}
}

Poll::Ready(Ok(()))
}
}
25 changes: 25 additions & 0 deletions src/io/tokio.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//! AsyncSocket trait implementation for tokio's AsyncRead + AsyncWrite
//! traits.
use super::AsyncSocket;
use futures_util::ready;
use std::{
io::Result as IoResult,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

impl<S> AsyncSocket for S
where
S: AsyncRead + AsyncWrite,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<IoResult<usize>> {
let mut buf = ReadBuf::new(buf);
ready!(AsyncRead::poll_read(self, cx, &mut buf))?;
Poll::Ready(Ok(buf.filled().len()))
}

fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<IoResult<usize>> {
AsyncWrite::poll_write(self, cx, buf)
}
}
9 changes: 5 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use futures_util::{
};
use std::{
borrow::Cow,
io,
io::Result as IoResult,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
pin::Pin,
task::{Context, Poll},
Expand Down Expand Up @@ -51,7 +51,7 @@ impl<'a> ToProxyAddrs for &'a [SocketAddr] {
type Output = ProxyAddrsStream;

fn to_proxy_addrs(&self) -> Self::Output {
ProxyAddrsStream(Some(io::Result::Ok(self.to_vec().into_iter())))
ProxyAddrsStream(Some(IoResult::Ok(self.to_vec().into_iter())))
}
}

Expand Down Expand Up @@ -79,7 +79,7 @@ impl<'a, T: ToProxyAddrs + ?Sized> ToProxyAddrs for &'a T {
}
}

pub struct ProxyAddrsStream(Option<io::Result<vec::IntoIter<SocketAddr>>>);
pub struct ProxyAddrsStream(Option<IoResult<vec::IntoIter<SocketAddr>>>);

impl Stream for ProxyAddrsStream {
type Item = Result<SocketAddr>;
Expand Down Expand Up @@ -123,7 +123,7 @@ impl<'a> TargetAddr<'a> {
impl<'a> ToSocketAddrs for TargetAddr<'a> {
type Iter = Either<std::option::IntoIter<SocketAddr>, std::vec::IntoIter<SocketAddr>>;

fn to_socket_addrs(&self) -> io::Result<Self::Iter> {
fn to_socket_addrs(&self) -> IoResult<Self::Iter> {
Ok(match self {
TargetAddr::Ip(addr) => Either::Left(addr.to_socket_addrs()?),
TargetAddr::Domain(domain, port) => Either::Right((&**domain, *port).to_socket_addrs()?),
Expand Down Expand Up @@ -260,6 +260,7 @@ impl<'a> Authentication<'a> {
}

mod error;
pub mod io;
pub mod tcp;

#[cfg(test)]
Expand Down
Loading

0 comments on commit aa311ea

Please sign in to comment.