diff --git a/src/vmess.rs b/src/vmess.rs index db3481e..9afacb4 100644 --- a/src/vmess.rs +++ b/src/vmess.rs @@ -53,9 +53,20 @@ impl Vmess { pub fn into_split(self) -> (VmessReader, VmessWriter) { let stream = self.stream.expect("stream should contain a value"); let (reader, writer) = stream.into_split(); + let encoder = Encoder::new(); - let r = VmessReader { reader, encoder }; + let key = md5!(&encoder.key); + let iv = md5!(&encoder.iv); + let decoder = Aes128CfbDec::new(&key.into(), &iv.into()); + + let r = VmessReader { + reader, + encoder, + decoder, + remaining_length: 0, + handshaked: false, + }; let w = VmessWriter { encoder, writer, @@ -70,6 +81,9 @@ impl Vmess { pub struct VmessReader { reader: R, encoder: Encoder, + remaining_length: usize, + decoder: Aes128CfbDec, + handshaked: bool, } impl VmessReader { @@ -83,45 +97,50 @@ impl VmessReader { // | Response Authentication V | Option Opt | Command Cmd | Command Length M | Command Content | Actual Response Data | // +---------------------------+------------+-------------+------------------+-----------------+----------------------+ - let key = md5!(&self.encoder.key); - let iv = md5!(&self.encoder.iv); - let mut decoder = Aes128CfbDec::new(&key.into(), &iv.into()); - - let mut header = [0u8; 4]; - self.reader.read_exact(&mut header).await?; - decoder.decrypt(&mut header); // ignore the header for now - // just decrypt it because our decoder is stateful - - // https://xtls.github.io/en/development/protocols/vmess.html#data-section - // - // +----------+-------------+ - // | 2 Bytes | L Bytes | - // +----------+-------------+ - // | Length L | Data Packet | - // +----------+-------------+ - // - // - Length L: A big-endian integer with a maximum value of 2^14 - // - Packet: A data packet encrypted by the specified encryption method - - // AES-128-CFB: - // The entire data section is encrypted using AES-128-CFB - // - 4 bytes: FNV1a hash of actual data - // - L - 4 bytes: actual data - let mut length = [0u8; 2]; - self.reader.read_exact(&mut length).await?; - decoder.decrypt(&mut length); - - // When Opt(M) is enabled, the value of L is equal to the true value xor Mask - // Mask = (RequestMask.NextByte() << 8) + RequestMask.NextByte() - let length = (length[0] as usize) << 8 | (length[1] as usize) - 4; // 4bytes checksum + if !self.handshaked { + let mut header = [0u8; 4]; + self.reader.read_exact(&mut header).await?; + self.decoder.decrypt(&mut header); // ignore the header for now + // just decrypt it because our decoder is stateful + + // https://xtls.github.io/en/development/protocols/vmess.html#data-section + // + // +----------+-------------+ + // | 2 Bytes | L Bytes | + // +----------+-------------+ + // | Length L | Data Packet | + // +----------+-------------+ + // + // - Length L: A big-endian integer with a maximum value of 2^14 + // - Packet: A data packet encrypted by the specified encryption method + + // AES-128-CFB: + // The entire data section is encrypted using AES-128-CFB + // - 4 bytes: FNV1a hash of actual data + // - L - 4 bytes: actual data + let mut length = [0u8; 2]; + self.reader.read_exact(&mut length).await?; + self.decoder.decrypt(&mut length); + + // When Opt(M) is enabled, the value of L is equal to the true value xor Mask + // Mask = (RequestMask.NextByte() << 8) + RequestMask.NextByte() + self.remaining_length = (length[0] as usize) << 8 | (length[1] as usize) - 4; // 4bytes checksum + + let mut checksum = [0u8; 4]; + self.reader.read(&mut checksum).await?; + self.decoder.decrypt(&mut checksum); // ignore the checksum for now + // just decrypt it because our decoder is stateful + self.handshaked = true; + } - let mut checksum = [0u8; 4]; - self.reader.read(&mut checksum).await?; - decoder.decrypt(&mut checksum); // ignore the checksum for now - // just decrypt it because our decoder is stateful + let mut length = self.remaining_length; + if length > buf.len() { + length = buf.len(); + self.remaining_length -= buf.len(); + } self.reader.read(&mut buf[..length]).await?; - decoder.decrypt(&mut buf[..length]); + self.decoder.decrypt(&mut buf[..length]); Ok(length) }