Skip to content

Commit

Permalink
implement trojan inbound
Browse files Browse the repository at this point in the history
Signed-off-by: Uncle Jack <[email protected]>
  • Loading branch information
unclejacki committed Jul 10, 2024
1 parent 2a1ce02 commit a257134
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 37 deletions.
9 changes: 7 additions & 2 deletions config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@
"type": "object",
"required": [
"path",
"protocol",
"uuid"
"protocol"
],
"properties": {
"password": {
"default": "",
"type": "string"
},
"path": {
"type": "string"
},
"protocol": {
"$ref": "#/definitions/Protocol"
},
"uuid": {
"default": "00000000-0000-0000-0000-000000000000",
"type": "string",
"format": "uuid"
}
Expand Down Expand Up @@ -81,6 +85,7 @@
"enum": [
"vmess",
"vless",
"trojan",
"bepass",
"relay_v1",
"relay_v2",
Expand Down
5 changes: 5 additions & 0 deletions config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ protocol = "vmess"
uuid = "0fbf4f81-2598-4b6a-a623-0ead4cb9efa8"
path = "/vmess"

[[inbound]]
protocol = "trojan"
path = "/trojan"
password = "test"

[outbound]
addresses = ["127.0.0.1"]
port = 6666
Expand Down
79 changes: 46 additions & 33 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ macro_rules! sha256 {
}
}

#[macro_export]
macro_rules! sha224 {
( $($v:expr),+ ) => {
{
let mut hash = Sha224::new();
$(
hash.update($v);
)*
hash.finalize()
}
}
}

#[macro_export]
macro_rules! hex {
($v:expr) => {
$v.iter().map(|b| format!("{:02x}", b)).collect::<String>()
};
}

pub fn encode_addr(addr: &str) -> Result<Vec<u8>> {
let ip = addr
.parse::<IpAddr>()
Expand All @@ -51,38 +71,31 @@ pub fn encode_addr(addr: &str) -> Result<Vec<u8>> {
})
}

