Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added to_string and display methods for TargetAddr Fixes #60 #63

Merged
merged 6 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 62 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Cow,
fmt,
io::Result as IoResult,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
pin::Pin,
Expand Down Expand Up @@ -47,7 +48,7 @@ trivial_impl_to_proxy_addrs!((Ipv6Addr, u16));
trivial_impl_to_proxy_addrs!(SocketAddrV4);
trivial_impl_to_proxy_addrs!(SocketAddrV6);

impl<'a> ToProxyAddrs for &'a [SocketAddr] {
impl ToProxyAddrs for &[SocketAddr] {
type Output = ProxyAddrsStream;

fn to_proxy_addrs(&self) -> Self::Output {
Expand All @@ -64,15 +65,15 @@ impl ToProxyAddrs for str {
}
}

impl<'a> ToProxyAddrs for (&'a str, u16) {
impl ToProxyAddrs for (&str, u16) {
type Output = ProxyAddrsStream;

fn to_proxy_addrs(&self) -> Self::Output {
ProxyAddrsStream(Some(self.to_socket_addrs()))
}
}

impl<'a, T: ToProxyAddrs + ?Sized> ToProxyAddrs for &'a T {
impl<T: ToProxyAddrs + ?Sized> ToProxyAddrs for &T {
type Output = T::Output;

fn to_proxy_addrs(&self) -> Self::Output {
Expand All @@ -98,7 +99,7 @@ impl Stream for ProxyAddrsStream {
}

/// A SOCKS connection target.
#[derive(Debug, PartialEq, Eq)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum TargetAddr<'a> {
/// Connect to an IP address.
Ip(SocketAddr),
Expand All @@ -110,7 +111,16 @@ pub enum TargetAddr<'a> {
Domain(Cow<'a, str>, u16),
}

impl<'a> TargetAddr<'a> {
impl fmt::Display for TargetAddr<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TargetAddr::Ip(addr) => write!(f, "{}", addr),
TargetAddr::Domain(domain, port) => write!(f, "{}:{}", domain, port),
}
}
}

impl TargetAddr<'_> {
/// Creates owned `TargetAddr` by cloning. It is usually used to eliminate
/// the lifetime bound.
pub fn to_owned(&self) -> TargetAddr<'static> {
Expand All @@ -121,7 +131,7 @@ impl<'a> TargetAddr<'a> {
}
}

impl<'a> ToSocketAddrs for TargetAddr<'a> {
impl ToSocketAddrs for TargetAddr<'_> {
type Iter = Either<std::option::IntoIter<SocketAddr>, std::vec::IntoIter<SocketAddr>>;

fn to_socket_addrs(&self) -> IoResult<Self::Iter> {
Expand Down Expand Up @@ -236,7 +246,8 @@ impl IntoTargetAddr<'static> for (String, u16) {
}

impl<'a, T> IntoTargetAddr<'a> for &'a T
where T: IntoTargetAddr<'a> + Copy
where
T: IntoTargetAddr<'a> + Copy,
{
fn into_target_addr(self) -> Result<TargetAddr<'a>> {
(*self).into_target_addr()
Expand All @@ -250,7 +261,7 @@ enum Authentication<'a> {
None,
}

impl<'a> Authentication<'a> {
impl Authentication<'_> {
fn id(&self) -> u8 {
match self {
Authentication::Password { .. } => 0x02,
Expand All @@ -274,6 +285,46 @@ mod tests {
Ok(block_on(t.to_proxy_addrs().map(Result::unwrap).collect()))
}

#[test]
fn test_clone_ip() {
let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
let addr_clone = addr.clone();
assert_eq!(addr, addr_clone);
assert_eq!(addr.to_string(), addr_clone.to_string());
}

#[test]
fn test_clone_domain() {
let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
let addr_clone = addr.clone();
assert_eq!(addr, addr_clone);
assert_eq!(addr.to_string(), addr_clone.to_string());
}

#[test]
fn test_display_ip() {
let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
assert_eq!(format!("{}", addr), "127.0.0.1:8080");
}

#[test]
fn test_display_domain() {
let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
assert_eq!(format!("{}", addr), "example.com:80");
}

#[test]
fn test_to_string_ip() {
let addr = TargetAddr::Ip(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
assert_eq!(addr.to_string(), "127.0.0.1:8080");
}

#[test]
fn test_to_string_domain() {
let addr = TargetAddr::Domain(Cow::Borrowed("example.com"), 80);
assert_eq!(addr.to_string(), "example.com:80");
}

#[test]
fn converts_socket_addr_to_proxy_addrs() -> Result<()> {
let addr = SocketAddr::from(([1, 1, 1, 1], 443));
Expand Down Expand Up @@ -302,7 +353,9 @@ mod tests {
}

fn into_target_addr<'a, T>(t: T) -> Result<TargetAddr<'a>>
where T: IntoTargetAddr<'a> {
where
T: IntoTargetAddr<'a>,
{
t.into_target_addr()
}

Expand Down
47 changes: 29 additions & 18 deletions src/tcp/socks5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ use tokio::net::TcpStream;
use crate::ToProxyAddrs;
use crate::{
io::{AsyncSocket, AsyncSocketExt},
Authentication,
Error,
IntoTargetAddr,
Result,
TargetAddr,
Authentication, Error, IntoTargetAddr, Result, TargetAddr,
};

#[repr(u8)]
Expand Down Expand Up @@ -149,7 +145,8 @@ impl Socks5Stream<TcpStream> {
}

impl<S> Socks5Stream<S>
where S: AsyncSocket + Unpin
where
S: AsyncSocket + Unpin,
{
/// Connects to a target server through a SOCKS5 proxy given a socket to it.
///
Expand All @@ -158,7 +155,9 @@ where S: AsyncSocket + Unpin
/// It propagates the error that occurs in the conversion from `T` to
/// `TargetAddr`.
pub async fn connect_with_socket<'t, T>(socket: S, target: T) -> Result<Socks5Stream<S>>
where T: IntoTargetAddr<'t> {
where
T: IntoTargetAddr<'t>,
{
Self::execute_command_with_socket(socket, target, Authentication::None, Command::Connect).await
}

Expand Down Expand Up @@ -190,11 +189,11 @@ where S: AsyncSocket + Unpin
fn validate_auth(auth: &Authentication<'_>) -> Result<()> {
match auth {
Authentication::Password { username, password } => {
let username_len = username.as_bytes().len();
let username_len = username.len();
if !(1..=255).contains(&username_len) {
Err(Error::InvalidAuthValues("username length should between 1 to 255"))?
}
let password_len = password.as_bytes().len();
let password_len = password.len();
if !(1..=255).contains(&password_len) {
Err(Error::InvalidAuthValues("password length should between 1 to 255"))?
}
Expand All @@ -208,7 +207,9 @@ where S: AsyncSocket + Unpin
/// Resolve the domain name to an ip using special Tor Resolve command, by
/// connecting to a Tor compatible proxy given a socket to it.
pub async fn tor_resolve_with_socket<'t, T>(socket: S, target: T) -> Result<TargetAddr<'static>>
where T: IntoTargetAddr<'t> {
where
T: IntoTargetAddr<'t>,
{
let sock = Self::execute_command_with_socket(socket, target, Authentication::None, Command::TorResolve).await?;

Ok(sock.target_addr().to_owned())
Expand All @@ -219,7 +220,9 @@ where S: AsyncSocket + Unpin
/// PTR command, by connecting to a Tor compatible proxy given a socket
/// to it.
pub async fn tor_resolve_ptr_with_socket<'t, T>(socket: S, target: T) -> Result<TargetAddr<'static>>
where T: IntoTargetAddr<'t> {
where
T: IntoTargetAddr<'t>,
{
let sock =
Self::execute_command_with_socket(socket, target, Authentication::None, Command::TorResolvePtr).await?;

Expand Down Expand Up @@ -274,7 +277,8 @@ pub struct SocksConnector<'a, 't, S> {
}

impl<'a, 't, S> SocksConnector<'a, 't, S>
where S: Stream<Item = Result<SocketAddr>> + Unpin
where
S: Stream<Item = Result<SocketAddr>> + Unpin,
{
fn new(auth: Authentication<'a>, command: Command, proxy: Fuse<S>, target: TargetAddr<'t>) -> Self {
SocksConnector {
Expand Down Expand Up @@ -585,7 +589,8 @@ impl Socks5Listener<TcpStream> {
}

impl<S> Socks5Listener<S>
where S: AsyncSocket + Unpin
where
S: AsyncSocket + Unpin,
{
/// Initiates a BIND request to the specified proxy using the given socket
/// to it.
Expand All @@ -598,7 +603,9 @@ where S: AsyncSocket + Unpin
/// It propagates the error that occurs in the conversion from `T` to
/// `TargetAddr`.
pub async fn bind_with_socket<'t, T>(socket: S, target: T) -> Result<Socks5Listener<S>>
where T: IntoTargetAddr<'t> {
where
T: IntoTargetAddr<'t>,
{
Self::bind_with_auth_and_socket(Authentication::None, socket, target).await
}

Expand Down Expand Up @@ -674,7 +681,8 @@ where S: AsyncSocket + Unpin

#[cfg(feature = "tokio")]
impl<T> tokio::io::AsyncRead for Socks5Stream<T>
where T: tokio::io::AsyncRead + Unpin
where
T: tokio::io::AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
Expand All @@ -687,7 +695,8 @@ where T: tokio::io::AsyncRead + Unpin

#[cfg(feature = "tokio")]
impl<T> tokio::io::AsyncWrite for Socks5Stream<T>
where T: tokio::io::AsyncWrite + Unpin
where
T: tokio::io::AsyncWrite + Unpin,
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf)
Expand All @@ -704,7 +713,8 @@ where T: tokio::io::AsyncWrite + Unpin

#[cfg(feature = "futures-io")]
impl<T> futures_io::AsyncRead for Socks5Stream<T>
where T: futures_io::AsyncRead + Unpin
where
T: futures_io::AsyncRead + Unpin,
{
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
futures_io::AsyncRead::poll_read(Pin::new(&mut self.socket), cx, buf)
Expand All @@ -713,7 +723,8 @@ where T: futures_io::AsyncRead + Unpin

#[cfg(feature = "futures-io")]
impl<T> futures_io::AsyncWrite for Socks5Stream<T>
where T: futures_io::AsyncWrite + Unpin
where
T: futures_io::AsyncWrite + Unpin,
{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
futures_io::AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf)
Expand Down
Loading