Skip to content

Commit

Permalink
Merge branch 'master' into serde_and_borsh_support
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kulkarni committed Feb 19, 2025
2 parents 12f0b41 + 8737caf commit 76c76f4
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 25 deletions.
64 changes: 57 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use futures_util::{
use serde::{Deserialize, Serialize};
use std::{
borrow::Cow,
fmt,
io::Result as IoResult,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
pin::Pin,
Expand Down Expand Up @@ -48,7 +49,7 @@ trivial_impl_to_proxy_addrs!((Ipv6Addr, u16));
trivial_impl_to_proxy_addrs!(SocketAddrV4);
trivial_impl_to_proxy_addrs!(SocketAddrV6);

impl<'a> ToProxyAddrs for &'a [SocketAddr] {
impl ToProxyAddrs for &[SocketAddr] {
type Output = ProxyAddrsStream;

fn to_proxy_addrs(&self) -> Self::Output {
Expand All @@ -65,15 +66,15 @@ impl ToProxyAddrs for str {
}
}

impl<'a> ToProxyAddrs for (&'a str, u16) {
impl ToProxyAddrs for (&str, u16) {
type Output = ProxyAddrsStream;

fn to_proxy_addrs(&self) -> Self::Output {
ProxyAddrsStream(Some(self.to_socket_addrs()))
}
}

impl<'a, T: ToProxyAddrs + ?Sized> ToProxyAddrs for &'a T {
impl<T: ToProxyAddrs + ?Sized> ToProxyAddrs for &T {
type Output = T::Output;

fn to_proxy_addrs(&self) -> Self::Output {
Expand All @@ -99,7 +100,7 @@ impl Stream for ProxyAddrsStream {
}

/// A SOCKS connection target.
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize)]
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize, BorshSerialize, BorshDeserialize, Clone)]
pub enum TargetAddr<'a> {
/// Connect to an IP address.
Ip(SocketAddr),
Expand All @@ -111,7 +112,16 @@ pub enum TargetAddr<'a> {
Domain(Cow<'a, str>, u16),
}

impl<'a> TargetAddr<'a> {
impl fmt::Display for TargetAddr<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TargetAddr::Ip(addr) => write!(f, "{}", addr),
TargetAddr::Domain(domain, port) => write!(f, "{}:{}", domain, port),
}
}
}

impl TargetAddr<'_> {
/// Creates owned `TargetAddr` by cloning. It is usually used to eliminate
/// the lifetime bound.
pub fn to_owned(&self) -> TargetAddr<'static> {
Expand All @@ -122,7 +132,7 @@ impl<'a> TargetAddr<'a> {
}
}