pub async fn parse_addr<R: AsyncRead + std::marker::Unpin>(buf: &mut R) -> Result<String> {
let addr = match buf.read_u8().await? {
1 => {
let mut addr = [0u8; 4];
buf.read_exact(&mut addr).await?;
Ipv4Addr::new(addr[0], addr[1], addr[2], addr[3]).to_string()
}
2 => {
let len = buf.read_u8().await?;
let mut domain = vec![0u8; len as _];
buf.read_exact(&mut domain).await?;
String::from_utf8_lossy(&domain).to_string()
}
3 => {
let mut addr = [0u8; 16];
buf.read_exact(&mut addr).await?;
Ipv6Addr::new(
((addr[0] as u16) << 16) | (addr[1] as u16),
((addr[2] as u16) << 16) | (addr[3] as u16),
((addr[4] as u16) << 16) | (addr[5] as u16),
((addr[6] as u16) << 16) | (addr[7] as u16),
((addr[8] as u16) << 16) | (addr[9] as u16),
((addr[10] as u16) << 16) | (addr[11] as u16),
((addr[12] as u16) << 16) | (addr[13] as u16),
((addr[14] as u16) << 16) | (addr[15] as u16),
)
.to_string()
}
_ => {
return Err(Error::RustError("invalid address".to_string()));
}
};
pub async fn parse_ipv4<R: AsyncRead + std::marker::Unpin>(buf: &mut R) -> Result<String> {
let mut addr = [0u8; 4];
buf.read_exact(&mut addr).await?;
Ok(Ipv4Addr::new(addr[0], addr[1], addr[2], addr[3]).to_string())
}

pub async fn parse_ipv6<R: AsyncRead + std::marker::Unpin>(buf: &mut R) -> Result<String> {
let mut addr = [0u8; 16];
buf.read_exact(&mut addr).await?;
Ok(Ipv6Addr::new(
((addr[0] as u16) << 16) | (addr[1] as u16),
((addr[2] as u16) << 16) | (addr[3] as u16),
((addr[4] as u16) << 16) | (addr[5] as u16),
((addr[6] as u16) << 16) | (addr[7] as u16),
((addr[8] as u16) << 16) | (addr[9] as u16),
((addr[10] as u16) << 16) | (addr[11] as u16),
((addr[12] as u16) << 16) | (addr[13] as u16),
((addr[14] as u16) << 16) | (addr[15] as u16),
)
.to_string())
}

Ok(addr)
pub async fn parse_domain<R: AsyncRead + std::marker::Unpin>(buf: &mut R) -> Result<String> {
let len = buf.read_u8().await?;
let mut domain = vec![0u8; len as _];
buf.read_exact(&mut domain).await?;
Ok(String::from_utf8_lossy(&domain).to_string())
}
5 changes: 5 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub enum Protocol {
#[default]
Vmess,
Vless,
Trojan,
Bepass,
RelayV1,
RelayV2,
Expand All @@ -38,8 +39,12 @@ pub struct Outbound {
#[derive(Default, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Inbound {
pub protocol: Protocol,
// only for vmess/vless
#[serde(default)]
pub uuid: Uuid,
// only for trojan
#[serde(default)]
pub password: String,
pub path: String,
#[serde(skip)]
pub context: RequestContext,
Expand Down
6 changes: 6 additions & 0 deletions src/proxy/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod bepass;
pub mod blackhole;
pub mod relay;
pub mod trojan;
pub mod vless;
pub mod vmess;

Expand Down Expand Up @@ -140,6 +141,11 @@ pub async fn process(
.process()
.await
}
Protocol::Trojan => {
trojan::inbound::TrojanStream::new(config, inbound, events, ws)
.process()
.await
}
Protocol::Bepass => {
bepass::inbound::BepassStream::new(config, inbound, events, ws)
.process()
Expand Down
84 changes: 84 additions & 0 deletions src/proxy/trojan/encoding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
use crate::proxy::Network;

use sha2::{Digest, Sha224};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use worker::*;

pub struct Header {
pub network: Network,
pub address: String,
pub port: u16,
}

pub async fn decode_request_header<S: AsyncRead + AsyncWrite + Unpin>(
stream: &mut S,
password: &str,
) -> Result<Header> {
// TODO: using BufReader instead of reading directly from the stream

// +-----------------------+---------+----------------+---------+----------+
// | hex(SHA224(password)) | CRLF | Trojan Request | CRLF | Payload |
// +-----------------------+---------+----------------+---------+----------+
// | 56 | X'0D0A' | Variable | X'0D0A' | Variable |
// +-----------------------+---------+----------------+---------+----------+
let mut crlf = [0u8; 2];

let mut header_pass = [0u8; 56];
stream.read_exact(&mut header_pass).await?;
{
let header_pass = String::from_utf8_lossy(&header_pass);
let password = {
let p = &crate::sha224!(&password)[..];
crate::hex!(p)
};

if password != header_pass {
return Err(Error::RustError("invalid password".to_string()));
}
}

stream.read_exact(&mut crlf).await?;

// +-----+------+----------+----------+
// | CMD | ATYP | DST.ADDR | DST.PORT |
// +-----+------+----------+----------+
// | 1 | 1 | Variable | 2 |
// +-----+------+----------+----------+
let network = match stream.read_u8().await? {
0x01 => Network::Tcp,
0x03 => Network::Udp,
_ => return Err(Error::RustError("invalid network type".to_string())),
};

let address = match stream.read_u8().await? {
0x01 => crate::common::parse_ipv4(stream).await?,
0x03 => crate::common::parse_domain(stream).await?,
0x04 => crate::common::parse_ipv6(stream).await?,
_ => return Err(Error::RustError("invalid address".to_string())),
};
let port = {
let mut p = [0u8; 2];
stream.read_exact(&mut p).await?;
u16::from_be_bytes(p)
};

// UDP
// +--------+
// | Length |
// +--------+
// | 2 |
// +--------+
match network {
Network::Udp => {
stream.read_exact(&mut [0u8; 2]).await?;
}
_ => {}
}
stream.read_exact(&mut crlf).await?;

Ok(Header {
network,
address,
port,
})
}
115 changes: 115 additions & 0 deletions src/proxy/trojan/inbound.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use crate::config::{Config, Inbound};
use crate::proxy::{trojan::encoding, Proxy};

use std::pin::Pin;
use std::task::{Context, Poll};

use async_trait::async_trait;
use bytes::{BufMut, BytesMut};
use futures_util::Stream;
use pin_project_lite::pin_project;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use worker::*;

pin_project! {
pub struct TrojanStream<'a> {
pub config: Config,
pub inbound: Inbound,
pub ws: &'a WebSocket,
pub buffer: BytesMut,
#[pin]
pub events: EventStream<'a>,
}
}

unsafe impl<'a> Send for TrojanStream<'a> {}

impl<'a> TrojanStream<'a> {
pub fn new(
config: Config,
inbound: Inbound,
events: EventStream<'a>,
ws: &'a WebSocket,
) -> Self {
let buffer = BytesMut::new();

Self {
config,
inbound,
ws,
buffer,
events,
}
}
}

#[async_trait]
impl<'a> Proxy for TrojanStream<'a> {
async fn process(&mut self) -> Result<()> {
let password = self.inbound.password.clone();
let header = encoding::decode_request_header(&mut self, &password).await?;

let mut context = self.inbound.context.clone();
{
context.address = header.address;
context.port = header.port;
context.network = header.network;
}

let outbound = self.config.dispatch_outbound(&context);
let mut upstream = crate::proxy::connect_outbound(context, outbound).await?;

tokio::io::copy_bidirectional(self, &mut upstream).await?;

Ok(())
}
}

impl<'a> AsyncRead for TrojanStream<'a> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
let mut this = self.project();

loop {
let size = std::cmp::min(this.buffer.len(), buf.remaining());
if size > 0 {
buf.put_slice(&this.buffer.split_to(size));
return Poll::Ready(Ok(()));
}

match this.events.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(WebsocketEvent::Message(msg)))) => {
msg.bytes().iter().for_each(|x| this.buffer.put_slice(&x));
}
Poll::Pending => return Poll::Pending,
_ => return Poll::Ready(Ok(())),
}
}
}
}

impl<'a> AsyncWrite for TrojanStream<'a> {
fn poll_write(
self: Pin<&mut Self>,
_: &mut Context<'_>,
buf: &[u8],
) -> Poll<tokio::io::Result<usize>> {
return Poll::Ready(
self.ws
.send_with_bytes(buf)
.map(|_| buf.len())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string())),
);
}

fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
Poll::Ready(Ok(()))
}

fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
unimplemented!()
}
}
2 changes: 2 additions & 0 deletions src/proxy/trojan/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
mod encoding;
pub mod inbound;
7 changes: 6 additions & 1 deletion src/proxy/vless/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ pub async fn decode_request_header<S: AsyncRead + AsyncWrite + Unpin>(
stream.read_exact(&mut p).await?;
u16::from_be_bytes(p)
};
let address = crate::common::parse_addr(stream).await?;
let address = match stream.read_u8().await? {
0x01 => crate::common::parse_ipv4(stream).await?,
0x02 => crate::common::parse_domain(stream).await?,
0x03 => crate::common::parse_ipv6(stream).await?,
_ => return Err(Error::RustError("invalid address".to_string())),
};

Ok(Header {
network,
Expand Down
Loading

0 comments on commit a257134

Please sign in to comment.