impl<'a> ToSocketAddrs for TargetAddr<'a> {
impl ToSocketAddrs for TargetAddr<'_> {
type Iter = Either<std::option::IntoIter<SocketAddr>, std::vec::IntoIter<SocketAddr>>;

fn to_socket_addrs(&self) -> IoResult<Self::Iter> {
Expand Down Expand Up @@ -252,7 +262,7 @@ enum Authentication<'a> {
None,
}

impl<'a> Authentication<'a> {
impl Authentication<'_> {
fn id(&self) -> u8 {
match self {
Authentication::Password { .. } => 0x02,
Expand All @@ -276,6 +286,46 @@ mod tests {
Ok(block_on(t.to_proxy_addrs().map(Result::unwrap).collect()))
}

#[test]
fn test_clone_ip() {
let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
let addr_clone = addr.clone();
assert_eq!(addr, addr_clone);
assert_eq!(addr.to_string(), addr_clone.to_string());
}

#[test]
fn test_clone_domain() {
let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
let addr_clone = addr.clone();
assert_eq!(addr, addr_clone);
assert_eq!(addr.to_string(), addr_clone.to_string());
}

#[test]
fn test_display_ip() {
let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
assert_eq!(format!("{}", addr), "127.0.0.1:8080");
}

#[test]
fn test_display_domain() {
let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
assert_eq!(format!("{}", addr), "example.com:80");
}

#[test]
fn test_to_string_ip() {
let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
assert_eq!(addr.to_string(), "127.0.0.1:8080");
}

#[test]
fn test_to_string_domain() {
let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
assert_eq!(addr.to_string(), "example.com:80");
}

#[test]
fn converts_socket_addr_to_proxy_addrs() -> Result<()> {
let addr = SocketAddr::from(([1, 1, 1, 1], 443));
Expand Down
47 changes: 29 additions & 18 deletions src/tcp/socks5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ use tokio::net::TcpStream;
use crate::ToProxyAddrs;
use crate::{
io::{AsyncSocket, AsyncSocketExt},
Authentication,
Error,
IntoTargetAddr,
Result,
TargetAddr,
Authentication, Error, IntoTargetAddr, Result, TargetAddr,
};

#[repr(u8)]
Expand Down Expand Up @@ -149,7 +145,8 @@ impl Socks5Stream<TcpStream> {
}

impl<S> Socks5Stream<S>
where S: AsyncSocket + Unpin
where
S: AsyncSocket + Unpin,
{
/// Connects to a target server through a SOCKS5 proxy given a socket to it.
///
Expand All @@ -158,7 +155,9 @@ where S: AsyncSocket + Unpin
/// It propagates the error that occurs in the conversion from `T` to
/// `TargetAddr`.
pub async fn connect_with_socket<'t, T>(socket: S, target: T) -> Result<Socks5Stream<S>>
where T: IntoTargetAddr<'t> {
where
T: IntoTargetAddr<'t>,
{
Self::execute_command_with_socket(socket, target, Authentication::None, Command::Connect).await
}

Expand Down Expand Up @@ -190,11 +189,11 @@ where S: AsyncSocket + Unpin
fn validate_auth(auth: &Authentication<'_>) -> Result<()> {
match auth {
Authentication::Password { username, password } => {
let username_len = username.as_bytes().len();
let username_len = username.len();
if !(1..=255).contains(&username_len) {
Err(Error::InvalidAuthValues("username length should between 1 to 255"))?
}
let password_len = password.as_bytes().len();
let password_len = password.len();
if !(1..=255).contains(&password_len) {
Err(Error::InvalidAuthValues("password length should between 1 to 255"))?
}
Expand All @@ -208,7 +207,9 @@ where S: AsyncSocket + Unpin
/// Resolve the domain name to an ip using special Tor Resolve command, by
/// connecting to a Tor compatible proxy given a socket to it.
pub async fn tor_resolve_with_socket<'t, T>(socket: S, target: T) -> Result<TargetAddr<'static>>
where T: IntoTargetAddr<'t> {
where
T: IntoTargetAddr<'t>,
{
let sock = Self::execute_command_with_socket(socket, target, Authentication::None, Command::TorResolve).await?;

Ok(sock.target_addr().to_owned())
Expand All @@ -219,7 +220,9 @@ where S: AsyncSocket + Unpin
/// PTR command, by connecting to a Tor compatible proxy given a socket
/// to it.
pub async fn tor_resolve_ptr_with_socket<'t, T>(socket: S, target: T) -> Result<TargetAddr<'static>>
where T: IntoTargetAddr<'t> {
where
T: IntoTargetAddr<'t>,
{
let sock =
Self::execute_command_with_socket(socket, target, Authentication::None, Command::TorResolvePtr).await?;

Expand Down Expand Up @@ -274,7 +277,8 @@ pub struct SocksConnector<'a, 't, S> {
}

impl<'a, 't, S> SocksConnector<'a, 't, S>
where S: Stream<Item = Result<SocketAddr>> + Unpin
where
S: Stream<Item = Result<SocketAddr>> + Unpin,
{
fn new(auth: Authentication<'a>, command: Command, proxy: Fuse<S>, target: TargetAddr<'t>) -> Self {
SocksConnector {
Expand Down Expand Up @@ -585,7 +589,8 @@ impl Socks5Listener<TcpStream> {
}

impl<S> Socks5Listener<S>
where S: AsyncSocket + Unpin
where
S: AsyncSocket + Unpin,
{
/// Initiates a BIND request to the specified proxy using the given socket
/// to it.
Expand All @@ -598,7 +603,9 @@ where S: AsyncSocket + Unpin
/// It propagates the error that occurs in the conversion from `T` to
/// `TargetAddr`.
pub async fn bind_with_socket<'t, T>(socket: S, target: T) -> Result<Socks5Listener<S>>
where T: IntoTargetAddr<'t> {
where
T: IntoTargetAddr<'t>,
{
Self::bind_with_auth_and_socket(Authentication::None, socket, target).await
}

Expand Down Expand Up @@ -674,7 +681,8 @@ where S: AsyncSocket + Unpin

#[cfg(feature = "tokio")]
impl<T> tokio::io::AsyncRead for Socks5Stream<T>
where T: tokio::io::AsyncRead + Unpin
where
T: tokio::io::AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
Expand All @@ -687,7 +695,8 @@ where T: tokio::io::AsyncRead + Unpin

#[cfg(feature = "tokio")]
impl<T> tokio::io::AsyncWrite for Socks5Stream<T>
where T: tokio::io::AsyncWrite + Unpin
where
T: tokio::io::AsyncWrite + Unpin,
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf)
Expand All @@ -704,7 +713,8 @@ where T: tokio::io::AsyncWrite + Unpin

#[cfg(feature = "futures-io")]
impl<T> futures_io::AsyncRead for Socks5Stream<T>
where T: futures_io::AsyncRead + Unpin
where
T: futures_io::AsyncRead + Unpin,
{
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
futures_io::AsyncRead::poll_read(Pin::new(&mut self.socket), cx, buf)
Expand All @@ -713,7 +723,8 @@ where T: futures_io::AsyncRead + Unpin

#[cfg(feature = "futures-io")]
impl<T> futures_io::AsyncWrite for Socks5Stream<T>
where T: futures_io::AsyncWrite + Unpin
where
T: futures_io::AsyncWrite + Unpin,
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
futures_io::AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf)
Expand Down

0 comments on commit 76c76f4

Please sign in to comment.