From aed8b12044a9a66b90dd1d8b1c16703b223524c4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 11 Dec 2023 15:12:26 +0100 Subject: [PATCH 01/20] split network i/o from task --- src/network/icmp.rs | 107 +++++++ src/network/io.rs | 532 +++++++++++++++++++++++++++++++++++ src/network/mod.rs | 2 + src/network/task.rs | 661 ++------------------------------------------ 4 files changed, 668 insertions(+), 634 deletions(-) create mode 100644 src/network/icmp.rs create mode 100644 src/network/io.rs diff --git a/src/network/icmp.rs b/src/network/icmp.rs new file mode 100644 index 00000000..318fa1f2 --- /dev/null +++ b/src/network/icmp.rs @@ -0,0 +1,107 @@ +use crate::messages::IpPacket; +use smoltcp::phy::ChecksumCapabilities; +use smoltcp::wire::{ + Icmpv4Message, Icmpv4Packet, Icmpv4Repr, Icmpv6Message, Icmpv6Packet, Icmpv6Repr, IpAddress, + IpProtocol, Ipv4Packet, Ipv4Repr, Ipv6Packet, Ipv6Repr, +}; + +pub(super) fn handle_icmpv4_echo_request( + mut input_packet: Ipv4Packet>, +) -> Option { + let src_addr = input_packet.src_addr(); + let dst_addr = input_packet.dst_addr(); + + // Parsing ICMP Packet + let mut input_icmpv4_packet = match Icmpv4Packet::new_checked(input_packet.payload_mut()) { + Ok(p) => p, + Err(e) => { + log::debug!("Received invalid ICMPv4 packet: {}", e); + return None; + } + }; + + // Checking that it is an ICMP Echo Request. + if input_icmpv4_packet.msg_type() != Icmpv4Message::EchoRequest { + log::debug!( + "Unsupported ICMPv4 packet of type: {}", + input_icmpv4_packet.msg_type() + ); + return None; + } + + // Creating fake response packet. + let icmp_repr = Icmpv4Repr::EchoReply { + ident: input_icmpv4_packet.echo_ident(), + seq_no: input_icmpv4_packet.echo_seq_no(), + data: input_icmpv4_packet.data_mut(), + }; + let ip_repr = Ipv4Repr { + // Directing fake reply back to the original source address. + src_addr: dst_addr, + dst_addr: src_addr, + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 255, + }; + let buf = vec![0u8; ip_repr.buffer_len() + icmp_repr.buffer_len()]; + let mut output_ipv4_packet = Ipv4Packet::new_unchecked(buf); + ip_repr.emit(&mut output_ipv4_packet, &ChecksumCapabilities::default()); + let mut output_ip_packet = IpPacket::from(output_ipv4_packet); + icmp_repr.emit( + &mut Icmpv4Packet::new_unchecked(output_ip_packet.payload_mut()), + &ChecksumCapabilities::default(), + ); + Some(output_ip_packet) +} + +pub(super) fn handle_icmpv6_echo_request( + mut input_packet: Ipv6Packet>, +) -> Option { + let src_addr = input_packet.src_addr(); + let dst_addr = input_packet.dst_addr(); + + // Parsing ICMP Packet + let mut input_icmpv6_packet = match Icmpv6Packet::new_checked(input_packet.payload_mut()) { + Ok(p) => p, + Err(e) => { + log::debug!("Received invalid ICMPv6 packet: {}", e); + return None; + } + }; + + // Checking that it is an ICMP Echo Request. + if input_icmpv6_packet.msg_type() != Icmpv6Message::EchoRequest { + log::debug!( + "Unsupported ICMPv6 packet of type: {}", + input_icmpv6_packet.msg_type() + ); + return None; + } + + // Creating fake response packet. + let icmp_repr = Icmpv6Repr::EchoReply { + ident: input_icmpv6_packet.echo_ident(), + seq_no: input_icmpv6_packet.echo_seq_no(), + data: input_icmpv6_packet.payload_mut(), + }; + let ip_repr = Ipv6Repr { + // Directing fake reply back to the original source address. + src_addr: dst_addr, + dst_addr: src_addr, + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 255, + }; + let buf = vec![0u8; ip_repr.buffer_len() + icmp_repr.buffer_len()]; + let mut output_ipv6_packet = Ipv6Packet::new_unchecked(buf); + ip_repr.emit(&mut output_ipv6_packet); + let mut output_ip_packet = IpPacket::from(output_ipv6_packet); + icmp_repr.emit( + // Directing fake reply back to the original source address. + &IpAddress::from(dst_addr), + &IpAddress::from(src_addr), + &mut Icmpv6Packet::new_unchecked(output_ip_packet.payload_mut()), + &ChecksumCapabilities::default(), + ); + Some(output_ip_packet) +} diff --git a/src/network/io.rs b/src/network/io.rs new file mode 100644 index 00000000..6a9a5046 --- /dev/null +++ b/src/network/io.rs @@ -0,0 +1,532 @@ +use std::collections::{HashMap, HashSet, VecDeque}; +use std::net::SocketAddr; +use std::{cmp, fmt}; + +use anyhow::Result; +use pretty_hex::pretty_hex; +use smoltcp::iface::{Config, SocketSet}; +use smoltcp::socket::{tcp, Socket}; + +use smoltcp::wire::HardwareAddress; +use smoltcp::{ + iface::{Interface, SocketHandle}, + phy::ChecksumCapabilities, + time::{Duration, Instant}, + wire::{ + IpAddress, IpCidr, IpProtocol, IpRepr, Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv6Address, + Ipv6Packet, Ipv6Repr, TcpPacket, UdpPacket, UdpRepr, + }, +}; +use tokio::sync::{ + mpsc::{Permit, Sender}, + oneshot, +}; + +use crate::messages::{ + ConnectionId, IpPacket, NetworkCommand, NetworkEvent, TransportCommand, TransportEvent, + TunnelInfo, +}; +use crate::network::icmp::{handle_icmpv4_echo_request, handle_icmpv6_echo_request}; + +use super::virtual_device::VirtualDevice; + +/// Associated data for a smoltcp socket. +#[derive(Debug)] +struct SocketData { + handle: SocketHandle, + /// smoltcp can only operate with fixed-size buffers, but Python's stream implementation assumes + /// an infinite buffer. So we have a second send buffer here, plus a boolean to indicate that + /// we want to send a FIN. + send_buffer: VecDeque, + write_eof: bool, + // Gets notified once there's data to be read. + recv_waiter: Option<(u32, oneshot::Sender>)>, + // Gets notified once there is enough space in the write buffer. + drain_waiter: Vec>, + addr_tuple: (SocketAddr, SocketAddr), +} + +pub struct NetworkIO<'a> { + iface: Interface, + device: VirtualDevice, + sockets: SocketSet<'a>, + + net_tx: Sender, + + socket_data: HashMap, + active_connections: HashSet<(SocketAddr, SocketAddr)>, + next_connection_id: ConnectionId, + remove_conns: Vec, +} + +impl<'a> NetworkIO<'a> { + pub fn new(net_tx: Sender) -> Self { + let mut device = VirtualDevice::new(net_tx.clone()); + + let config = Config::new(HardwareAddress::Ip); + let mut iface = Interface::new(config, &mut device, Instant::now()); + + iface.set_any_ip(true); + + iface.update_ip_addrs(|ip_address| { + ip_address + .push(IpCidr::new(IpAddress::v4(0, 0, 0, 1), 0)) + .unwrap(); + }); + // TODO: IPv6 + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1)) + .unwrap(); + + NetworkIO { + iface, + device, + sockets: SocketSet::new(Vec::new()), + net_tx, + socket_data: HashMap::new(), + active_connections: HashSet::new(), + next_connection_id: 0, + remove_conns: Vec::new(), + } + } + + fn receive_packet( + &mut self, + packet: IpPacket, + tunnel_info: TunnelInfo, + permit: Permit<'_, TransportEvent>, + ) -> Result<()> { + if let IpPacket::V4(p) = &packet { + if !p.verify_checksum() { + log::warn!("Received invalid IP packet (checksum error)."); + return Ok(()); + } + } + + match packet.transport_protocol() { + IpProtocol::Tcp => self.receive_packet_tcp(packet, tunnel_info, permit), + IpProtocol::Udp => self.receive_packet_udp(packet, tunnel_info, permit), + IpProtocol::Icmp => self.receive_packet_icmp(packet), + _ => { + log::debug!( + "Received IP packet for unknown protocol: {}", + packet.transport_protocol() + ); + Ok(()) + } + } + } + + fn receive_packet_udp( + &mut self, + mut packet: IpPacket, + tunnel_info: TunnelInfo, + permit: Permit<'_, TransportEvent>, + ) -> Result<()> { + let src_ip = packet.src_ip(); + let dst_ip = packet.dst_ip(); + + let mut udp_packet = match UdpPacket::new_checked(packet.payload_mut()) { + Ok(p) => p, + Err(e) => { + log::debug!("Received invalid UDP packet: {}", e); + return Ok(()); + } + }; + + let src_addr = SocketAddr::new(src_ip, udp_packet.src_port()); + let dst_addr = SocketAddr::new(dst_ip, udp_packet.dst_port()); + + let event = TransportEvent::DatagramReceived { + data: udp_packet.payload_mut().to_vec(), + src_addr, + dst_addr, + tunnel_info, + }; + + permit.send(event); + Ok(()) + } + + fn receive_packet_tcp( + &mut self, + mut packet: IpPacket, + tunnel_info: TunnelInfo, + permit: Permit<'_, TransportEvent>, + ) -> Result<()> { + let src_ip = packet.src_ip(); + let dst_ip = packet.dst_ip(); + + let tcp_packet = match TcpPacket::new_checked(packet.payload_mut()) { + // packet with correct length + Ok(p) => { + // packet with correct checksum + if p.verify_checksum(&src_ip.into(), &dst_ip.into()) { + p + } else { + // packet with incorrect checksum + log::warn!("Received invalid TCP packet (checksum error)."); + return Ok(()); + } + } + // packet with incorrect length + Err(e) => { + log::debug!("Received invalid TCP packet ({}) with payload:", e); + log::debug!("{}", pretty_hex(&packet.payload_mut())); + return Ok(()); + } + }; + + let src_addr = SocketAddr::new(src_ip, tcp_packet.src_port()); + let dst_addr = SocketAddr::new(dst_ip, tcp_packet.dst_port()); + + if tcp_packet.syn() + && !tcp_packet.ack() + && !self.active_connections.contains(&(src_addr, dst_addr)) + { + let mut socket = tcp::Socket::new( + tcp::SocketBuffer::new(vec![0u8; 64 * 1024]), + tcp::SocketBuffer::new(vec![0u8; 64 * 1024]), + ); + + socket.listen(dst_addr)?; + socket.set_timeout(Some(Duration::from_secs(60))); + socket.set_keep_alive(Some(Duration::from_secs(28))); + + let handle = self.sockets.add(socket); + + let connection_id = { + self.next_connection_id += 1; + self.next_connection_id + }; + + let data = SocketData { + handle, + send_buffer: VecDeque::new(), + write_eof: false, + recv_waiter: None, + drain_waiter: Vec::new(), + addr_tuple: (src_addr, dst_addr), + }; + self.socket_data.insert(connection_id, data); + self.active_connections.insert((src_addr, dst_addr)); + + let event = TransportEvent::ConnectionEstablished { + connection_id, + src_addr, + dst_addr, + tunnel_info, + }; + permit.send(event); + } + + self.device.receive_packet(packet); + Ok(()) + } + + fn receive_packet_icmp(&mut self, packet: IpPacket) -> Result<()> { + // Some apps check network connectivity by sending ICMP pings. ICMP traffic is currently + // swallowed by mitmproxy_rs, which makes them believe that there is no network connectivity. + // Generating fake ICMP replies as a simple workaround. + + if let Ok(permit) = self.net_tx.try_reserve() { + // Generating and sending fake replies for ICMP echo requests. Ignoring all other ICMP types. + let response_packet = match packet { + IpPacket::V4(packet) => handle_icmpv4_echo_request(packet), + IpPacket::V6(packet) => handle_icmpv6_echo_request(packet), + }; + if let Some(response_packet) = response_packet { + permit.send(NetworkCommand::SendPacket(response_packet)); + } + } else { + log::debug!("Channel full, discarding ICMP packet."); + } + Ok(()) + } + + fn read_data(&mut self, id: ConnectionId, n: u32, tx: oneshot::Sender>) { + if let Some(data) = self.socket_data.get_mut(&id) { + assert!(data.recv_waiter.is_none()); + data.recv_waiter = Some((n, tx)); + } else { + // connection is has already been removed because the connection is closed, + // so we just drop the tx. + } + } + + fn write_data(&mut self, id: ConnectionId, buf: Vec) { + if let Some(data) = self.socket_data.get_mut(&id) { + data.send_buffer.extend(buf); + } else { + // connection is has already been removed because the connection is closed, + // so we just ignore the write. + } + } + + fn drain_writer(&mut self, id: ConnectionId, tx: oneshot::Sender<()>) { + if let Some(data) = self.socket_data.get_mut(&id) { + data.drain_waiter.push(tx); + } else { + // connection is has already been removed because the connection is closed, + // so we just drop the tx. + } + } + + fn close_connection(&mut self, id: ConnectionId, _half_close: bool) { + if let Some(data) = self.socket_data.get_mut(&id) { + // smoltcp does not have a good way to do "SHUT_RDWR". We can't call .abort() + // here because that sends a RST instead of a FIN (and breaks + // retransmissions of the connection close packet). Alternatively, we could manually + // set a timer on .close() and then forcibly .abort() once the timer expires (see + // tcp-abort branch). This incurs a bit of unnecessary complexity, so we try something + // dumber here: We simply close our end and then hope that either the client sends a FIN + // or times out via the keepalive mechanism. + + data.write_eof = true; + } else { + // connection is already dead. + } + } + + fn send_datagram(&mut self, data: Vec, src_addr: SocketAddr, dst_addr: SocketAddr) { + let permit = match self.net_tx.try_reserve() { + Ok(p) => p, + Err(_) => { + log::debug!("Channel full, discarding UDP packet."); + return; + } + }; + + // We now know that there's space for us to send, + // let's painstakingly reassemble the IP packet... + + let udp_repr = UdpRepr { + src_port: src_addr.port(), + dst_port: dst_addr.port(), + }; + + let ip_repr: IpRepr = match (src_addr, dst_addr) { + (SocketAddr::V4(src_addr), SocketAddr::V4(dst_addr)) => IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address::from(*src_addr.ip()), + dst_addr: Ipv4Address::from(*dst_addr.ip()), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + data.len(), + hop_limit: 255, + }), + (SocketAddr::V6(src_addr), SocketAddr::V6(dst_addr)) => IpRepr::Ipv6(Ipv6Repr { + src_addr: Ipv6Address::from(*src_addr.ip()), + dst_addr: Ipv6Address::from(*dst_addr.ip()), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + data.len(), + hop_limit: 255, + }), + _ => { + log::error!("Failed to assemble UDP datagram: mismatched IP address versions"); + return; + } + }; + + let buf = vec![0u8; ip_repr.buffer_len()]; + + let mut ip_packet = match ip_repr { + IpRepr::Ipv4(repr) => { + let mut packet = Ipv4Packet::new_unchecked(buf); + repr.emit(&mut packet, &ChecksumCapabilities::default()); + IpPacket::from(packet) + } + IpRepr::Ipv6(repr) => { + let mut packet = Ipv6Packet::new_unchecked(buf); + repr.emit(&mut packet); + IpPacket::from(packet) + } + }; + + udp_repr.emit( + &mut UdpPacket::new_unchecked(ip_packet.payload_mut()), + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + data.len(), + |buf| buf.copy_from_slice(data.as_slice()), + &ChecksumCapabilities::default(), + ); + + permit.send(NetworkCommand::SendPacket(ip_packet)); + } + + pub fn handle_network_event( + &mut self, + event: NetworkEvent, + permit: Permit<'_, TransportEvent>, + ) -> Result<()> { + match event { + NetworkEvent::ReceivePacket { + packet, + tunnel_info, + } => { + self.receive_packet(packet, tunnel_info, permit)?; + } + } + Ok(()) + } + + pub fn handle_transport_command(&mut self, command: TransportCommand) { + match command { + TransportCommand::ReadData(id, n, tx) => { + self.read_data(id, n, tx); + } + TransportCommand::WriteData(id, buf) => { + self.write_data(id, buf); + } + TransportCommand::DrainWriter(id, tx) => { + self.drain_writer(id, tx); + } + TransportCommand::CloseConnection(id, half_close) => { + self.close_connection(id, half_close); + } + TransportCommand::SendDatagram { + data, + src_addr, + dst_addr, + } => { + self.send_datagram(data, src_addr, dst_addr); + } + } + } + + pub fn poll_delay(&mut self) -> Option { + self.iface.poll_delay(Instant::now(), &self.sockets) + } + + pub fn poll(&mut self) -> Result<()> { + // poll virtual network device + #[cfg(debug_assertions)] + log::debug!("Polling virtual network device ..."); + self.iface + .poll(Instant::now(), &mut self.device, &mut self.sockets); + + // Process TCP socket I/O + #[cfg(debug_assertions)] + log::debug!("Processing TCP connections ..."); + self.process_tcp()?; + + // poll again. we may have new stuff to do. + #[cfg(debug_assertions)] + log::debug!("Polling virtual network device ..."); + self.iface + .poll(Instant::now(), &mut self.device, &mut self.sockets); + Ok(()) + } + + fn process_tcp(&mut self) -> Result<()> { + for (connection_id, data) in self.socket_data.iter_mut() { + let socket = self.sockets.get_mut::(data.handle); + + // receive data over the socket + if data.recv_waiter.is_some() { + if socket.can_recv() { + let (n, tx) = data.recv_waiter.take().unwrap(); + let bytes_available = socket.recv_queue(); + + let mut buf = vec![0u8; cmp::min(bytes_available, n as usize)]; + let bytes_read = socket.recv_slice(&mut buf)?; + + buf.truncate(bytes_read); + if tx.send(buf).is_err() { + log::debug!("Cannot send received data, channel was already closed."); + } + } else { + // We can't use .may_recv() here as it returns false during establishment. + use tcp::State::*; + match socket.state() { + // can we still receive something in the future? + CloseWait | LastAck | Closed | Closing | TimeWait => { + let (_, tx) = data.recv_waiter.take().unwrap(); + if tx.send(Vec::new()).is_err() { + log::debug!("Cannot send close, channel was already closed."); + } + } + _ => {} + } + } + } + + // send data over the socket + if !data.send_buffer.is_empty() && socket.can_send() { + let (a, b) = data.send_buffer.as_slices(); + let sent = socket.send_slice(a)? + socket.send_slice(b)?; + data.send_buffer.drain(..sent); + } + + // if necessary, drain write buffers: + // either when drain has been requested explicitly, or when socket is being closed + // TODO: benchmark different variants here. (e.g. only return on half capacity) + if (!data.drain_waiter.is_empty() || data.write_eof) + && socket.send_queue() < socket.send_capacity() + { + for waiter in data.drain_waiter.drain(..) { + if waiter.send(()).is_err() { + log::debug!("TcpStream already closed, cannot send notification about drained buffers.") + } + } + } + + #[cfg(debug_assertions)] + log::debug!( + "TCP connection {}: socket state {} for {:?}", + connection_id, + socket.state(), + data.addr_tuple, + ); + + // if requested, close socket + if data.write_eof && data.send_buffer.is_empty() { + socket.close(); + data.write_eof = false; + } + + // if socket is closed, mark connection for removal + if socket.state() == tcp::State::Closed { + self.remove_conns.push(*connection_id); + } + } + + for connection_id in self.remove_conns.drain(..) { + let data = self.socket_data.remove(&connection_id).unwrap(); + self.sockets.remove(data.handle); + self.active_connections.remove(&data.addr_tuple); + } + Ok(()) + } +} + +impl fmt::Debug for NetworkIO<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let sockets: Vec = self + .sockets + .iter() + .filter_map(|(_h, s)| match s { + Socket::Tcp(s) => Some(s), + _ => None, + }) + .map(|sock| { + format!( + "TCP {:<21} {:<21} {}", + sock.remote_endpoint() + .map(|e| e.to_string()) + .as_ref() + .map_or("not connected", String::as_str), + sock.local_endpoint() + .map(|e| e.to_string()) + .as_ref() + .map_or("not connected", String::as_str), + sock.state() + ) + }) + .collect(); + + f.debug_struct("NetworkIO") + .field("sockets", &sockets) + .finish() + } +} diff --git a/src/network/mod.rs b/src/network/mod.rs index 52b4221b..aaa86e04 100755 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -4,6 +4,8 @@ pub use task::NetworkTask; mod virtual_device; +mod icmp; +mod io; #[cfg(test)] mod tests; diff --git a/src/network/task.rs b/src/network/task.rs index b5789c5d..d15fec5a 100755 --- a/src/network/task.rs +++ b/src/network/task.rs @@ -1,501 +1,18 @@ -use std::cmp; -use std::collections::{HashMap, HashSet, VecDeque}; use std::fmt; -use std::net::SocketAddr; use anyhow::Result; -use pretty_hex::pretty_hex; -use smoltcp::iface::{Config, SocketSet}; -use smoltcp::socket::{tcp, Socket}; +use smoltcp::time::Duration; -use smoltcp::wire::{ - HardwareAddress, Icmpv4Message, Icmpv4Packet, Icmpv4Repr, Icmpv6Message, Icmpv6Packet, - Icmpv6Repr, -}; -use smoltcp::{ - iface::{Interface, SocketHandle}, - phy::ChecksumCapabilities, - time::{Duration, Instant}, - wire::{ - IpAddress, IpCidr, IpProtocol, IpRepr, Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv6Address, - Ipv6Packet, Ipv6Repr, TcpPacket, UdpPacket, UdpRepr, - }, -}; use tokio::sync::{ broadcast, broadcast::Receiver as BroadcastReceiver, mpsc, mpsc::{Permit, Receiver, Sender, UnboundedReceiver}, - oneshot, }; use tokio::task::JoinHandle; -use crate::messages::{ - ConnectionId, IpPacket, NetworkCommand, NetworkEvent, TransportCommand, TransportEvent, - TunnelInfo, -}; - -use super::virtual_device::VirtualDevice; - -/// Associated data for a smoltcp socket. -#[derive(Debug)] -pub(super) struct SocketData { - handle: SocketHandle, - /// smoltcp can only operate with fixed-size buffers, but Python's stream implementation assumes - /// an infinite buffer. So we have a second send buffer here, plus a boolean to indicate that - /// we want to send a FIN. - send_buffer: VecDeque, - write_eof: bool, - // Gets notified once there's data to be read. - recv_waiter: Option<(u32, oneshot::Sender>)>, - // Gets notified once there is enough space in the write buffer. - drain_waiter: Vec>, - addr_tuple: (SocketAddr, SocketAddr), -} - -struct NetworkIO<'a> { - iface: Interface, - device: VirtualDevice, - sockets: SocketSet<'a>, - - net_tx: Sender, - - socket_data: HashMap, - active_connections: HashSet<(SocketAddr, SocketAddr)>, - next_connection_id: ConnectionId, -} - -impl<'a> NetworkIO<'a> { - fn new(net_tx: Sender) -> Self { - let mut device = VirtualDevice::new(net_tx.clone()); - - let config = Config::new(HardwareAddress::Ip); - let mut iface = Interface::new(config, &mut device, Instant::now()); - - iface.set_any_ip(true); - - iface.update_ip_addrs(|ip_address| { - ip_address - .push(IpCidr::new(IpAddress::v4(0, 0, 0, 1), 0)) - .unwrap(); - }); - // TODO: IPv6 - iface - .routes_mut() - .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1)) - .unwrap(); - - NetworkIO { - iface, - device, - sockets: SocketSet::new(Vec::new()), - net_tx, - socket_data: HashMap::new(), - active_connections: HashSet::new(), - next_connection_id: 0, - } - } - - fn receive_packet( - &mut self, - packet: IpPacket, - tunnel_info: TunnelInfo, - permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - if let IpPacket::V4(p) = &packet { - if !p.verify_checksum() { - log::warn!("Received invalid IP packet (checksum error)."); - return Ok(()); - } - } - - match packet.transport_protocol() { - IpProtocol::Tcp => self.receive_packet_tcp(packet, tunnel_info, permit), - IpProtocol::Udp => self.receive_packet_udp(packet, tunnel_info, permit), - IpProtocol::Icmp => self.receive_packet_icmp(packet), - _ => { - log::debug!( - "Received IP packet for unknown protocol: {}", - packet.transport_protocol() - ); - Ok(()) - } - } - } - - fn receive_packet_udp( - &mut self, - mut packet: IpPacket, - tunnel_info: TunnelInfo, - permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - let src_ip = packet.src_ip(); - let dst_ip = packet.dst_ip(); - - let mut udp_packet = match UdpPacket::new_checked(packet.payload_mut()) { - Ok(p) => p, - Err(e) => { - log::debug!("Received invalid UDP packet: {}", e); - return Ok(()); - } - }; - - let src_addr = SocketAddr::new(src_ip, udp_packet.src_port()); - let dst_addr = SocketAddr::new(dst_ip, udp_packet.dst_port()); - - let event = TransportEvent::DatagramReceived { - data: udp_packet.payload_mut().to_vec(), - src_addr, - dst_addr, - tunnel_info, - }; - - permit.send(event); - Ok(()) - } - - fn receive_packet_tcp( - &mut self, - mut packet: IpPacket, - tunnel_info: TunnelInfo, - permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - let src_ip = packet.src_ip(); - let dst_ip = packet.dst_ip(); - - let tcp_packet = match TcpPacket::new_checked(packet.payload_mut()) { - // packet with correct length - Ok(p) => { - // packet with correct checksum - if p.verify_checksum(&src_ip.into(), &dst_ip.into()) { - p - } else { - // packet with incorrect checksum - log::warn!("Received invalid TCP packet (checksum error)."); - return Ok(()); - } - } - // packet with incorrect length - Err(e) => { - log::debug!("Received invalid TCP packet ({}) with payload:", e); - log::debug!("{}", pretty_hex(&packet.payload_mut())); - return Ok(()); - } - }; - - let src_addr = SocketAddr::new(src_ip, tcp_packet.src_port()); - let dst_addr = SocketAddr::new(dst_ip, tcp_packet.dst_port()); - - if tcp_packet.syn() - && !tcp_packet.ack() - && !self.active_connections.contains(&(src_addr, dst_addr)) - { - let mut socket = tcp::Socket::new( - tcp::SocketBuffer::new(vec![0u8; 64 * 1024]), - tcp::SocketBuffer::new(vec![0u8; 64 * 1024]), - ); - - socket.listen(dst_addr)?; - socket.set_timeout(Some(Duration::from_secs(60))); - socket.set_keep_alive(Some(Duration::from_secs(28))); - - let handle = self.sockets.add(socket); - - let connection_id = { - self.next_connection_id += 1; - self.next_connection_id - }; - - let data = SocketData { - handle, - send_buffer: VecDeque::new(), - write_eof: false, - recv_waiter: None, - drain_waiter: Vec::new(), - addr_tuple: (src_addr, dst_addr), - }; - self.socket_data.insert(connection_id, data); - self.active_connections.insert((src_addr, dst_addr)); - - let event = TransportEvent::ConnectionEstablished { - connection_id, - src_addr, - dst_addr, - tunnel_info, - }; - permit.send(event); - } - - self.device.receive_packet(packet); - Ok(()) - } - - fn receive_packet_icmp(&mut self, packet: IpPacket) -> Result<()> { - // Some apps check network connectivity by sending ICMP pings. ICMP traffic is currently - // swallowed by mitmproxy_rs, which makes them believe that there is no network connectivity. - // Generating fake ICMP replies as a simple workaround. - - if let Ok(permit) = self.net_tx.try_reserve() { - // Generating and sending fake replies for ICMP echo requests. Ignoring all other ICMP types. - let response_packet = match packet { - IpPacket::V4(packet) => handle_icmpv4_echo_request(packet), - IpPacket::V6(packet) => handle_icmpv6_echo_request(packet), - }; - if let Some(response_packet) = response_packet { - permit.send(NetworkCommand::SendPacket(response_packet)); - } - } else { - log::debug!("Channel full, discarding ICMP packet."); - } - Ok(()) - } - - fn read_data(&mut self, id: ConnectionId, n: u32, tx: oneshot::Sender>) { - if let Some(data) = self.socket_data.get_mut(&id) { - assert!(data.recv_waiter.is_none()); - data.recv_waiter = Some((n, tx)); - } else { - // connection is has already been removed because the connection is closed, - // so we just drop the tx. - } - } - - fn write_data(&mut self, id: ConnectionId, buf: Vec) { - if let Some(data) = self.socket_data.get_mut(&id) { - data.send_buffer.extend(buf); - } else { - // connection is has already been removed because the connection is closed, - // so we just ignore the write. - } - } - - fn drain_writer(&mut self, id: ConnectionId, tx: oneshot::Sender<()>) { - if let Some(data) = self.socket_data.get_mut(&id) { - data.drain_waiter.push(tx); - } else { - // connection is has already been removed because the connection is closed, - // so we just drop the tx. - } - } - - fn close_connection(&mut self, id: ConnectionId, _half_close: bool) { - if let Some(data) = self.socket_data.get_mut(&id) { - // smoltcp does not have a good way to do "SHUT_RDWR". We can't call .abort() - // here because that sends a RST instead of a FIN (and breaks - // retransmissions of the connection close packet). Alternatively, we could manually - // set a timer on .close() and then forcibly .abort() once the timer expires (see - // tcp-abort branch). This incurs a bit of unnecessary complexity, so we try something - // dumber here: We simply close our end and then hope that either the client sends a FIN - // or times out via the keepalive mechanism. - - data.write_eof = true; - } else { - // connection is already dead. - } - } - - fn send_datagram(&mut self, data: Vec, src_addr: SocketAddr, dst_addr: SocketAddr) { - let permit = match self.net_tx.try_reserve() { - Ok(p) => p, - Err(_) => { - log::debug!("Channel full, discarding UDP packet."); - return; - } - }; - - // We now know that there's space for us to send, - // let's painstakingly reassemble the IP packet... - - let udp_repr = UdpRepr { - src_port: src_addr.port(), - dst_port: dst_addr.port(), - }; - - let ip_repr: IpRepr = match (src_addr, dst_addr) { - (SocketAddr::V4(src_addr), SocketAddr::V4(dst_addr)) => IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address::from(*src_addr.ip()), - dst_addr: Ipv4Address::from(*dst_addr.ip()), - next_header: IpProtocol::Udp, - payload_len: udp_repr.header_len() + data.len(), - hop_limit: 255, - }), - (SocketAddr::V6(src_addr), SocketAddr::V6(dst_addr)) => IpRepr::Ipv6(Ipv6Repr { - src_addr: Ipv6Address::from(*src_addr.ip()), - dst_addr: Ipv6Address::from(*dst_addr.ip()), - next_header: IpProtocol::Udp, - payload_len: udp_repr.header_len() + data.len(), - hop_limit: 255, - }), - _ => { - log::error!("Failed to assemble UDP datagram: mismatched IP address versions"); - return; - } - }; - - let buf = vec![0u8; ip_repr.buffer_len()]; - - let mut ip_packet = match ip_repr { - IpRepr::Ipv4(repr) => { - let mut packet = Ipv4Packet::new_unchecked(buf); - repr.emit(&mut packet, &ChecksumCapabilities::default()); - IpPacket::from(packet) - } - IpRepr::Ipv6(repr) => { - let mut packet = Ipv6Packet::new_unchecked(buf); - repr.emit(&mut packet); - IpPacket::from(packet) - } - }; - - udp_repr.emit( - &mut UdpPacket::new_unchecked(ip_packet.payload_mut()), - &ip_repr.src_addr(), - &ip_repr.dst_addr(), - data.len(), - |buf| buf.copy_from_slice(data.as_slice()), - &ChecksumCapabilities::default(), - ); - - permit.send(NetworkCommand::SendPacket(ip_packet)); - } - - fn handle_network_event( - &mut self, - event: NetworkEvent, - permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - match event { - NetworkEvent::ReceivePacket { - packet, - tunnel_info, - } => { - self.receive_packet(packet, tunnel_info, permit)?; - } - } - Ok(()) - } - - fn handle_transport_command(&mut self, command: TransportCommand) { - match command { - TransportCommand::ReadData(id, n, tx) => { - self.read_data(id, n, tx); - } - TransportCommand::WriteData(id, buf) => { - self.write_data(id, buf); - } - TransportCommand::DrainWriter(id, tx) => { - self.drain_writer(id, tx); - } - TransportCommand::CloseConnection(id, half_close) => { - self.close_connection(id, half_close); - } - TransportCommand::SendDatagram { - data, - src_addr, - dst_addr, - } => { - self.send_datagram(data, src_addr, dst_addr); - } - } - } -} - -fn handle_icmpv4_echo_request(mut input_packet: Ipv4Packet>) -> Option { - let src_addr = input_packet.src_addr(); - let dst_addr = input_packet.dst_addr(); - - // Parsing ICMP Packet - let mut input_icmpv4_packet = match Icmpv4Packet::new_checked(input_packet.payload_mut()) { - Ok(p) => p, - Err(e) => { - log::debug!("Received invalid ICMPv4 packet: {}", e); - return None; - } - }; - - // Checking that it is an ICMP Echo Request. - if input_icmpv4_packet.msg_type() != Icmpv4Message::EchoRequest { - log::debug!( - "Unsupported ICMPv4 packet of type: {}", - input_icmpv4_packet.msg_type() - ); - return None; - } - - // Creating fake response packet. - let icmp_repr = Icmpv4Repr::EchoReply { - ident: input_icmpv4_packet.echo_ident(), - seq_no: input_icmpv4_packet.echo_seq_no(), - data: input_icmpv4_packet.data_mut(), - }; - let ip_repr = Ipv4Repr { - // Directing fake reply back to the original source address. - src_addr: dst_addr, - dst_addr: src_addr, - next_header: IpProtocol::Icmp, - payload_len: icmp_repr.buffer_len(), - hop_limit: 255, - }; - let buf = vec![0u8; ip_repr.buffer_len() + icmp_repr.buffer_len()]; - let mut output_ipv4_packet = Ipv4Packet::new_unchecked(buf); - ip_repr.emit(&mut output_ipv4_packet, &ChecksumCapabilities::default()); - let mut output_ip_packet = IpPacket::from(output_ipv4_packet); - icmp_repr.emit( - &mut Icmpv4Packet::new_unchecked(output_ip_packet.payload_mut()), - &ChecksumCapabilities::default(), - ); - Some(output_ip_packet) -} - -fn handle_icmpv6_echo_request(mut input_packet: Ipv6Packet>) -> Option { - let src_addr = input_packet.src_addr(); - let dst_addr = input_packet.dst_addr(); - - // Parsing ICMP Packet - let mut input_icmpv6_packet = match Icmpv6Packet::new_checked(input_packet.payload_mut()) { - Ok(p) => p, - Err(e) => { - log::debug!("Received invalid ICMPv6 packet: {}", e); - return None; - } - }; - - // Checking that it is an ICMP Echo Request. - if input_icmpv6_packet.msg_type() != Icmpv6Message::EchoRequest { - log::debug!( - "Unsupported ICMPv6 packet of type: {}", - input_icmpv6_packet.msg_type() - ); - return None; - } - - // Creating fake response packet. - let icmp_repr = Icmpv6Repr::EchoReply { - ident: input_icmpv6_packet.echo_ident(), - seq_no: input_icmpv6_packet.echo_seq_no(), - data: input_icmpv6_packet.payload_mut(), - }; - let ip_repr = Ipv6Repr { - // Directing fake reply back to the original source address. - src_addr: dst_addr, - dst_addr: src_addr, - next_header: IpProtocol::Icmp, - payload_len: icmp_repr.buffer_len(), - hop_limit: 255, - }; - let buf = vec![0u8; ip_repr.buffer_len() + icmp_repr.buffer_len()]; - let mut output_ipv6_packet = Ipv6Packet::new_unchecked(buf); - ip_repr.emit(&mut output_ipv6_packet); - let mut output_ip_packet = IpPacket::from(output_ipv6_packet); - icmp_repr.emit( - // Directing fake reply back to the original source address. - &IpAddress::from(dst_addr), - &IpAddress::from(src_addr), - &mut Icmpv6Packet::new_unchecked(output_ip_packet.payload_mut()), - &ChecksumCapabilities::default(), - ); - Some(output_ip_packet) -} +use crate::messages::{NetworkCommand, NetworkEvent, TransportCommand, TransportEvent}; +use crate::network::io::NetworkIO; pub struct NetworkTask<'a> { net_tx: Sender, @@ -552,10 +69,8 @@ impl NetworkTask<'_> { } pub async fn run(mut self) -> Result<()> { - let mut io = self.io; - let mut remove_conns = Vec::new(); - let mut py_tx_permit: Option> = None; + let mut delay: Option = None; 'task: loop { // On a high level, we do three things in our main loop: @@ -563,9 +78,6 @@ impl NetworkTask<'_> { // 2. `.poll()` the smoltcp interface until it's finished with everything for now. // 3. Check if we can wake up any waiters, move more data in the send buffer, or clean up sockets. - // check device for timeouts - let delay = io.iface.poll_delay(Instant::now(), &io.sockets); - #[cfg(debug_assertions)] if let Some(d) = delay { log::debug!("Waiting for device timeout: {} ...", d); @@ -574,139 +86,52 @@ impl NetworkTask<'_> { #[cfg(debug_assertions)] log::debug!("Waiting for events ..."); - if py_tx_permit.is_none() { - py_tx_permit = self.py_tx.try_reserve().ok(); - } - let net_tx_full = self.net_tx.capacity() == 0; + let py_tx_available = py_tx_permit.is_some(); + let net_tx_available = self.net_tx.capacity() > 0; tokio::select! { // wait for graceful shutdown _ = self.shutdown.recv() => break 'task, // wait for timeouts when the device is idle _ = async { tokio::time::sleep(delay.unwrap().into()).await }, if delay.is_some() => {}, - // wait for incoming packets - Some(e) = self.net_rx.recv(), if py_tx_permit.is_some() => { + // wait for py_tx channel capacity... + Ok(permit) = self.py_tx.reserve(), if !py_tx_available => { + py_tx_permit = Some(permit); + continue 'task; + }, + // ...or process incoming packets + Some(e) = self.net_rx.recv(), if py_tx_available => { // handle pending network events until channel is full - io.handle_network_event(e, py_tx_permit.take().unwrap())?; - + self.io.handle_network_event(e, py_tx_permit.take().unwrap())?; while let Ok(p) = self.py_tx.try_reserve() { if let Ok(e) = self.net_rx.try_recv() { - io.handle_network_event(e, p)?; + self.io.handle_network_event(e, p)?; } else { break; } } }, - // wait for outgoing packets - Some(c) = self.py_rx.recv(), if !net_tx_full => { + // wait for net_tx capacity... + Ok(permit) = self.net_tx.reserve(), if !net_tx_available => { + drop(permit); // smoltcp's device stuff is not permit-based. + continue 'task; + }, + // ...or process outgoing packets + Some(c) = self.py_rx.recv(), if net_tx_available => { // handle pending transport commands until channel is full - io.handle_transport_command(c); - + self.io.handle_transport_command(c); while self.net_tx.capacity() > 0 { if let Ok(c) = self.py_rx.try_recv() { - io.handle_transport_command(c); + self.io.handle_transport_command(c); } else { break; } } }, - // wait until channels are no longer full - Ok(()) = wait_for_channel_capacity(&self.py_tx), if py_tx_permit.is_none() => {}, - Ok(()) = wait_for_channel_capacity(&self.net_tx), if net_tx_full => {}, } - // poll virtual network device - #[cfg(debug_assertions)] - log::debug!("Polling virtual network device ..."); - io.iface - .poll(Instant::now(), &mut io.device, &mut io.sockets); - - #[cfg(debug_assertions)] - log::debug!("Processing TCP connections ..."); - - for (connection_id, data) in io.socket_data.iter_mut() { - let socket = io.sockets.get_mut::(data.handle); - - // receive data over the socket - if data.recv_waiter.is_some() { - if socket.can_recv() { - let (n, tx) = data.recv_waiter.take().unwrap(); - let bytes_available = socket.recv_queue(); - - let mut buf = vec![0u8; cmp::min(bytes_available, n as usize)]; - let bytes_read = socket.recv_slice(&mut buf)?; - - buf.truncate(bytes_read); - if tx.send(buf).is_err() { - log::debug!("Cannot send received data, channel was already closed."); - } - } else { - // We can't use .may_recv() here as it returns false during establishment. - use tcp::State::*; - match socket.state() { - // can we still receive something in the future? - CloseWait | LastAck | Closed | Closing | TimeWait => { - let (_, tx) = data.recv_waiter.take().unwrap(); - if tx.send(Vec::new()).is_err() { - log::debug!("Cannot send close, channel was already closed."); - } - } - _ => {} - } - } - } - - // send data over the socket - if !data.send_buffer.is_empty() && socket.can_send() { - let (a, b) = data.send_buffer.as_slices(); - let sent = socket.send_slice(a)? + socket.send_slice(b)?; - data.send_buffer.drain(..sent); - } - - // if necessary, drain write buffers: - // either when drain has been requested explicitly, or when socket is being closed - // TODO: benchmark different variants here. (e.g. only return on half capacity) - if (!data.drain_waiter.is_empty() || data.write_eof) - && socket.send_queue() < socket.send_capacity() - { - for waiter in data.drain_waiter.drain(..) { - if waiter.send(()).is_err() { - log::debug!("TcpStream already closed, cannot send notification about drained buffers.") - } - } - } - - #[cfg(debug_assertions)] - log::debug!( - "TCP connection {}: socket state {} for {:?}", - connection_id, - socket.state(), - data.addr_tuple, - ); - - // if requested, close socket - if data.write_eof && data.send_buffer.is_empty() { - socket.close(); - data.write_eof = false; - } - - // if socket is closed, mark connection for removal - if socket.state() == tcp::State::Closed { - remove_conns.push(*connection_id); - } - } - - for connection_id in remove_conns.drain(..) { - let data = io.socket_data.remove(&connection_id).unwrap(); - io.sockets.remove(data.handle); - io.active_connections.remove(&data.addr_tuple); - } - - // poll again. we may have new stuff to do. - #[cfg(debug_assertions)] - log::debug!("Polling virtual network device ..."); - io.iface - .poll(Instant::now(), &mut io.device, &mut io.sockets); + self.io.poll()?; + delay = self.io.poll_delay(); } // TODO: process remaining pending data after the shutdown request was received? @@ -718,38 +143,6 @@ impl NetworkTask<'_> { impl fmt::Debug for NetworkTask<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let sockets: Vec = self - .io - .sockets - .iter() - .filter_map(|(_h, s)| match s { - Socket::Tcp(s) => Some(s), - _ => None, - }) - .map(|sock| { - format!( - "TCP {:<21} {:<21} {}", - sock.remote_endpoint() - .map(|e| e.to_string()) - .as_ref() - .map_or("not connected", String::as_str), - sock.local_endpoint() - .map(|e| e.to_string()) - .as_ref() - .map_or("not connected", String::as_str), - sock.state() - ) - }) - .collect(); - - f.debug_struct("NetworkTask") - .field("sockets", &sockets) - .finish() + f.debug_struct("NetworkTask").field("io", &self.io).finish() } } - -async fn wait_for_channel_capacity(s: &Sender) -> Result<()> { - let permit = s.reserve().await?; - drop(permit); - Ok(()) -} From b738b241559c6a93bf8183646ad00c30ac82f4bd Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 14 Dec 2023 16:13:48 +0100 Subject: [PATCH 02/20] add udp handler --- Cargo.lock | 8 +- Cargo.toml | 2 + mitmproxy-windows/redirector/Cargo.toml | 2 +- src/messages.rs | 13 ++ src/network/core.rs | 133 +++++++++++ src/network/mod.rs | 4 +- src/network/task.rs | 12 +- src/network/{io.rs => tcp.rs} | 239 ++----------------- src/network/udp.rs | 292 ++++++++++++++++++++++++ 9 files changed, 479 insertions(+), 226 deletions(-) create mode 100644 src/network/core.rs rename src/network/{io.rs => tcp.rs} (59%) create mode 100644 src/network/udp.rs diff --git a/Cargo.lock b/Cargo.lock index 6cf614a1..9d1245dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1126,10 +1126,12 @@ checksum = "fc6d6206008e25125b1f97fbe5d309eb7b85141cf9199d52dbd3729a1584dd16" [[package]] name = "internet-packet" -version = "0.1.0" -source = "git+https://github.com/mhils/internet-packet.git#41a0b6711147269ba5be0eae595bfac810b4b195" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95d8d20ad61a92e71edf571fa568e14aeba0c5f00548acd491fbf694ce9a5ad8" dependencies = [ "internet-checksum", + "smoltcp", ] [[package]] @@ -1353,7 +1355,9 @@ dependencies = [ "env_logger", "futures-util", "image", + "internet-packet", "log", + "lru_time_cache", "nix 0.27.1", "once_cell", "pretty-hex", diff --git a/Cargo.toml b/Cargo.toml index 80d8850d..5a8731b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,8 @@ image = "0.24.6" prost = "0.12.3" tokio-util = { version = "0.7.10", features = ["codec"] } futures-util = { version = "0.3.29", features = ["sink"] } +lru_time_cache = "0.11.11" +internet-packet = { version = "0.2.0", features = ["smoltcp"] } # [patch.crates-io] # tokio = { path = "../tokio/tokio" } diff --git a/mitmproxy-windows/redirector/Cargo.toml b/mitmproxy-windows/redirector/Cargo.toml index 4ab749d8..6cba51d2 100644 --- a/mitmproxy-windows/redirector/Cargo.toml +++ b/mitmproxy-windows/redirector/Cargo.toml @@ -19,7 +19,7 @@ lru_time_cache = "0.11.11" log = "0.4.18" env_logger = "0.10.1" prost = "0.12.3" -internet-packet = { git = "https://github.com/mhils/internet-packet.git", features = ["checksums"] } +internet-packet = { version = "0.2.0", features = ["checksums"] } [target.'cfg(windows)'.dev-dependencies] hex = "0.4.3" diff --git a/src/messages.rs b/src/messages.rs index ba45ce5f..fe436605 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,6 +1,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use anyhow::{anyhow, Result}; +use internet_packet::InternetPacket; use smoltcp::wire::{IpProtocol, Ipv4Packet, Ipv6Packet}; use tokio::sync::oneshot; @@ -60,6 +61,7 @@ pub enum TransportCommand { WriteData(ConnectionId, Vec), DrainWriter(ConnectionId, oneshot::Sender<()>), CloseConnection(ConnectionId, bool), + // FIXME: Remove SendDatagram { data: Vec, src_addr: SocketAddr, @@ -86,6 +88,17 @@ impl From>> for IpPacket { } } +impl TryInto for IpPacket { + type Error = internet_packet::ParseError; + + fn try_into(self) -> std::result::Result { + match self { + IpPacket::V4(packet) => InternetPacket::try_from(packet), + IpPacket::V6(packet) => InternetPacket::try_from(packet), + } + } +} + impl TryFrom> for IpPacket { type Error = anyhow::Error; diff --git a/src/network/core.rs b/src/network/core.rs new file mode 100644 index 00000000..d35f4868 --- /dev/null +++ b/src/network/core.rs @@ -0,0 +1,133 @@ +use std::cmp::min; +use std::fmt; + +use std::time::Duration; + +use anyhow::Result; + +use smoltcp::wire::IpProtocol; +use tokio::sync::mpsc::{Permit, Sender}; + +use crate::messages::{IpPacket, NetworkCommand, NetworkEvent, TransportCommand, TransportEvent}; +use crate::network::icmp::{handle_icmpv4_echo_request, handle_icmpv6_echo_request}; + +use crate::network::tcp::TcpHandler; +use crate::network::udp::UdpHandler; + +pub struct NetworkStack<'a> { + tcp: TcpHandler<'a>, + udp: UdpHandler, + net_tx: Sender, +} + +impl<'a> NetworkStack<'a> { + pub fn new(net_tx: Sender) -> Self { + Self { + tcp: TcpHandler::new(net_tx.clone()), + udp: UdpHandler::new(net_tx.clone()), + net_tx, + } + } + + pub fn handle_network_event( + &mut self, + event: NetworkEvent, + permit: Permit<'_, TransportEvent>, + ) -> Result<()> { + let (packet, tunnel_info) = match event { + NetworkEvent::ReceivePacket { + packet, + tunnel_info, + } => (packet, tunnel_info), + }; + + if let IpPacket::V4(p) = &packet { + if !p.verify_checksum() { + log::warn!("Received invalid IP packet (checksum error)."); + return Ok(()); + } + } + + match packet.transport_protocol() { + IpProtocol::Tcp => self.tcp.receive_packet(packet, tunnel_info, permit), + IpProtocol::Udp => self.udp.receive_packet(packet, tunnel_info, permit), + IpProtocol::Icmp => self.receive_packet_icmp(packet), + _ => { + log::debug!( + "Received IP packet for unknown protocol: {}", + packet.transport_protocol() + ); + Ok(()) + } + } + } + + fn receive_packet_icmp(&mut self, packet: IpPacket) -> Result<()> { + // Some apps check network connectivity by sending ICMP pings. ICMP traffic is currently + // swallowed by mitmproxy_rs, which makes them believe that there is no network connectivity. + // Generating fake ICMP replies as a simple workaround. + + if let Ok(permit) = self.net_tx.try_reserve() { + // Generating and sending fake replies for ICMP echo requests. Ignoring all other ICMP types. + let response_packet = match packet { + IpPacket::V4(packet) => handle_icmpv4_echo_request(packet), + IpPacket::V6(packet) => handle_icmpv6_echo_request(packet), + }; + if let Some(response_packet) = response_packet { + permit.send(NetworkCommand::SendPacket(response_packet)); + } + } else { + log::debug!("Channel full, discarding ICMP packet."); + } + Ok(()) + } + + pub fn handle_transport_command(&mut self, command: TransportCommand) { + match command { + TransportCommand::ReadData(id, n, tx) => match id & 1 == 1 { + true => self.udp.read_data(id, tx), + false => self.tcp.read_data(id, n, tx), + }, + TransportCommand::WriteData(id, buf) => match id & 1 == 1 { + true => self.udp.write_data(id, buf), + false => self.tcp.write_data(id, buf), + }, + TransportCommand::DrainWriter(id, tx) => match id & 1 == 1 { + true => self.udp.drain_writer(id, tx), + false => self.tcp.drain_writer(id, tx), + }, + TransportCommand::CloseConnection(id, half_close) => match id & 1 == 1 { + true => self.udp.close_connection(id), + false => self.tcp.close_connection(id, half_close), + }, + TransportCommand::SendDatagram { + data: _, + src_addr: _, + dst_addr: _, + } => { + // TODO remove + log::error!("Error: SendDatagram is deprecated."); + } + }; + } + + pub fn poll_delay(&mut self) -> Option { + match (self.tcp.poll_delay(), self.udp.poll_delay()) { + (Some(a), Some(b)) => Some(min(a, b)), + (Some(a), None) => Some(a), + (None, Some(b)) => Some(b), + (None, None) => None, + } + } + + pub fn poll(&mut self) -> Result<()> { + self.udp.poll(); + self.tcp.poll() + } +} + +impl fmt::Debug for NetworkStack<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NetworkIO").field("tcp", &self.tcp).finish() + } +} diff --git a/src/network/mod.rs b/src/network/mod.rs index aaa86e04..998a8331 100755 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -4,9 +4,11 @@ pub use task::NetworkTask; mod virtual_device; +mod core; mod icmp; -mod io; +mod tcp; #[cfg(test)] mod tests; +mod udp; pub const MAX_PACKET_SIZE: usize = 65535; diff --git a/src/network/task.rs b/src/network/task.rs index d15fec5a..0e61e8ab 100755 --- a/src/network/task.rs +++ b/src/network/task.rs @@ -1,8 +1,8 @@ use std::fmt; use anyhow::Result; -use smoltcp::time::Duration; +use std::time::Duration; use tokio::sync::{ broadcast, broadcast::Receiver as BroadcastReceiver, @@ -12,7 +12,7 @@ use tokio::sync::{ use tokio::task::JoinHandle; use crate::messages::{NetworkCommand, NetworkEvent, TransportCommand, TransportEvent}; -use crate::network::io::NetworkIO; +use crate::network::core::NetworkStack; pub struct NetworkTask<'a> { net_tx: Sender, @@ -21,7 +21,7 @@ pub struct NetworkTask<'a> { py_rx: UnboundedReceiver, shutdown: BroadcastReceiver<()>, - io: NetworkIO<'a>, + io: NetworkStack<'a>, } #[allow(clippy::type_complexity)] @@ -57,7 +57,7 @@ impl NetworkTask<'_> { py_rx: UnboundedReceiver, sd_watcher: BroadcastReceiver<()>, ) -> Result { - let io = NetworkIO::new(net_tx.clone()); + let io = NetworkStack::new(net_tx.clone()); Ok(Self { net_tx, net_rx, @@ -80,7 +80,7 @@ impl NetworkTask<'_> { #[cfg(debug_assertions)] if let Some(d) = delay { - log::debug!("Waiting for device timeout: {} ...", d); + log::debug!("Waiting for device timeout: {:?} ...", d); } #[cfg(debug_assertions)] @@ -93,7 +93,7 @@ impl NetworkTask<'_> { // wait for graceful shutdown _ = self.shutdown.recv() => break 'task, // wait for timeouts when the device is idle - _ = async { tokio::time::sleep(delay.unwrap().into()).await }, if delay.is_some() => {}, + _ = async { tokio::time::sleep(delay.unwrap()).await }, if delay.is_some() => {}, // wait for py_tx channel capacity... Ok(permit) = self.py_tx.reserve(), if !py_tx_available => { py_tx_permit = Some(permit); diff --git a/src/network/io.rs b/src/network/tcp.rs similarity index 59% rename from src/network/io.rs rename to src/network/tcp.rs index 6a9a5046..b699fb34 100644 --- a/src/network/io.rs +++ b/src/network/tcp.rs @@ -6,27 +6,19 @@ use anyhow::Result; use pretty_hex::pretty_hex; use smoltcp::iface::{Config, SocketSet}; use smoltcp::socket::{tcp, Socket}; - use smoltcp::wire::HardwareAddress; use smoltcp::{ iface::{Interface, SocketHandle}, - phy::ChecksumCapabilities, - time::{Duration, Instant}, - wire::{ - IpAddress, IpCidr, IpProtocol, IpRepr, Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv6Address, - Ipv6Packet, Ipv6Repr, TcpPacket, UdpPacket, UdpRepr, - }, + time::Instant, + wire::{IpAddress, IpCidr, Ipv4Address, TcpPacket}, }; +use std::time::Duration; use tokio::sync::{ mpsc::{Permit, Sender}, oneshot, }; -use crate::messages::{ - ConnectionId, IpPacket, NetworkCommand, NetworkEvent, TransportCommand, TransportEvent, - TunnelInfo, -}; -use crate::network::icmp::{handle_icmpv4_echo_request, handle_icmpv6_echo_request}; +use crate::messages::{ConnectionId, IpPacket, NetworkCommand, TransportEvent, TunnelInfo}; use super::virtual_device::VirtualDevice; @@ -46,22 +38,19 @@ struct SocketData { addr_tuple: (SocketAddr, SocketAddr), } -pub struct NetworkIO<'a> { +pub struct TcpHandler<'a> { + next_connection_id: ConnectionId, iface: Interface, device: VirtualDevice, sockets: SocketSet<'a>, - - net_tx: Sender, - socket_data: HashMap, - active_connections: HashSet<(SocketAddr, SocketAddr)>, - next_connection_id: ConnectionId, remove_conns: Vec, + active_connections: HashSet<(SocketAddr, SocketAddr)>, } -impl<'a> NetworkIO<'a> { +impl<'a> TcpHandler<'a> { pub fn new(net_tx: Sender) -> Self { - let mut device = VirtualDevice::new(net_tx.clone()); + let mut device = VirtualDevice::new(net_tx); let config = Config::new(HardwareAddress::Ip); let mut iface = Interface::new(config, &mut device, Instant::now()); @@ -79,11 +68,10 @@ impl<'a> NetworkIO<'a> { .add_default_ipv4_route(Ipv4Address::new(0, 0, 0, 1)) .unwrap(); - NetworkIO { + TcpHandler { iface, device, sockets: SocketSet::new(Vec::new()), - net_tx, socket_data: HashMap::new(), active_connections: HashSet::new(), next_connection_id: 0, @@ -91,65 +79,7 @@ impl<'a> NetworkIO<'a> { } } - fn receive_packet( - &mut self, - packet: IpPacket, - tunnel_info: TunnelInfo, - permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - if let IpPacket::V4(p) = &packet { - if !p.verify_checksum() { - log::warn!("Received invalid IP packet (checksum error)."); - return Ok(()); - } - } - - match packet.transport_protocol() { - IpProtocol::Tcp => self.receive_packet_tcp(packet, tunnel_info, permit), - IpProtocol::Udp => self.receive_packet_udp(packet, tunnel_info, permit), - IpProtocol::Icmp => self.receive_packet_icmp(packet), - _ => { - log::debug!( - "Received IP packet for unknown protocol: {}", - packet.transport_protocol() - ); - Ok(()) - } - } - } - - fn receive_packet_udp( - &mut self, - mut packet: IpPacket, - tunnel_info: TunnelInfo, - permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - let src_ip = packet.src_ip(); - let dst_ip = packet.dst_ip(); - - let mut udp_packet = match UdpPacket::new_checked(packet.payload_mut()) { - Ok(p) => p, - Err(e) => { - log::debug!("Received invalid UDP packet: {}", e); - return Ok(()); - } - }; - - let src_addr = SocketAddr::new(src_ip, udp_packet.src_port()); - let dst_addr = SocketAddr::new(dst_ip, udp_packet.dst_port()); - - let event = TransportEvent::DatagramReceived { - data: udp_packet.payload_mut().to_vec(), - src_addr, - dst_addr, - tunnel_info, - }; - - permit.send(event); - Ok(()) - } - - fn receive_packet_tcp( + pub fn receive_packet( &mut self, mut packet: IpPacket, tunnel_info: TunnelInfo, @@ -191,13 +121,13 @@ impl<'a> NetworkIO<'a> { ); socket.listen(dst_addr)?; - socket.set_timeout(Some(Duration::from_secs(60))); - socket.set_keep_alive(Some(Duration::from_secs(28))); + socket.set_timeout(Some(smoltcp::time::Duration::from_secs(60))); + socket.set_keep_alive(Some(smoltcp::time::Duration::from_secs(28))); let handle = self.sockets.add(socket); let connection_id = { - self.next_connection_id += 1; + self.next_connection_id += 2; // only even ids. self.next_connection_id }; @@ -225,27 +155,13 @@ impl<'a> NetworkIO<'a> { Ok(()) } - fn receive_packet_icmp(&mut self, packet: IpPacket) -> Result<()> { - // Some apps check network connectivity by sending ICMP pings. ICMP traffic is currently - // swallowed by mitmproxy_rs, which makes them believe that there is no network connectivity. - // Generating fake ICMP replies as a simple workaround. - - if let Ok(permit) = self.net_tx.try_reserve() { - // Generating and sending fake replies for ICMP echo requests. Ignoring all other ICMP types. - let response_packet = match packet { - IpPacket::V4(packet) => handle_icmpv4_echo_request(packet), - IpPacket::V6(packet) => handle_icmpv6_echo_request(packet), - }; - if let Some(response_packet) = response_packet { - permit.send(NetworkCommand::SendPacket(response_packet)); - } - } else { - log::debug!("Channel full, discarding ICMP packet."); - } - Ok(()) + pub fn poll_delay(&mut self) -> Option { + self.iface + .poll_delay(Instant::now(), &self.sockets) + .map(Duration::from) } - fn read_data(&mut self, id: ConnectionId, n: u32, tx: oneshot::Sender>) { + pub fn read_data(&mut self, id: ConnectionId, n: u32, tx: oneshot::Sender>) { if let Some(data) = self.socket_data.get_mut(&id) { assert!(data.recv_waiter.is_none()); data.recv_waiter = Some((n, tx)); @@ -255,7 +171,7 @@ impl<'a> NetworkIO<'a> { } } - fn write_data(&mut self, id: ConnectionId, buf: Vec) { + pub fn write_data(&mut self, id: ConnectionId, buf: Vec) { if let Some(data) = self.socket_data.get_mut(&id) { data.send_buffer.extend(buf); } else { @@ -264,7 +180,7 @@ impl<'a> NetworkIO<'a> { } } - fn drain_writer(&mut self, id: ConnectionId, tx: oneshot::Sender<()>) { + pub fn drain_writer(&mut self, id: ConnectionId, tx: oneshot::Sender<()>) { if let Some(data) = self.socket_data.get_mut(&id) { data.drain_waiter.push(tx); } else { @@ -273,7 +189,7 @@ impl<'a> NetworkIO<'a> { } } - fn close_connection(&mut self, id: ConnectionId, _half_close: bool) { + pub fn close_connection(&mut self, id: ConnectionId, _half_close: bool) { if let Some(data) = self.socket_data.get_mut(&id) { // smoltcp does not have a good way to do "SHUT_RDWR". We can't call .abort() // here because that sends a RST instead of a FIN (and breaks @@ -289,115 +205,6 @@ impl<'a> NetworkIO<'a> { } } - fn send_datagram(&mut self, data: Vec, src_addr: SocketAddr, dst_addr: SocketAddr) { - let permit = match self.net_tx.try_reserve() { - Ok(p) => p, - Err(_) => { - log::debug!("Channel full, discarding UDP packet."); - return; - } - }; - - // We now know that there's space for us to send, - // let's painstakingly reassemble the IP packet... - - let udp_repr = UdpRepr { - src_port: src_addr.port(), - dst_port: dst_addr.port(), - }; - - let ip_repr: IpRepr = match (src_addr, dst_addr) { - (SocketAddr::V4(src_addr), SocketAddr::V4(dst_addr)) => IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address::from(*src_addr.ip()), - dst_addr: Ipv4Address::from(*dst_addr.ip()), - next_header: IpProtocol::Udp, - payload_len: udp_repr.header_len() + data.len(), - hop_limit: 255, - }), - (SocketAddr::V6(src_addr), SocketAddr::V6(dst_addr)) => IpRepr::Ipv6(Ipv6Repr { - src_addr: Ipv6Address::from(*src_addr.ip()), - dst_addr: Ipv6Address::from(*dst_addr.ip()), - next_header: IpProtocol::Udp, - payload_len: udp_repr.header_len() + data.len(), - hop_limit: 255, - }), - _ => { - log::error!("Failed to assemble UDP datagram: mismatched IP address versions"); - return; - } - }; - - let buf = vec![0u8; ip_repr.buffer_len()]; - - let mut ip_packet = match ip_repr { - IpRepr::Ipv4(repr) => { - let mut packet = Ipv4Packet::new_unchecked(buf); - repr.emit(&mut packet, &ChecksumCapabilities::default()); - IpPacket::from(packet) - } - IpRepr::Ipv6(repr) => { - let mut packet = Ipv6Packet::new_unchecked(buf); - repr.emit(&mut packet); - IpPacket::from(packet) - } - }; - - udp_repr.emit( - &mut UdpPacket::new_unchecked(ip_packet.payload_mut()), - &ip_repr.src_addr(), - &ip_repr.dst_addr(), - data.len(), - |buf| buf.copy_from_slice(data.as_slice()), - &ChecksumCapabilities::default(), - ); - - permit.send(NetworkCommand::SendPacket(ip_packet)); - } - - pub fn handle_network_event( - &mut self, - event: NetworkEvent, - permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - match event { - NetworkEvent::ReceivePacket { - packet, - tunnel_info, - } => { - self.receive_packet(packet, tunnel_info, permit)?; - } - } - Ok(()) - } - - pub fn handle_transport_command(&mut self, command: TransportCommand) { - match command { - TransportCommand::ReadData(id, n, tx) => { - self.read_data(id, n, tx); - } - TransportCommand::WriteData(id, buf) => { - self.write_data(id, buf); - } - TransportCommand::DrainWriter(id, tx) => { - self.drain_writer(id, tx); - } - TransportCommand::CloseConnection(id, half_close) => { - self.close_connection(id, half_close); - } - TransportCommand::SendDatagram { - data, - src_addr, - dst_addr, - } => { - self.send_datagram(data, src_addr, dst_addr); - } - } - } - - pub fn poll_delay(&mut self) -> Option { - self.iface.poll_delay(Instant::now(), &self.sockets) - } - pub fn poll(&mut self) -> Result<()> { // poll virtual network device #[cfg(debug_assertions)] @@ -500,7 +307,7 @@ impl<'a> NetworkIO<'a> { } } -impl fmt::Debug for NetworkIO<'_> { +impl fmt::Debug for TcpHandler<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let sockets: Vec = self .sockets diff --git a/src/network/udp.rs b/src/network/udp.rs new file mode 100644 index 00000000..92def8b7 --- /dev/null +++ b/src/network/udp.rs @@ -0,0 +1,292 @@ +use std::collections::VecDeque; +use std::net::SocketAddr; +use std::time::Duration; + +use lru_time_cache::LruCache; +use tokio::sync::mpsc::{Permit, Sender}; +use tokio::sync::oneshot; + +use crate::messages::{ConnectionId, IpPacket, NetworkCommand, TransportEvent, TunnelInfo}; +use anyhow::Result; +use internet_packet::InternetPacket; +use smoltcp::phy::ChecksumCapabilities; + +use smoltcp::wire::{ + IpProtocol, IpRepr, Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv6Address, Ipv6Packet, Ipv6Repr, + UdpPacket, UdpRepr, +}; + +struct ConnectionState { + remote_addr: SocketAddr, + local_addr: SocketAddr, + closed: bool, + packets: VecDeque>, + read_tx: Option>>, +} + +impl ConnectionState { + fn new(remote_addr: SocketAddr, local_addr: SocketAddr) -> Self { + Self { + remote_addr, + local_addr, + closed: false, + packets: VecDeque::new(), + read_tx: None, + } + } + fn receive_packet_payload(&mut self, data: Vec) { + if self.closed { + } else if let Some(tx) = self.read_tx.take() { + tx.send(data).ok(); + } else { + self.packets.push_back(data); + } + } + fn read_packet_payload(&mut self, tx: oneshot::Sender>) { + assert!(self.read_tx.is_none()); + if self.closed { + drop(tx); + } else if let Some(data) = self.packets.pop_front() { + tx.send(data).ok(); + } else { + self.read_tx = Some(tx); + } + } + fn close(&mut self) { + if self.closed { + } else if let Some(tx) = self.read_tx.take() { + drop(tx); + self.closed = true; + } else { + self.packets.clear(); + self.closed = true; + } + } +} + +pub struct UdpHandler { + next_connection_id: ConnectionId, + id_lookup: LruCache<(SocketAddr, SocketAddr), ConnectionId>, + connections: LruCache, + net_tx: Sender, +} + +impl UdpHandler { + pub fn new(net_tx: Sender) -> Self { + let connections = LruCache::::with_expiry_duration( + Duration::from_secs(60), + ); + let id_lookup = LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration( + Duration::from_secs(60), + ); + Self { + connections, + id_lookup, + net_tx, + next_connection_id: 1, + } + } + + pub fn read_data(&mut self, id: ConnectionId, tx: oneshot::Sender>) { + if let Some(state) = self.connections.get_mut(&id) { + state.read_packet_payload(tx); + } + } + + pub fn write_data(&mut self, id: ConnectionId, data: Vec) { + let Some(state) = self.connections.get(&id) else { + return; + }; + // Refresh id lookup. + self.id_lookup + .insert((state.local_addr, state.remote_addr), id); + + let permit = match self.net_tx.try_reserve() { + Ok(p) => p, + Err(_) => { + log::debug!("Channel full, discarding UDP packet."); + return; + } + }; + + // We now know that there's space for us to send, + // let's painstakingly reassemble the IP packet... + + let udp_repr = UdpRepr { + src_port: state.local_addr.port(), + dst_port: state.remote_addr.port(), + }; + + let ip_repr: IpRepr = match (state.local_addr, state.remote_addr) { + (SocketAddr::V4(src_addr), SocketAddr::V4(dst_addr)) => IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address::from(*src_addr.ip()), + dst_addr: Ipv4Address::from(*dst_addr.ip()), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + data.len(), + hop_limit: 255, + }), + (SocketAddr::V6(src_addr), SocketAddr::V6(dst_addr)) => IpRepr::Ipv6(Ipv6Repr { + src_addr: Ipv6Address::from(*src_addr.ip()), + dst_addr: Ipv6Address::from(*dst_addr.ip()), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + data.len(), + hop_limit: 255, + }), + _ => { + log::error!("Failed to assemble UDP datagram: mismatched IP address versions"); + return; + } + }; + + let buf = vec![0u8; ip_repr.buffer_len()]; + + let mut ip_packet = match ip_repr { + IpRepr::Ipv4(repr) => { + let mut packet = Ipv4Packet::new_unchecked(buf); + repr.emit(&mut packet, &ChecksumCapabilities::default()); + IpPacket::from(packet) + } + IpRepr::Ipv6(repr) => { + let mut packet = Ipv6Packet::new_unchecked(buf); + repr.emit(&mut packet); + IpPacket::from(packet) + } + }; + + udp_repr.emit( + &mut UdpPacket::new_unchecked(ip_packet.payload_mut()), + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + data.len(), + |buf| buf.copy_from_slice(data.as_slice()), + &ChecksumCapabilities::default(), + ); + + permit.send(NetworkCommand::SendPacket(ip_packet)); + } + + pub fn drain_writer(&mut self, _id: ConnectionId, tx: oneshot::Sender<()>) { + tx.send(()).ok(); + } + + pub fn close_connection(&mut self, id: ConnectionId) { + if let Some(state) = self.connections.get_mut(&id) { + state.close(); + } + } + + pub fn receive_packet( + &mut self, + packet: IpPacket, + tunnel_info: TunnelInfo, + permit: Permit<'_, TransportEvent>, + ) -> Result<()> { + let packet: InternetPacket = match packet.try_into() { + Ok(p) => p, + Err(e) => { + log::debug!("Received invalid IP packet: {}", e); + return Ok(()); + } + }; + let src_addr = packet.src(); + let dst_addr = packet.dst(); + + let potential_cid = self + .id_lookup + .get(&(src_addr, dst_addr)) + .cloned() + .unwrap_or(0); // guaranteed to not exist. + + let payload = packet.payload().to_vec(); + + match self.connections.get_mut(&potential_cid) { + Some(state) => { + state.receive_packet_payload(payload); + } + None => { + let mut state = ConnectionState::new(src_addr, dst_addr); + state.receive_packet_payload(payload); + let connection_id = { + self.next_connection_id += 2; // only odd ids. + self.next_connection_id + }; + self.id_lookup.insert((src_addr, dst_addr), connection_id); + self.connections.insert(connection_id, state); + permit.send(TransportEvent::ConnectionEstablished { + connection_id, + src_addr, + dst_addr, + tunnel_info, + }); + } + }; + + Ok(()) + } + + pub fn poll_delay(&mut self) -> Option { + if self.connections.is_empty() { + None + } else { + Some(Duration::from_secs(5)) + } + } + + pub fn poll(&mut self) { + // Creating an iterator removes expired entries. + self.connections.iter(); + self.id_lookup.iter(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::{IpAddr, Ipv4Addr}; + + const SRC: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 54321); + const DST: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 80); + + #[test] + fn test_connection_state_recv_recv_read_read() { + let mut state = ConnectionState::new(SRC, DST); + state.receive_packet_payload(vec![1, 2, 3]); + state.receive_packet_payload(vec![4, 5, 6]); + let (tx, rx) = oneshot::channel(); + state.read_packet_payload(tx); + assert_eq!(vec![1, 2, 3], rx.blocking_recv().unwrap()); + let (tx, rx) = oneshot::channel(); + state.read_packet_payload(tx); + assert_eq!(vec![4, 5, 6], rx.blocking_recv().unwrap()); + } + + #[test] + fn test_connection_state_read_recv_recv() { + let mut state = ConnectionState::new(SRC, DST); + let (tx, rx) = oneshot::channel(); + state.read_packet_payload(tx); + state.receive_packet_payload(vec![1, 2, 3]); + state.receive_packet_payload(vec![4, 5, 6]); + assert_eq!(vec![1, 2, 3], rx.blocking_recv().unwrap()); + } + + #[test] + fn test_connection_state_close_recv_read() { + let mut state = ConnectionState::new(SRC, DST); + let (tx, rx) = oneshot::channel(); + state.close(); + state.receive_packet_payload(vec![1, 2, 3]); + state.read_packet_payload(tx); + assert!(rx.blocking_recv().is_err()); + } + + #[test] + fn test_connection_state_read_close_recv() { + let mut state = ConnectionState::new(SRC, DST); + let (tx, rx) = oneshot::channel(); + state.read_packet_payload(tx); + state.close(); + state.receive_packet_payload(vec![1, 2, 3]); + assert!(rx.blocking_recv().is_err()); + } +} From 200ae9b7481e1543deb21c9d986e0b2df70ee1bf Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 14 Dec 2023 22:01:01 +0100 Subject: [PATCH 03/20] remove old udp code --- mitmproxy-rs/src/datagram_transport.rs | 81 ----------------- mitmproxy-rs/src/lib.rs | 2 - mitmproxy-rs/src/task.rs | 47 ++-------- src/messages.rs | 12 --- src/network/core.rs | 8 -- src/network/tests.rs | 115 ++++++++++++------------- src/packet_sources/macos.rs | 5 +- 7 files changed, 64 insertions(+), 206 deletions(-) delete mode 100644 mitmproxy-rs/src/datagram_transport.rs diff --git a/mitmproxy-rs/src/datagram_transport.rs b/mitmproxy-rs/src/datagram_transport.rs deleted file mode 100644 index 48aa809b..00000000 --- a/mitmproxy-rs/src/datagram_transport.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::net::SocketAddr; - -use pyo3::prelude::*; -use pyo3::types::PyTuple; -use tokio::sync::mpsc; - -use crate::util::get_tunnel_info; -use mitmproxy::messages::{TransportCommand, TunnelInfo}; - -use crate::util::{event_queue_unavailable, py_to_socketaddr, socketaddr_to_py}; - -#[pyclass(module = "mitmproxy_rs")] -#[derive(Debug)] -pub struct DatagramTransport { - pub event_tx: mpsc::UnboundedSender, - pub peername: SocketAddr, - pub sockname: SocketAddr, - pub tunnel_info: TunnelInfo, -} - -#[pymethods] -impl DatagramTransport { - #[pyo3(text_signature = "(self, data, addr=None)")] - fn sendto(&self, data: Vec, addr: Option<&PyTuple>) -> PyResult<()> { - let dst_addr = match addr { - Some(addr) => py_to_socketaddr(addr)?, - None => self.peername, - }; - self.event_tx - .send(TransportCommand::SendDatagram { - data, - src_addr: self.sockname, - dst_addr, - }) - .map_err(event_queue_unavailable)?; - Ok(()) - } - - /// Query the UDP transport for details of the underlying network connection. - /// - /// Supported values: `peername`, `sockname`, `original_src`, and `original_dst`. - #[pyo3(text_signature = "(self, name, default=None)")] - fn get_extra_info( - &self, - py: Python, - name: String, - default: Option, - ) -> PyResult { - match name.as_str() { - "peername" => Ok(socketaddr_to_py(py, self.peername)), - "sockname" => Ok(socketaddr_to_py(py, self.sockname)), - _ => get_tunnel_info(&self.tunnel_info, py, name, default), - } - } - - /// Close the UDP transport. - /// This method is a no-op and only exists for API compatibility with DatagramTransport - fn close(&mut self) -> PyResult<()> { - Ok(()) - } - - /// Check whether this UDP transport is being closed. - /// This method is a no-op and only exists for API compatibility with DatagramTransport - fn is_closing(&self) -> PyResult { - Ok(false) - } - - /// Wait until the UDP transport is closed. - /// This method is a no-op and only exists for API compatibility with DatagramTransport - fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { - pyo3_asyncio::tokio::future_into_py(py, std::future::ready(Ok(()))) - } - - fn get_protocol(self_: Py) -> Py { - self_ - } - - fn drain<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { - pyo3_asyncio::tokio::future_into_py(py, std::future::ready(Ok(()))) - } -} diff --git a/mitmproxy-rs/src/lib.rs b/mitmproxy-rs/src/lib.rs index 4fedff42..6335c7ac 100644 --- a/mitmproxy-rs/src/lib.rs +++ b/mitmproxy-rs/src/lib.rs @@ -5,7 +5,6 @@ use std::sync::RwLock; use once_cell::sync::Lazy; use pyo3::{exceptions::PyException, prelude::*}; -mod datagram_transport; mod process_info; mod server; mod task; @@ -55,7 +54,6 @@ pub fn mitmproxy_rs(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(process_info::executable_icon, m)?)?; m.add_class::()?; - m.add_class::()?; // Import platform-specific modules here so that missing dependencies are raising immediately. #[cfg(target_os = "macos")] diff --git a/mitmproxy-rs/src/task.rs b/mitmproxy-rs/src/task.rs index 06d3c9fd..9d8603f7 100644 --- a/mitmproxy-rs/src/task.rs +++ b/mitmproxy-rs/src/task.rs @@ -2,15 +2,13 @@ use std::collections::HashMap; use std::sync::Arc; use anyhow::Result; -use pyo3::{prelude::*, types::PyBytes}; +use pyo3::prelude::*; use tokio::sync::{broadcast, mpsc, Mutex}; use mitmproxy::messages::{TransportCommand, TransportEvent}; -use crate::datagram_transport::DatagramTransport; use crate::tcp_stream::TcpStream; use crate::tcp_stream::TcpStreamState; -use crate::util::socketaddr_to_py; pub struct PyInteropTask { py_loop: PyObject, @@ -58,7 +56,7 @@ impl PyInteropTask { dst_addr, tunnel_info, } => { - // initialize new TCP stream + // initialize new TCP/UDP stream let stream = TcpStream { connection_id, state: TcpStreamState::Open, @@ -70,12 +68,16 @@ impl PyInteropTask { let mut conns = active_tcp_connections.lock().await; - // spawn TCP connection handler coroutine + // spawn TCP/UDP connection handler coroutine if let Err(err) = Python::with_gil(|py| -> Result<(), PyErr> { let stream = stream.into_py(py); // calling Python coroutine object yields an awaitable object - let coro = self.py_tcp_handler.call1(py, (stream, ))?; + let coro = if connection_id & 1 == 1 { + self.py_udp_handler.call1(py, (stream, ))? + } else { + self.py_tcp_handler.call1(py, (stream, ))? + }; // convert Python awaitable into Rust Future let locals = pyo3_asyncio::TaskLocals::new(self.py_loop.as_ref(py)) @@ -101,39 +103,6 @@ impl PyInteropTask { log::error!("Failed to spawn TCP connection handler coroutine:\n{}", err); }; }, - TransportEvent::DatagramReceived { - data, - src_addr, - dst_addr, - tunnel_info, - } => { - - let transport = DatagramTransport { - event_tx: self.transport_commands.clone(), - peername: src_addr, - sockname: dst_addr, - tunnel_info, - }; - - Python::with_gil(|py| { - let transport = transport.into_py(py); - let bytes: Py = PyBytes::new(py, &data).into_py(py); - - if let Err(err) = self.py_loop.call_method1( - py, - "call_soon_threadsafe", - ( - self.py_udp_handler.as_ref(py), - transport, - bytes, - socketaddr_to_py(py, src_addr), - socketaddr_to_py(py, dst_addr), - ), - ) { - err.print(py); - } - }); - }, } } else { // channel was closed diff --git a/src/messages.rs b/src/messages.rs index fe436605..b3a61338 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -46,12 +46,6 @@ pub enum TransportEvent { dst_addr: SocketAddr, tunnel_info: TunnelInfo, }, - DatagramReceived { - data: Vec, - src_addr: SocketAddr, - dst_addr: SocketAddr, - tunnel_info: TunnelInfo, - }, } /// Commands that are sent by the Python side to the TCP stack. @@ -61,12 +55,6 @@ pub enum TransportCommand { WriteData(ConnectionId, Vec), DrainWriter(ConnectionId, oneshot::Sender<()>), CloseConnection(ConnectionId, bool), - // FIXME: Remove - SendDatagram { - data: Vec, - src_addr: SocketAddr, - dst_addr: SocketAddr, - }, } /// Generic IPv4/IPv6 packet type that wraps smoltcp's IPv4 and IPv6 packet buffers diff --git a/src/network/core.rs b/src/network/core.rs index d35f4868..1cd44865 100644 --- a/src/network/core.rs +++ b/src/network/core.rs @@ -99,14 +99,6 @@ impl<'a> NetworkStack<'a> { TransportCommand::CloseConnection(id, half_close) => match id & 1 == 1 { true => self.udp.close_connection(id), false => self.tcp.close_connection(id, half_close), - }, - TransportCommand::SendDatagram { - data: _, - src_addr: _, - dst_addr: _, - } => { - // TODO remove - log::error!("Error: SendDatagram is deprecated."); } }; } diff --git a/src/network/tests.rs b/src/network/tests.rs index bb938aeb..5fcd217e 100755 --- a/src/network/tests.rs +++ b/src/network/tests.rs @@ -1,4 +1,4 @@ -use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::{Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use anyhow::{anyhow, Result}; use smoltcp::{phy::ChecksumCapabilities, wire::*}; @@ -339,66 +339,70 @@ async fn do_nothing() -> Result<()> { mock.stop().await } -#[tokio::test] -async fn receive_ipv4_datagram() -> Result<()> { - init_logger(); +async fn receive_datagram( + packet: IpPacket, + src_addr: SocketAddr, + dst_addr: SocketAddr, +) -> Result<()> { let mut mock = MockNetwork::init().await?; - let src_addr = Ipv4Address([10, 0, 0, 1]); - let dst_addr = Ipv4Address([10, 0, 0, 42]); - let data = "hello world!".as_bytes(); - - let udp_ip_packet = build_ipv4_udp_packet(src_addr, dst_addr, 1234, 31337, data); - - mock.push_wg_packet(udp_ip_packet.into()).await?; + mock.push_wg_packet(packet).await?; let event = mock.pull_py_event().await.unwrap(); - if let TransportEvent::DatagramReceived { - data: recv_data, + let TransportEvent::ConnectionEstablished { + connection_id, src_addr: recv_src_addr, dst_addr: recv_dst_addr, tunnel_info: _, - } = event - { - assert_eq!(data, recv_data); - assert_eq!(IpAddress::Ipv4(src_addr), recv_src_addr.ip().into()); - assert_eq!(IpAddress::Ipv4(dst_addr), recv_dst_addr.ip().into()); - } else { - return Err(anyhow!("Wrong Transport event emitted!")); - } + } = event; + + assert_eq!(src_addr, recv_src_addr); + assert_eq!(dst_addr, recv_dst_addr); + + let (tx, rx) = oneshot::channel(); + mock.push_py_command(TransportCommand::ReadData(connection_id, 0, tx)) + .await?; + assert_eq!(rx.await?, b"hello world!"); mock.stop().await } +#[tokio::test] +async fn receive_ipv4_datagram() -> Result<()> { + init_logger(); + let src_addr = Ipv4Address([10, 0, 0, 1]); + let dst_addr = Ipv4Address([10, 0, 0, 42]); + let data = "hello world!".as_bytes(); + + let udp_ip_packet = build_ipv4_udp_packet(src_addr, dst_addr, 1234, 31337, data); + + receive_datagram( + udp_ip_packet.into(), + "10.0.0.1:1234".parse()?, + "10.0.0.42:31337".parse()?, + ) + .await +} + #[tokio::test] async fn receive_ipv6_datagram() -> Result<()> { init_logger(); - let mut mock = MockNetwork::init().await?; - let src_addr = Ipv6Address(b"cafecafecafe0001".to_owned()); - let dst_addr = Ipv6Address(b"cafecafecafe0002".to_owned()); + let src: Ipv6Addr = "ca:fe:ca:fe:ca:fe:00:01".parse()?; + let dst: Ipv6Addr = "ca:fe:ca:fe:ca:fe:00:02".parse()?; + + let src_addr = Ipv6Address::from(src); + let dst_addr = Ipv6Address::from(dst); let data = "hello world!".as_bytes(); let udp_ip_packet = build_ipv6_udp_packet(src_addr, dst_addr, 1234, 31337, data); - mock.push_wg_packet(udp_ip_packet.into()).await?; - let event = mock.pull_py_event().await.unwrap(); - - if let TransportEvent::DatagramReceived { - data: recv_data, - src_addr: recv_src_addr, - dst_addr: recv_dst_addr, - tunnel_info: _, - } = event - { - assert_eq!(data, recv_data); - assert_eq!(IpAddress::Ipv6(src_addr), recv_src_addr.ip().into()); - assert_eq!(IpAddress::Ipv6(dst_addr), recv_dst_addr.ip().into()); - } else { - return Err(anyhow!("Wrong Transport event emitted!")); - } - - mock.stop().await + receive_datagram( + udp_ip_packet.into(), + SocketAddr::from((src, 1234)), + SocketAddr::from((dst, 31337)), + ) + .await } #[tokio::test] @@ -554,19 +558,14 @@ async fn tcp_ipv4_connection() -> Result<()> { // expect ConnectionEstablished event let event = mock.pull_py_event().await.unwrap(); - let (tcp_conn_id, tcp_src_sock, tcp_dst_sock) = if let TransportEvent::ConnectionEstablished { + let TransportEvent::ConnectionEstablished { connection_id: tcp_conn_id, src_addr: tcp_src_sock, dst_addr: tcp_dst_sock, tunnel_info: _, - } = event - { - assert_eq!(IpAddress::Ipv4(src_addr), tcp_src_sock.ip().into()); - assert_eq!(IpAddress::Ipv4(dst_addr), tcp_dst_sock.ip().into()); - (tcp_conn_id, tcp_src_sock, tcp_dst_sock) - } else { - return Err(anyhow!("Wrong Transport event emitted!")); - }; + } = event; + assert_eq!(IpAddress::Ipv4(src_addr), tcp_src_sock.ip().into()); + assert_eq!(IpAddress::Ipv4(dst_addr), tcp_dst_sock.ip().into()); // expect TCP data log::debug!("Reading from TCP stream"); @@ -736,20 +735,14 @@ async fn tcp_ipv6_connection() -> Result<()> { // expect ConnectionEstablished event let event = mock.pull_py_event().await.unwrap(); - let (tcp_conn_id, tcp_src_sock, tcp_dst_sock) = if let TransportEvent::ConnectionEstablished { + let TransportEvent::ConnectionEstablished { connection_id: tcp_conn_id, src_addr: tcp_src_sock, dst_addr: tcp_dst_sock, tunnel_info: _, - } = event - { - assert_eq!(IpAddress::Ipv6(src_addr), tcp_src_sock.ip().into()); - assert_eq!(IpAddress::Ipv6(dst_addr), tcp_dst_sock.ip().into()); - - (tcp_conn_id, tcp_src_sock, tcp_dst_sock) - } else { - return Err(anyhow!("Wrong Transport event emitted!")); - }; + } = event; + assert_eq!(IpAddress::Ipv6(src_addr), tcp_src_sock.ip().into()); + assert_eq!(IpAddress::Ipv6(dst_addr), tcp_dst_sock.ip().into()); // expect TCP data log::debug!("Reading from TCP stream"); diff --git a/src/packet_sources/macos.rs b/src/packet_sources/macos.rs index f252c93e..c2d8eeac 100644 --- a/src/packet_sources/macos.rs +++ b/src/packet_sources/macos.rs @@ -320,6 +320,7 @@ impl ConnectionTask { SocketAddr::try_from(dst_addr).context("invalid socket address")? }; + todo!(); if let Err(e) = self.events.try_send(TransportEvent::DatagramReceived { data: packet.data, src_addr: local_addr, @@ -334,6 +335,7 @@ impl ConnectionTask { break; }; match command { + todo!(); TransportCommand::SendDatagram { data, src_addr, dst_addr } => { assert_eq!(dst_addr, local_addr); let packet = ipc::UdpPacket { @@ -426,9 +428,6 @@ impl ConnectionTask { if !half_close { break; } - }, - TransportCommand::SendDatagram { .. } => { - bail!("TCP connection received UDP event: {command:?}"); } } } From c31dccb3ce913bfad6f2b1a36a7c373ac9f45499 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 15 Dec 2023 11:28:44 +0100 Subject: [PATCH 04/20] unify stream class --- Cargo.lock | 10 +- Cargo.toml | 2 +- mitmproxy-rs/mitmproxy_rs/__init__.pyi | 6 +- mitmproxy-rs/src/lib.rs | 4 +- mitmproxy-rs/src/stream.rs | 189 +++++++++++++++++++++ mitmproxy-rs/src/task.rs | 117 ++++++------- mitmproxy-rs/src/tcp_stream.rs | 154 ----------------- mitmproxy-rs/src/util.rs | 1 + src/messages.rs | 95 +++++++---- src/network/core.rs | 36 ++-- src/network/icmp.rs | 10 +- src/network/tcp.rs | 19 +-- src/network/tests.rs | 218 +++++++++---------------- src/network/udp.rs | 21 ++- src/network/virtual_device.rs | 6 +- src/packet_sources/windows.rs | 4 +- src/packet_sources/wireguard.rs | 8 +- 17 files changed, 453 insertions(+), 447 deletions(-) create mode 100644 mitmproxy-rs/src/stream.rs delete mode 100644 mitmproxy-rs/src/tcp_stream.rs diff --git a/Cargo.lock b/Cargo.lock index 9d1245dc..b54a0069 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1257,7 +1257,7 @@ checksum = "9106e1d747ffd48e6be5bb2d97fa706ed25b144fbee4d5c02eae110cd8d6badd" [[package]] name = "macos-certificate-truster" -version = "0.4.0" +version = "0.5.0" dependencies = [ "apple-security-framework", ] @@ -1333,7 +1333,7 @@ dependencies = [ [[package]] name = "mitm-wg-test-client" -version = "0.4.0" +version = "0.5.0" dependencies = [ "anyhow", "boringtun", @@ -1344,7 +1344,7 @@ dependencies = [ [[package]] name = "mitmproxy" -version = "0.4.0" +version = "0.5.0" dependencies = [ "anyhow", "apple-security-framework", @@ -1373,7 +1373,7 @@ dependencies = [ [[package]] name = "mitmproxy_rs" -version = "0.4.0" +version = "0.5.0" dependencies = [ "anyhow", "boringtun", @@ -2625,7 +2625,7 @@ dependencies = [ [[package]] name = "windows-redirector" -version = "0.4.0" +version = "0.5.0" dependencies = [ "anyhow", "env_logger", diff --git a/Cargo.toml b/Cargo.toml index 5a8731b1..24958b28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ authors = [ "Fabio Valentini ", "Maximilian Hils ", ] -version = "0.4.0" +version = "0.5.0" publish = false repository = "https://github.com/mitmproxy/mitmproxy-rs" edition = "2021" diff --git a/mitmproxy-rs/mitmproxy_rs/__init__.pyi b/mitmproxy-rs/mitmproxy_rs/__init__.pyi index c4d2c2ba..30dbb806 100644 --- a/mitmproxy-rs/mitmproxy_rs/__init__.pyi +++ b/mitmproxy-rs/mitmproxy_rs/__init__.pyi @@ -12,7 +12,7 @@ async def start_wireguard_server( port: int, private_key: str, peer_public_keys: list[str], - handle_connection: Callable[[TcpStream], Awaitable[None]], + handle_connection: Callable[[Stream], Awaitable[None]], receive_datagram: Callable[[DatagramTransport, bytes, tuple[str, int], tuple[str, int]], None], ) -> WireGuardServer: ... @@ -29,7 +29,7 @@ def pubkey(private_key: str) -> str: ... # Windows async def start_local_redirector( - handle_connection: Callable[[TcpStream], Awaitable[None]], + handle_connection: Callable[[Stream], Awaitable[None]], receive_datagram: Callable[[DatagramTransport, bytes, tuple[str, int], tuple[str, int]], None], ) -> LocalRedirector: ... @@ -48,7 +48,7 @@ def remove_cert() -> None: ... # TCP / UDP @final -class TcpStream: +class Stream: async def read(self, n: int) -> bytes: ... def write(self, data: bytes): ... async def drain(self) -> None: ... diff --git a/mitmproxy-rs/src/lib.rs b/mitmproxy-rs/src/lib.rs index 6335c7ac..66aca938 100644 --- a/mitmproxy-rs/src/lib.rs +++ b/mitmproxy-rs/src/lib.rs @@ -7,8 +7,8 @@ use pyo3::{exceptions::PyException, prelude::*}; mod process_info; mod server; +mod stream; mod task; -mod tcp_stream; mod util; static LOGGER_INITIALIZED: Lazy> = Lazy::new(|| RwLock::new(false)); @@ -53,7 +53,7 @@ pub fn mitmproxy_rs(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(process_info::executable_icon, m)?)?; - m.add_class::()?; + m.add_class::()?; // Import platform-specific modules here so that missing dependencies are raising immediately. #[cfg(target_os = "macos")] diff --git a/mitmproxy-rs/src/stream.rs b/mitmproxy-rs/src/stream.rs new file mode 100644 index 00000000..0bffbc8d --- /dev/null +++ b/mitmproxy-rs/src/stream.rs @@ -0,0 +1,189 @@ +use std::net::SocketAddr; + +use once_cell::sync::Lazy; + +use pyo3::{exceptions::PyOSError, intern, prelude::*, types::PyBytes}; + +use tokio::sync::{ + mpsc::{self}, + oneshot::{self}, +}; + +use mitmproxy::messages::{ConnectionId, TransportCommand, TunnelInfo}; + +use crate::util::{event_queue_unavailable, get_tunnel_info, socketaddr_to_py}; + +#[derive(Debug)] +pub enum StreamState { + Open, + HalfClosed, + Closed, +} + +/// An individual TCP or UDP stream with an API that is similar to +/// [`asyncio.StreamReader` and `asyncio.StreamWriter`](https://docs.python.org/3/library/asyncio-stream.html) +/// from the Python standard library. +#[pyclass(module = "mitmproxy_rs")] +#[derive(Debug)] +pub struct Stream { + pub connection_id: ConnectionId, + pub state: StreamState, + pub event_tx: mpsc::UnboundedSender, + pub peername: SocketAddr, + pub sockname: SocketAddr, + pub tunnel_info: TunnelInfo, +} + +/// Do *not* hold the GIL while accessing. +static EMPTY_BYTES: Lazy> = + Lazy::new(|| Python::with_gil(|py| PyBytes::new(py, &[]).into_py(py))); + +#[pymethods] +impl Stream { + /// Read up to `n` bytes of a TCP stream, or a single UDP packet (`n` is ignored for UDP). + /// + /// Return an empty `bytes` object if the connection was closed + /// or the server has been shut down. + fn read<'p>(&self, py: Python<'p>, n: u32) -> PyResult<&'p PyAny> { + match self.state { + StreamState::Open | StreamState::HalfClosed => { + let (tx, rx) = oneshot::channel(); + + self.event_tx + .send(TransportCommand::ReadData(self.connection_id, n, tx)) + .ok(); // if this fails tx is dropped and rx.await will error. + + pyo3_asyncio::tokio::future_into_py(py, async move { + if let Ok(data) = rx.await { + Python::with_gil(|py| Ok(PyBytes::new(py, &data).into_py(py))) + } else { + Ok(EMPTY_BYTES.clone()) + } + }) + } + StreamState::Closed => { + pyo3_asyncio::tokio::future_into_py(py, async move { Ok(EMPTY_BYTES.clone()) }) + } + } + } + + /// Write bytes onto the TCP stream, or send a single UDP packet. + /// + /// For TCP, this queues the data into a write buffer. To wait until the stream can be written + /// to again, await `Stream.drain`. + /// + /// Raises: + /// OSError if the connection has previously been closed or if server has been shut down. + fn write(&self, data: Vec) -> PyResult<()> { + match self.state { + StreamState::Open => self + .event_tx + .send(TransportCommand::WriteData(self.connection_id, data)) + .map_err(event_queue_unavailable), + StreamState::HalfClosed => Err(PyOSError::new_err("connection closed")), + StreamState::Closed => Err(PyOSError::new_err("connection closed")), + } + } + + /// Wait until the stream can be written to again. + /// + /// Raises: + /// OSError if the stream is closed or the server has been shut down. + fn drain<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + let (tx, rx) = oneshot::channel(); + + self.event_tx + .send(TransportCommand::DrainWriter(self.connection_id, tx)) + .map_err(event_queue_unavailable)?; + + pyo3_asyncio::tokio::future_into_py(py, async move { + rx.await + .map_err(|_| PyOSError::new_err("connection closed")) + }) + } + + /// Close the TCP stream after flushing the write buffer. + /// This method is a no-op for UDP streams, but may still raise an error (see below). + /// + /// Raises: + /// OSError if the server has been shut down. + fn write_eof(&mut self) -> PyResult<()> { + match self.state { + StreamState::Open => { + self.state = StreamState::HalfClosed; + self.event_tx + .send(TransportCommand::CloseConnection(self.connection_id, true)) + .map_err(event_queue_unavailable) + } + StreamState::HalfClosed => Ok(()), + StreamState::Closed => Ok(()), + } + } + + /// Close the stream for both reading and writing. + /// + /// Raises: + /// OSError if the server has been shut down. + fn close(&mut self) -> PyResult<()> { + match self.state { + StreamState::Open | StreamState::HalfClosed => { + self.state = StreamState::Closed; + self.event_tx + .send(TransportCommand::CloseConnection(self.connection_id, false)) + .map_err(event_queue_unavailable) + } + StreamState::Closed => Ok(()), + } + } + + /// Check whether this stream is being closed. + fn is_closing(&self) -> bool { + match self.state { + StreamState::Open => false, + StreamState::HalfClosed | StreamState::Closed => true, + } + } + + /// Wait until the stream is closed (currently a no-op). + fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + pyo3_asyncio::tokio::future_into_py(py, std::future::ready(Ok(()))) + } + + /// Query the stream for details of the underlying network connection. + /// + /// Supported values: + /// - Always available: `transport_protocol`, `peername`, `sockname` + /// - WireGuard mode: `original_dst`, `original_src` + /// - Local redirector mode: `pid`, `process_name`, `remote_endpoint` + #[pyo3(text_signature = "(self, name, default=None)")] + fn get_extra_info( + &self, + py: Python, + name: String, + default: Option, + ) -> PyResult { + match name.as_str() { + "transport_protocol" => Ok(PyObject::from(if self.connection_id.is_tcp() { + intern!(py, "tcp") + } else { + intern!(py, "udp") + })), + "peername" => Ok(socketaddr_to_py(py, self.peername)), + "sockname" => Ok(socketaddr_to_py(py, self.sockname)), + _ => get_tunnel_info(&self.tunnel_info, py, name, default), + } + } + + fn __repr__(&self) -> String { + format!( + "Stream({}, peer={}, sock={}, tunnel_info={:?})", + self.connection_id, self.peername, self.sockname, self.tunnel_info, + ) + } +} + +impl Drop for Stream { + fn drop(&mut self) { + self.close().ok(); + } +} diff --git a/mitmproxy-rs/src/task.rs b/mitmproxy-rs/src/task.rs index 9d8603f7..d0081989 100644 --- a/mitmproxy-rs/src/task.rs +++ b/mitmproxy-rs/src/task.rs @@ -3,12 +3,13 @@ use std::sync::Arc; use anyhow::Result; use pyo3::prelude::*; +use pyo3_asyncio::TaskLocals; use tokio::sync::{broadcast, mpsc, Mutex}; use mitmproxy::messages::{TransportCommand, TransportEvent}; -use crate::tcp_stream::TcpStream; -use crate::tcp_stream::TcpStreamState; +use crate::stream::Stream; +use crate::stream::StreamState; pub struct PyInteropTask { py_loop: PyObject, @@ -40,7 +41,11 @@ impl PyInteropTask { } pub async fn run(mut self) -> Result<()> { - let active_tcp_connections = Arc::new(Mutex::new(HashMap::new())); + let active_streams = Arc::new(Mutex::new(HashMap::new())); + + let locals = Python::with_gil(|py| { + TaskLocals::new(self.py_loop.as_ref(py)).copy_context(self.py_loop.as_ref(py).py()) + })?; loop { tokio::select!( @@ -48,65 +53,61 @@ impl PyInteropTask { _ = self.sd_watcher.recv() => break, // wait for network events event = self.transport_events.recv() => { - if let Some(event) = event { - match event { - TransportEvent::ConnectionEstablished { + let Some(event) = event else { + // channel was closed + break; + }; + match event { + TransportEvent::ConnectionEstablished { + connection_id, + src_addr, + dst_addr, + tunnel_info, + } => { + // initialize new stream + let stream = Stream { connection_id, - src_addr, - dst_addr, + state: StreamState::Open, + event_tx: self.transport_commands.clone(), + peername: src_addr, + sockname: dst_addr, tunnel_info, - } => { - // initialize new TCP/UDP stream - let stream = TcpStream { - connection_id, - state: TcpStreamState::Open, - event_tx: self.transport_commands.clone(), - peername: src_addr, - sockname: dst_addr, - tunnel_info, + }; + + let mut conns = active_streams.lock().await; + + // spawn connection handler coroutine + if let Err(err) = Python::with_gil(|py| -> Result<(), PyErr> { + let stream = stream.into_py(py); + + // calling Python coroutine object yields an awaitable object + let coro = if connection_id.is_tcp() { + self.py_tcp_handler.call1(py, (stream, ))? + } else { + self.py_udp_handler.call1(py, (stream, ))? }; - let mut conns = active_tcp_connections.lock().await; - - // spawn TCP/UDP connection handler coroutine - if let Err(err) = Python::with_gil(|py| -> Result<(), PyErr> { - let stream = stream.into_py(py); - - // calling Python coroutine object yields an awaitable object - let coro = if connection_id & 1 == 1 { - self.py_udp_handler.call1(py, (stream, ))? - } else { - self.py_tcp_handler.call1(py, (stream, ))? - }; - - // convert Python awaitable into Rust Future - let locals = pyo3_asyncio::TaskLocals::new(self.py_loop.as_ref(py)) - .copy_context(self.py_loop.as_ref(py).py())?; - let future = pyo3_asyncio::into_future_with_locals(&locals, coro.as_ref(py))?; - - // run Future on a new Tokio task - - let handle = { - let active_tcp_connections = active_tcp_connections.clone(); - tokio::spawn(async move { - if let Err(err) = future.await { - log::error!("TCP connection handler coroutine raised an exception:\n{}", err) - } - active_tcp_connections.lock().await.remove(&connection_id); - }) - }; - - conns.insert(connection_id, handle); - - Ok(()) - }) { - log::error!("Failed to spawn TCP connection handler coroutine:\n{}", err); + // convert Python awaitable into Rust Future + let future = pyo3_asyncio::into_future_with_locals(&locals, coro.as_ref(py))?; + + // run Future on a new Tokio task + let handle = { + let active_streams = active_streams.clone(); + tokio::spawn(async move { + if let Err(err) = future.await { + log::error!("TCP connection handler coroutine raised an exception:\n{}", err) + } + active_streams.lock().await.remove(&connection_id); + }) }; - }, - } - } else { - // channel was closed - break; + + conns.insert(connection_id, handle); + + Ok(()) + }) { + log::error!("Failed to spawn TCP connection handler coroutine:\n{}", err); + }; + }, } }, ); @@ -114,7 +115,7 @@ impl PyInteropTask { log::debug!("Python interoperability task shutting down."); - while let Some((_, handle)) = active_tcp_connections.lock().await.drain().next() { + while let Some((_, handle)) = active_streams.lock().await.drain().next() { if handle.is_finished() { // Future is already finished: just await; // Python exceptions are already logged by the wrapper coroutine diff --git a/mitmproxy-rs/src/tcp_stream.rs b/mitmproxy-rs/src/tcp_stream.rs deleted file mode 100644 index ab451b04..00000000 --- a/mitmproxy-rs/src/tcp_stream.rs +++ /dev/null @@ -1,154 +0,0 @@ -use std::net::SocketAddr; - -use pyo3::{exceptions::PyOSError, prelude::*, types::PyBytes}; - -use tokio::sync::{ - mpsc::{self}, - oneshot::{self, error::RecvError}, -}; - -use mitmproxy::messages::{ConnectionId, TransportCommand, TunnelInfo}; - -use crate::util::{event_queue_unavailable, get_tunnel_info, socketaddr_to_py}; - -#[derive(Debug)] -pub enum TcpStreamState { - Open, - HalfClosed, - Closed, -} - -/// An individual TCP stream with an API that is similar to -/// [`asyncio.StreamReader` and `asyncio.StreamWriter`](https://docs.python.org/3/library/asyncio-stream.html) -/// from the Python standard library. -#[pyclass(module = "mitmproxy_rs")] -#[derive(Debug)] -pub struct TcpStream { - pub connection_id: ConnectionId, - pub state: TcpStreamState, - pub event_tx: mpsc::UnboundedSender, - pub peername: SocketAddr, - pub sockname: SocketAddr, - pub tunnel_info: TunnelInfo, -} - -#[pymethods] -impl TcpStream { - /// Read up to `n` bytes from the TCP stream. - /// - /// If the connection was closed, this returns an empty `bytes` object. - fn read<'p>(&self, py: Python<'p>, n: u32) -> PyResult<&'p PyAny> { - let (tx, rx) = oneshot::channel(); - - self.event_tx - .send(TransportCommand::ReadData(self.connection_id, n, tx)) - .map_err(event_queue_unavailable)?; - - pyo3_asyncio::tokio::future_into_py::<_, Py>(py, async move { - let data = rx.await.map_err(connection_closed)?; - Python::with_gil(|py| Ok(PyBytes::new(py, &data).into_py(py))) - }) - } - - /// Write bytes onto the TCP stream. - /// - /// This queues the data into a write buffer. To wait until the TCP connection can be written to - /// again, use the `TcpStream.drain` coroutine. - fn write(&self, data: Vec) -> PyResult<()> { - self.event_tx - .send(TransportCommand::WriteData(self.connection_id, data)) - .map_err(event_queue_unavailable)?; - - Ok(()) - } - - /// Wait until the TCP stream can be written to again. - fn drain<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { - let (tx, rx) = oneshot::channel(); - - self.event_tx - .send(TransportCommand::DrainWriter(self.connection_id, tx)) - .map_err(event_queue_unavailable)?; - - pyo3_asyncio::tokio::future_into_py(py, async move { - rx.await.map_err(connection_closed)?; - Ok(()) - }) - } - - /// Close the stream after flushing the write buffer. - fn write_eof(&mut self) -> PyResult<()> { - match self.state { - TcpStreamState::Open => { - self.state = TcpStreamState::HalfClosed; - self.event_tx - .send(TransportCommand::CloseConnection(self.connection_id, true)) - .map_err(event_queue_unavailable) - } - TcpStreamState::HalfClosed => Ok(()), - TcpStreamState::Closed => Ok(()), - } - } - - /// Close the TCP stream and the underlying socket immediately. - fn close(&mut self) -> PyResult<()> { - match self.state { - TcpStreamState::Open | TcpStreamState::HalfClosed => { - self.state = TcpStreamState::Closed; - self.event_tx - .send(TransportCommand::CloseConnection(self.connection_id, false)) - .map_err(event_queue_unavailable) - } - TcpStreamState::Closed => Ok(()), - } - } - - /// Check whether this TCP stream is being closed. - fn is_closing(&self) -> bool { - match self.state { - TcpStreamState::Open => false, - TcpStreamState::HalfClosed | TcpStreamState::Closed => true, - } - } - - /// Wait until the TCP stream is closed (currently a no-op). - fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { - pyo3_asyncio::tokio::future_into_py(py, std::future::ready(Ok(()))) - } - - /// Query the TCP stream for details of the underlying network connection. - /// - /// Supported values: `peername`, `sockname`, `original_dst`, and `original_src`. - #[pyo3(text_signature = "(self, name, default=None)")] - fn get_extra_info( - &self, - py: Python, - name: String, - default: Option, - ) -> PyResult { - match name.as_str() { - "peername" => Ok(socketaddr_to_py(py, self.peername)), - "sockname" => Ok(socketaddr_to_py(py, self.sockname)), - _ => get_tunnel_info(&self.tunnel_info, py, name, default), - } - } - - fn __repr__(&self) -> String { - format!( - "TcpStream({}, peer={}, sock={}, tunnel_info={:?})", - self.connection_id, self.peername, self.sockname, self.tunnel_info, - ) - } -} - -impl Drop for TcpStream { - fn drop(&mut self) { - if let Err(error) = self.close() { - log::debug!("Failed to close TCP stream during clean up: {}", error); - } - } -} - -pub fn connection_closed(_: RecvError) -> PyErr { - PyOSError::new_err("connection closed") -} diff --git a/mitmproxy-rs/src/util.rs b/mitmproxy-rs/src/util.rs index aa11dc22..5a3f90e3 100644 --- a/mitmproxy-rs/src/util.rs +++ b/mitmproxy-rs/src/util.rs @@ -40,6 +40,7 @@ pub fn socketaddr_to_py(py: Python, s: SocketAddr) -> PyObject { } } +#[allow(dead_code)] pub fn py_to_socketaddr(t: &PyTuple) -> PyResult { if t.len() == 2 { let host = t.get_item(0)?.downcast::()?; diff --git a/src/messages.rs b/src/messages.rs index b3a61338..129d8b59 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,3 +1,4 @@ +use std::fmt; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use anyhow::{anyhow, Result}; @@ -24,7 +25,7 @@ pub enum TunnelInfo { #[derive(Debug)] pub enum NetworkEvent { ReceivePacket { - packet: IpPacket, + packet: SmolPacket, tunnel_info: TunnelInfo, }, } @@ -32,10 +33,48 @@ pub enum NetworkEvent { /// Commands that are sent by the TCP stack to WireGuard. #[derive(Debug)] pub enum NetworkCommand { - SendPacket(IpPacket), + SendPacket(SmolPacket), } -pub type ConnectionId = usize; +pub struct ConnectionIdGenerator(usize); +impl ConnectionIdGenerator { + pub fn tcp() -> Self { + Self(2) + } + pub fn udp() -> Self { + Self(1) + } + pub fn next_id(&mut self) -> ConnectionId { + let ret = ConnectionId(self.0); + self.0 += 2; + ret + } +} + +#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Hash)] +pub struct ConnectionId(usize); +impl ConnectionId { + pub fn is_tcp(&self) -> bool { + self.0 & 1 == 0 + } + pub fn unassigned() -> Self { + ConnectionId(0) + } +} +impl fmt::Display for ConnectionId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} +impl fmt::Debug for ConnectionId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self.is_tcp() { + write!(f, "{}#TCP", self.0) + } else { + write!(f, "{}#UDP", self.0) + } + } +} /// Events that are sent by the TCP stack to Python. #[derive(Debug)] @@ -58,36 +97,36 @@ pub enum TransportCommand { } /// Generic IPv4/IPv6 packet type that wraps smoltcp's IPv4 and IPv6 packet buffers -#[derive(Debug)] -pub enum IpPacket { +#[derive(Debug, Clone)] +pub enum SmolPacket { V4(Ipv4Packet>), V6(Ipv6Packet>), } -impl From>> for IpPacket { +impl From>> for SmolPacket { fn from(packet: Ipv4Packet>) -> Self { - IpPacket::V4(packet) + SmolPacket::V4(packet) } } -impl From>> for IpPacket { +impl From>> for SmolPacket { fn from(packet: Ipv6Packet>) -> Self { - IpPacket::V6(packet) + SmolPacket::V6(packet) } } -impl TryInto for IpPacket { +impl TryInto for SmolPacket { type Error = internet_packet::ParseError; fn try_into(self) -> std::result::Result { match self { - IpPacket::V4(packet) => InternetPacket::try_from(packet), - IpPacket::V6(packet) => InternetPacket::try_from(packet), + SmolPacket::V4(packet) => InternetPacket::try_from(packet), + SmolPacket::V6(packet) => InternetPacket::try_from(packet), } } } -impl TryFrom> for IpPacket { +impl TryFrom> for SmolPacket { type Error = anyhow::Error; fn try_from(value: Vec) -> Result { @@ -96,32 +135,32 @@ impl TryFrom> for IpPacket { } match value[0] >> 4 { - 4 => Ok(IpPacket::V4(Ipv4Packet::new_checked(value)?)), - 6 => Ok(IpPacket::V6(Ipv6Packet::new_checked(value)?)), + 4 => Ok(SmolPacket::V4(Ipv4Packet::new_checked(value)?)), + 6 => Ok(SmolPacket::V6(Ipv6Packet::new_checked(value)?)), _ => Err(anyhow!("Not an IP packet: {:?}", value)), } } } -impl IpPacket { +impl SmolPacket { pub fn src_ip(&self) -> IpAddr { match self { - IpPacket::V4(packet) => IpAddr::V4(Ipv4Addr::from(packet.src_addr())), - IpPacket::V6(packet) => IpAddr::V6(Ipv6Addr::from(packet.src_addr())), + SmolPacket::V4(packet) => IpAddr::V4(Ipv4Addr::from(packet.src_addr())), + SmolPacket::V6(packet) => IpAddr::V6(Ipv6Addr::from(packet.src_addr())), } } pub fn dst_ip(&self) -> IpAddr { match self { - IpPacket::V4(packet) => IpAddr::V4(Ipv4Addr::from(packet.dst_addr())), - IpPacket::V6(packet) => IpAddr::V6(Ipv6Addr::from(packet.dst_addr())), + SmolPacket::V4(packet) => IpAddr::V4(Ipv4Addr::from(packet.dst_addr())), + SmolPacket::V6(packet) => IpAddr::V6(Ipv6Addr::from(packet.dst_addr())), } } pub fn transport_protocol(&self) -> IpProtocol { match self { - IpPacket::V4(packet) => packet.next_header(), - IpPacket::V6(packet) => { + SmolPacket::V4(packet) => packet.next_header(), + SmolPacket::V6(packet) => { log::debug!("TODO: Implement IPv6 next_header logic."); packet.next_header() } @@ -130,22 +169,22 @@ impl IpPacket { pub fn payload_mut(&mut self) -> &mut [u8] { match self { - IpPacket::V4(packet) => packet.payload_mut(), - IpPacket::V6(packet) => packet.payload_mut(), + SmolPacket::V4(packet) => packet.payload_mut(), + SmolPacket::V6(packet) => packet.payload_mut(), } } pub fn into_inner(self) -> Vec { match self { - IpPacket::V4(packet) => packet.into_inner(), - IpPacket::V6(packet) => packet.into_inner(), + SmolPacket::V4(packet) => packet.into_inner(), + SmolPacket::V6(packet) => packet.into_inner(), } } pub fn fill_ip_checksum(&mut self) { match self { - IpPacket::V4(packet) => packet.fill_checksum(), - IpPacket::V6(_) => (), + SmolPacket::V4(packet) => packet.fill_checksum(), + SmolPacket::V6(_) => (), } } } diff --git a/src/network/core.rs b/src/network/core.rs index 1cd44865..b8f94218 100644 --- a/src/network/core.rs +++ b/src/network/core.rs @@ -8,7 +8,7 @@ use anyhow::Result; use smoltcp::wire::IpProtocol; use tokio::sync::mpsc::{Permit, Sender}; -use crate::messages::{IpPacket, NetworkCommand, NetworkEvent, TransportCommand, TransportEvent}; +use crate::messages::{NetworkCommand, NetworkEvent, SmolPacket, TransportCommand, TransportEvent}; use crate::network::icmp::{handle_icmpv4_echo_request, handle_icmpv6_echo_request}; use crate::network::tcp::TcpHandler; @@ -41,7 +41,7 @@ impl<'a> NetworkStack<'a> { } => (packet, tunnel_info), }; - if let IpPacket::V4(p) = &packet { + if let SmolPacket::V4(p) = &packet { if !p.verify_checksum() { log::warn!("Received invalid IP packet (checksum error)."); return Ok(()); @@ -62,7 +62,7 @@ impl<'a> NetworkStack<'a> { } } - fn receive_packet_icmp(&mut self, packet: IpPacket) -> Result<()> { + fn receive_packet_icmp(&mut self, packet: SmolPacket) -> Result<()> { // Some apps check network connectivity by sending ICMP pings. ICMP traffic is currently // swallowed by mitmproxy_rs, which makes them believe that there is no network connectivity. // Generating fake ICMP replies as a simple workaround. @@ -70,8 +70,8 @@ impl<'a> NetworkStack<'a> { if let Ok(permit) = self.net_tx.try_reserve() { // Generating and sending fake replies for ICMP echo requests. Ignoring all other ICMP types. let response_packet = match packet { - IpPacket::V4(packet) => handle_icmpv4_echo_request(packet), - IpPacket::V6(packet) => handle_icmpv6_echo_request(packet), + SmolPacket::V4(packet) => handle_icmpv4_echo_request(packet), + SmolPacket::V6(packet) => handle_icmpv6_echo_request(packet), }; if let Some(response_packet) = response_packet { permit.send(NetworkCommand::SendPacket(response_packet)); @@ -84,22 +84,22 @@ impl<'a> NetworkStack<'a> { pub fn handle_transport_command(&mut self, command: TransportCommand) { match command { - TransportCommand::ReadData(id, n, tx) => match id & 1 == 1 { - true => self.udp.read_data(id, tx), - false => self.tcp.read_data(id, n, tx), + TransportCommand::ReadData(id, n, tx) => match id.is_tcp() { + true => self.tcp.read_data(id, n, tx), + false => self.udp.read_data(id, tx), }, - TransportCommand::WriteData(id, buf) => match id & 1 == 1 { - true => self.udp.write_data(id, buf), - false => self.tcp.write_data(id, buf), + TransportCommand::WriteData(id, buf) => match id.is_tcp() { + true => self.tcp.write_data(id, buf), + false => self.udp.write_data(id, buf), }, - TransportCommand::DrainWriter(id, tx) => match id & 1 == 1 { - true => self.udp.drain_writer(id, tx), - false => self.tcp.drain_writer(id, tx), + TransportCommand::DrainWriter(id, tx) => match id.is_tcp() { + true => self.tcp.drain_writer(id, tx), + false => self.udp.drain_writer(id, tx), + }, + TransportCommand::CloseConnection(id, half_close) => match id.is_tcp() { + true => self.tcp.close_connection(id, half_close), + false => self.udp.close_connection(id), }, - TransportCommand::CloseConnection(id, half_close) => match id & 1 == 1 { - true => self.udp.close_connection(id), - false => self.tcp.close_connection(id, half_close), - } }; } diff --git a/src/network/icmp.rs b/src/network/icmp.rs index 318fa1f2..79d1c3fb 100644 --- a/src/network/icmp.rs +++ b/src/network/icmp.rs @@ -1,4 +1,4 @@ -use crate::messages::IpPacket; +use crate::messages::SmolPacket; use smoltcp::phy::ChecksumCapabilities; use smoltcp::wire::{ Icmpv4Message, Icmpv4Packet, Icmpv4Repr, Icmpv6Message, Icmpv6Packet, Icmpv6Repr, IpAddress, @@ -7,7 +7,7 @@ use smoltcp::wire::{ pub(super) fn handle_icmpv4_echo_request( mut input_packet: Ipv4Packet>, -) -> Option { +) -> Option { let src_addr = input_packet.src_addr(); let dst_addr = input_packet.dst_addr(); @@ -46,7 +46,7 @@ pub(super) fn handle_icmpv4_echo_request( let buf = vec![0u8; ip_repr.buffer_len() + icmp_repr.buffer_len()]; let mut output_ipv4_packet = Ipv4Packet::new_unchecked(buf); ip_repr.emit(&mut output_ipv4_packet, &ChecksumCapabilities::default()); - let mut output_ip_packet = IpPacket::from(output_ipv4_packet); + let mut output_ip_packet = SmolPacket::from(output_ipv4_packet); icmp_repr.emit( &mut Icmpv4Packet::new_unchecked(output_ip_packet.payload_mut()), &ChecksumCapabilities::default(), @@ -56,7 +56,7 @@ pub(super) fn handle_icmpv4_echo_request( pub(super) fn handle_icmpv6_echo_request( mut input_packet: Ipv6Packet>, -) -> Option { +) -> Option { let src_addr = input_packet.src_addr(); let dst_addr = input_packet.dst_addr(); @@ -95,7 +95,7 @@ pub(super) fn handle_icmpv6_echo_request( let buf = vec![0u8; ip_repr.buffer_len() + icmp_repr.buffer_len()]; let mut output_ipv6_packet = Ipv6Packet::new_unchecked(buf); ip_repr.emit(&mut output_ipv6_packet); - let mut output_ip_packet = IpPacket::from(output_ipv6_packet); + let mut output_ip_packet = SmolPacket::from(output_ipv6_packet); icmp_repr.emit( // Directing fake reply back to the original source address. &IpAddress::from(dst_addr), diff --git a/src/network/tcp.rs b/src/network/tcp.rs index b699fb34..ed26d9c6 100644 --- a/src/network/tcp.rs +++ b/src/network/tcp.rs @@ -18,7 +18,9 @@ use tokio::sync::{ oneshot, }; -use crate::messages::{ConnectionId, IpPacket, NetworkCommand, TransportEvent, TunnelInfo}; +use crate::messages::{ + ConnectionId, ConnectionIdGenerator, NetworkCommand, SmolPacket, TransportEvent, TunnelInfo, +}; use super::virtual_device::VirtualDevice; @@ -39,7 +41,7 @@ struct SocketData { } pub struct TcpHandler<'a> { - next_connection_id: ConnectionId, + connection_id_generator: ConnectionIdGenerator, iface: Interface, device: VirtualDevice, sockets: SocketSet<'a>, @@ -74,14 +76,14 @@ impl<'a> TcpHandler<'a> { sockets: SocketSet::new(Vec::new()), socket_data: HashMap::new(), active_connections: HashSet::new(), - next_connection_id: 0, + connection_id_generator: ConnectionIdGenerator::tcp(), remove_conns: Vec::new(), } } pub fn receive_packet( &mut self, - mut packet: IpPacket, + mut packet: SmolPacket, tunnel_info: TunnelInfo, permit: Permit<'_, TransportEvent>, ) -> Result<()> { @@ -126,10 +128,7 @@ impl<'a> TcpHandler<'a> { let handle = self.sockets.add(socket); - let connection_id = { - self.next_connection_id += 2; // only even ids. - self.next_connection_id - }; + let connection_id = self.connection_id_generator.next_id(); let data = SocketData { handle, @@ -191,8 +190,8 @@ impl<'a> TcpHandler<'a> { pub fn close_connection(&mut self, id: ConnectionId, _half_close: bool) { if let Some(data) = self.socket_data.get_mut(&id) { - // smoltcp does not have a good way to do "SHUT_RDWR". We can't call .abort() - // here because that sends a RST instead of a FIN (and breaks + // smoltcp does not have a good way to do a full close ("SHUT_RDWR"). We can't call + // .abort() here because that sends a RST instead of a FIN (and breaks // retransmissions of the connection close packet). Alternatively, we could manually // set a timer on .close() and then forcibly .abort() once the timer expires (see // tcp-abort branch). This incurs a bit of unnecessary complexity, so we try something diff --git a/src/network/tests.rs b/src/network/tests.rs index 5fcd217e..bec43ca0 100755 --- a/src/network/tests.rs +++ b/src/network/tests.rs @@ -1,6 +1,7 @@ -use std::net::{Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; +use std::net::{Ipv6Addr, SocketAddr}; use anyhow::{anyhow, Result}; +use internet_packet::InternetPacket; use smoltcp::{phy::ChecksumCapabilities, wire::*}; use tokio::{ sync::{ @@ -12,7 +13,7 @@ use tokio::{ }; use crate::messages::{ - IpPacket, NetworkCommand, NetworkEvent, TransportCommand, TransportEvent, TunnelInfo, + NetworkCommand, NetworkEvent, SmolPacket, TransportCommand, TransportEvent, TunnelInfo, }; use super::task::NetworkTask; @@ -63,7 +64,7 @@ impl MockNetwork { self.handle.await? } - async fn push_wg_packet(&self, packet: IpPacket) -> Result<()> { + async fn push_smol_packet(&self, packet: SmolPacket) -> Result<()> { let tunnel_info = TunnelInfo::WireGuard { src_addr: "192.168.86.134:12345".parse()?, dst_addr: "0.0.0.0:0".parse()?, @@ -76,20 +77,22 @@ impl MockNetwork { Ok(()) } + async fn pull_smol_packet(&mut self) -> SmolPacket { + let NetworkCommand::SendPacket(packet) = + self.smol_to_wg_rx.recv().await.expect("No packet received"); + packet + } + + async fn pull_packet(&mut self) -> InternetPacket { + let packet = self.pull_smol_packet().await; + packet.try_into().unwrap() + } + async fn push_py_command(&self, command: TransportCommand) -> Result<()> { self.py_to_smol_tx.send(command)?; Ok(()) } - async fn pull_wg_packet(&mut self) -> Option { - self.smol_to_wg_rx - .recv() - .await - .map(|command| match command { - NetworkCommand::SendPacket(packet) => packet, - }) - } - async fn pull_py_event(&mut self) -> Option { self.smol_to_py_rx.recv().await } @@ -339,14 +342,14 @@ async fn do_nothing() -> Result<()> { mock.stop().await } -async fn receive_datagram( - packet: IpPacket, +async fn udp_read_write( + packet: SmolPacket, src_addr: SocketAddr, dst_addr: SocketAddr, ) -> Result<()> { let mut mock = MockNetwork::init().await?; - mock.push_wg_packet(packet).await?; + mock.push_smol_packet(packet.clone()).await?; let event = mock.pull_py_event().await.unwrap(); let TransportEvent::ConnectionEstablished { @@ -364,11 +367,30 @@ async fn receive_datagram( .await?; assert_eq!(rx.await?, b"hello world!"); + mock.push_py_command(TransportCommand::WriteData( + connection_id, + b"HELLO WORLD!".to_vec(), + )) + .await?; + let response = mock.pull_packet().await; + assert_eq!(response.payload(), b"HELLO WORLD!"); + assert_eq!(response.src(), dst_addr); + assert_eq!(response.dst(), src_addr); + + mock.push_py_command(TransportCommand::CloseConnection(connection_id, false)) + .await?; + mock.push_smol_packet(packet.clone()).await?; + + let (tx, rx) = oneshot::channel(); + mock.push_py_command(TransportCommand::ReadData(connection_id, 0, tx)) + .await?; + assert!(rx.await.is_err()); + mock.stop().await } #[tokio::test] -async fn receive_ipv4_datagram() -> Result<()> { +async fn ivp4_udp() -> Result<()> { init_logger(); let src_addr = Ipv4Address([10, 0, 0, 1]); let dst_addr = Ipv4Address([10, 0, 0, 42]); @@ -376,7 +398,7 @@ async fn receive_ipv4_datagram() -> Result<()> { let udp_ip_packet = build_ipv4_udp_packet(src_addr, dst_addr, 1234, 31337, data); - receive_datagram( + udp_read_write( udp_ip_packet.into(), "10.0.0.1:1234".parse()?, "10.0.0.42:31337".parse()?, @@ -385,7 +407,7 @@ async fn receive_ipv4_datagram() -> Result<()> { } #[tokio::test] -async fn receive_ipv6_datagram() -> Result<()> { +async fn ipv6_udp() -> Result<()> { init_logger(); let src: Ipv6Addr = "ca:fe:ca:fe:ca:fe:00:01".parse()?; @@ -397,7 +419,7 @@ async fn receive_ipv6_datagram() -> Result<()> { let udp_ip_packet = build_ipv6_udp_packet(src_addr, dst_addr, 1234, 31337, data); - receive_datagram( + udp_read_write( udp_ip_packet.into(), SocketAddr::from((src, 1234)), SocketAddr::from((dst, 31337)), @@ -405,96 +427,6 @@ async fn receive_ipv6_datagram() -> Result<()> { .await } -#[tokio::test] -async fn send_ipv4_datagram() -> Result<()> { - init_logger(); - let mut mock = MockNetwork::init().await?; - - let src_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Address([10, 0, 0, 42]).into(), 31337)); - let dst_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Address([10, 0, 0, 1]).into(), 1234)); - let data = "hello world!".as_bytes(); - - mock.push_py_command(TransportCommand::SendDatagram { - data: data.to_vec(), - src_addr, - dst_addr, - }) - .await?; - - let mut udp_ip_packet = match mock.pull_wg_packet().await.unwrap() { - IpPacket::V4(packet) => packet, - IpPacket::V6(_) => return Err(anyhow!("Received unexpected IPv6 packet!")), - }; - - let udp_ip_src_addr = udp_ip_packet.src_addr(); - let udp_ip_dst_addr = udp_ip_packet.dst_addr(); - - let udp_packet = UdpPacket::new_unchecked(udp_ip_packet.payload_mut() as &[u8]); - let udp_repr = UdpRepr::parse( - &udp_packet, - &udp_ip_src_addr.into(), - &udp_ip_dst_addr.into(), - &ChecksumCapabilities::default(), - ) - .unwrap(); - - assert_eq!(udp_packet.payload(), data); - assert_eq!(udp_repr.src_port, 31337); - assert_eq!(udp_repr.dst_port, 1234); - - mock.stop().await -} - -#[tokio::test] -async fn send_ipv6_datagram() -> Result<()> { - init_logger(); - let mut mock = MockNetwork::init().await?; - - let src_addr = SocketAddr::V6(SocketAddrV6::new( - Ipv6Address(b"cafecafecafe0001".to_owned()).into(), - 31337, - 0, - 0, - )); - let dst_addr = SocketAddr::V6(SocketAddrV6::new( - Ipv6Address(b"cafecafecafe0002".to_owned()).into(), - 1234, - 0, - 0, - )); - let data = "hello world!".as_bytes(); - - mock.push_py_command(TransportCommand::SendDatagram { - data: data.to_vec(), - src_addr, - dst_addr, - }) - .await?; - - let mut udp_ip_packet = match mock.pull_wg_packet().await.unwrap() { - IpPacket::V6(packet) => packet, - IpPacket::V4(_) => return Err(anyhow!("Received unexpected IPv4 packet!")), - }; - - let udp_ip_src_addr = udp_ip_packet.src_addr(); - let udp_ip_dst_addr = udp_ip_packet.dst_addr(); - - let udp_packet = UdpPacket::new_unchecked(udp_ip_packet.payload_mut() as &[u8]); - let udp_repr = UdpRepr::parse( - &udp_packet, - &udp_ip_src_addr.into(), - &udp_ip_dst_addr.into(), - &ChecksumCapabilities::default(), - ) - .unwrap(); - - assert_eq!(udp_packet.payload(), data); - assert_eq!(udp_repr.src_port, 31337); - assert_eq!(udp_repr.dst_port, 1234); - - mock.stop().await -} - #[tokio::test] async fn tcp_ipv4_connection() -> Result<()> { init_logger(); @@ -517,12 +449,12 @@ async fn tcp_ipv4_connection() -> Result<()> { None, &[], ); - mock.push_wg_packet(tcp_ip_syn_packet.into()).await?; + mock.push_smol_packet(tcp_ip_syn_packet.into()).await?; // expect TCP SYN/ACK - let mut tcp_synack_ip_packet = match mock.pull_wg_packet().await.unwrap() { - IpPacket::V4(packet) => packet, - IpPacket::V6(_) => return Err(anyhow!("Received unexpected IPv6 packet!")), + let mut tcp_synack_ip_packet = match mock.pull_smol_packet().await { + SmolPacket::V4(packet) => packet, + SmolPacket::V6(_) => return Err(anyhow!("Received unexpected IPv6 packet!")), }; let synack_src_addr = tcp_synack_ip_packet.src_addr(); @@ -553,7 +485,7 @@ async fn tcp_ipv4_connection() -> Result<()> { Some(ack), data, ); - mock.push_wg_packet(tcp_ip_ack_packet.into()).await?; + mock.push_smol_packet(tcp_ip_ack_packet.into()).await?; // expect ConnectionEstablished event let event = mock.pull_py_event().await.unwrap(); @@ -590,12 +522,12 @@ async fn tcp_ipv4_connection() -> Result<()> { drain_rx.await?; // expect TCP/IP packets - mock.pull_wg_packet().await.unwrap(); - mock.pull_wg_packet().await.unwrap(); + mock.pull_smol_packet().await; + mock.pull_smol_packet().await; - let mut tcp_resp_ip_packet = match mock.pull_wg_packet().await.unwrap() { - IpPacket::V4(packet) => packet, - IpPacket::V6(_) => return Err(anyhow!("Received unexpected IPv6 packet!")), + let mut tcp_resp_ip_packet = match mock.pull_smol_packet().await { + SmolPacket::V4(packet) => packet, + SmolPacket::V6(_) => return Err(anyhow!("Received unexpected IPv6 packet!")), }; let tcp_ip_resp_src_addr = tcp_resp_ip_packet.src_addr(); @@ -626,9 +558,9 @@ async fn tcp_ipv4_connection() -> Result<()> { .await?; // expect TCP FIN - let mut tcp_fin_ip_packet = match mock.pull_wg_packet().await.unwrap() { - IpPacket::V4(packet) => packet, - IpPacket::V6(_) => return Err(anyhow!("Received unexpected IPv6 packet!")), + let mut tcp_fin_ip_packet = match mock.pull_smol_packet().await { + SmolPacket::V4(packet) => packet, + SmolPacket::V6(_) => return Err(anyhow!("Received unexpected IPv6 packet!")), }; let tcp_ip_fin_src_addr = tcp_fin_ip_packet.src_addr(); @@ -667,7 +599,7 @@ async fn tcp_ipv4_connection() -> Result<()> { Some(ack), &[], ); - mock.push_wg_packet(tcp_ip_syn_packet.into()).await?; + mock.push_smol_packet(tcp_ip_syn_packet.into()).await?; mock.stop().await } @@ -694,12 +626,12 @@ async fn tcp_ipv6_connection() -> Result<()> { None, &[], ); - mock.push_wg_packet(tcp_ip_syn_packet.into()).await?; + mock.push_smol_packet(tcp_ip_syn_packet.into()).await?; // expect TCP SYN/ACK - let mut tcp_synack_ip_packet = match mock.pull_wg_packet().await.unwrap() { - IpPacket::V6(packet) => packet, - IpPacket::V4(_) => return Err(anyhow!("Received unexpected IPv4 packet!")), + let mut tcp_synack_ip_packet = match mock.pull_smol_packet().await { + SmolPacket::V6(packet) => packet, + SmolPacket::V4(_) => return Err(anyhow!("Received unexpected IPv4 packet!")), }; let synack_src_addr = tcp_synack_ip_packet.src_addr(); @@ -730,7 +662,7 @@ async fn tcp_ipv6_connection() -> Result<()> { Some(ack), data, ); - mock.push_wg_packet(tcp_ip_ack_packet.into()).await?; + mock.push_smol_packet(tcp_ip_ack_packet.into()).await?; // expect ConnectionEstablished event let event = mock.pull_py_event().await.unwrap(); @@ -767,12 +699,12 @@ async fn tcp_ipv6_connection() -> Result<()> { drain_rx.await?; // expect TCP/IP packets - mock.pull_wg_packet().await.unwrap(); - mock.pull_wg_packet().await.unwrap(); + mock.pull_smol_packet().await; + mock.pull_smol_packet().await; - let mut tcp_resp_ip_packet = match mock.pull_wg_packet().await.unwrap() { - IpPacket::V6(packet) => packet, - IpPacket::V4(_) => return Err(anyhow!("Received unexpected IPv4 packet!")), + let mut tcp_resp_ip_packet = match mock.pull_smol_packet().await { + SmolPacket::V6(packet) => packet, + SmolPacket::V4(_) => return Err(anyhow!("Received unexpected IPv4 packet!")), }; let tcp_ip_resp_src_addr = tcp_resp_ip_packet.src_addr(); @@ -803,9 +735,9 @@ async fn tcp_ipv6_connection() -> Result<()> { .await?; // expect TCP FIN - let mut tcp_fin_ip_packet = match mock.pull_wg_packet().await.unwrap() { - IpPacket::V6(packet) => packet, - IpPacket::V4(_) => return Err(anyhow!("Received unexpected IPv4 packet!")), + let mut tcp_fin_ip_packet = match mock.pull_smol_packet().await { + SmolPacket::V6(packet) => packet, + SmolPacket::V4(_) => return Err(anyhow!("Received unexpected IPv4 packet!")), }; let tcp_ip_fin_src_addr = tcp_fin_ip_packet.src_addr(); @@ -844,7 +776,7 @@ async fn tcp_ipv6_connection() -> Result<()> { Some(ack), &[], ); - mock.push_wg_packet(tcp_ip_syn_packet.into()).await?; + mock.push_smol_packet(tcp_ip_syn_packet.into()).await?; mock.stop().await } @@ -860,11 +792,11 @@ async fn receive_icmp4_echo() -> Result<()> { let icmp_echo_ip_packet = build_icmp4_echo_packet(src_addr, dst_addr, 42, 31337, data); - mock.push_wg_packet(icmp_echo_ip_packet.into()).await?; + mock.push_smol_packet(icmp_echo_ip_packet.into()).await?; - let response = mock.pull_wg_packet().await.unwrap(); + let response = mock.pull_smol_packet().await; - if let IpPacket::V4(mut response) = response { + if let SmolPacket::V4(mut response) = response { // Checking that source and destination addresses were flipped and data was the same. assert_eq!(src_addr, response.dst_addr()); assert_eq!(dst_addr, response.src_addr()); @@ -898,11 +830,11 @@ async fn receive_icmp6_echo() -> Result<()> { let icmp_echo_ip_packet = build_icmp6_echo_packet(src_addr, dst_addr, 42, 31337, data); - mock.push_wg_packet(icmp_echo_ip_packet.into()).await?; + mock.push_smol_packet(icmp_echo_ip_packet.into()).await?; - let response = mock.pull_wg_packet().await.unwrap(); + let response = mock.pull_smol_packet().await; - if let IpPacket::V6(mut response) = response { + if let SmolPacket::V6(mut response) = response { // Checking that source and destination addresses were flipped and data was the same. assert_eq!(src_addr, response.dst_addr()); assert_eq!(dst_addr, response.src_addr()); diff --git a/src/network/udp.rs b/src/network/udp.rs index 92def8b7..3fc7c0d0 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -6,7 +6,9 @@ use lru_time_cache::LruCache; use tokio::sync::mpsc::{Permit, Sender}; use tokio::sync::oneshot; -use crate::messages::{ConnectionId, IpPacket, NetworkCommand, TransportEvent, TunnelInfo}; +use crate::messages::{ + ConnectionId, ConnectionIdGenerator, NetworkCommand, SmolPacket, TransportEvent, TunnelInfo, +}; use anyhow::Result; use internet_packet::InternetPacket; use smoltcp::phy::ChecksumCapabilities; @@ -65,7 +67,7 @@ impl ConnectionState { } pub struct UdpHandler { - next_connection_id: ConnectionId, + connection_id_generator: ConnectionIdGenerator, id_lookup: LruCache<(SocketAddr, SocketAddr), ConnectionId>, connections: LruCache, net_tx: Sender, @@ -83,7 +85,7 @@ impl UdpHandler { connections, id_lookup, net_tx, - next_connection_id: 1, + connection_id_generator: ConnectionIdGenerator::udp(), } } @@ -144,12 +146,12 @@ impl UdpHandler { IpRepr::Ipv4(repr) => { let mut packet = Ipv4Packet::new_unchecked(buf); repr.emit(&mut packet, &ChecksumCapabilities::default()); - IpPacket::from(packet) + SmolPacket::from(packet) } IpRepr::Ipv6(repr) => { let mut packet = Ipv6Packet::new_unchecked(buf); repr.emit(&mut packet); - IpPacket::from(packet) + SmolPacket::from(packet) } }; @@ -177,7 +179,7 @@ impl UdpHandler { pub fn receive_packet( &mut self, - packet: IpPacket, + packet: SmolPacket, tunnel_info: TunnelInfo, permit: Permit<'_, TransportEvent>, ) -> Result<()> { @@ -195,7 +197,7 @@ impl UdpHandler { .id_lookup .get(&(src_addr, dst_addr)) .cloned() - .unwrap_or(0); // guaranteed to not exist. + .unwrap_or(ConnectionId::unassigned()); let payload = packet.payload().to_vec(); @@ -206,10 +208,7 @@ impl UdpHandler { None => { let mut state = ConnectionState::new(src_addr, dst_addr); state.receive_packet_payload(payload); - let connection_id = { - self.next_connection_id += 2; // only odd ids. - self.next_connection_id - }; + let connection_id = self.connection_id_generator.next_id(); self.id_lookup.insert((src_addr, dst_addr), connection_id); self.connections.insert(connection_id, state); permit.send(TransportEvent::ConnectionEstablished { diff --git a/src/network/virtual_device.rs b/src/network/virtual_device.rs index e3323a09..62ce2966 100755 --- a/src/network/virtual_device.rs +++ b/src/network/virtual_device.rs @@ -6,7 +6,7 @@ use smoltcp::{ }; use tokio::sync::mpsc::{Permit, Sender}; -use crate::messages::{IpPacket, NetworkCommand}; +use crate::messages::{NetworkCommand, SmolPacket}; /// A virtual smoltcp device into which we manually feed packets using /// [VirtualDevice::receive_packet] and which send outgoing packets to a channel. @@ -23,7 +23,7 @@ impl VirtualDevice { } } - pub fn receive_packet(&mut self, packet: IpPacket) { + pub fn receive_packet(&mut self, packet: SmolPacket) { self.rx_buffer.push_back(packet.into_inner()); } } @@ -75,7 +75,7 @@ impl<'a> TxToken for VirtualTxToken<'a> { let mut buffer = vec![0; len]; let result = f(&mut buffer); - match IpPacket::try_from(buffer) { + match SmolPacket::try_from(buffer) { Ok(packet) => { self.permit.send(NetworkCommand::SendPacket(packet)); } diff --git a/src/packet_sources/windows.rs b/src/packet_sources/windows.rs index f4ab41a9..ea38b0c6 100755 --- a/src/packet_sources/windows.rs +++ b/src/packet_sources/windows.rs @@ -20,7 +20,7 @@ use crate::intercept_conf::InterceptConf; use crate::ipc; use crate::ipc::PacketWithMeta; use crate::messages::{ - IpPacket, NetworkCommand, NetworkEvent, TransportCommand, TransportEvent, TunnelInfo, + NetworkCommand, NetworkEvent, SmolPacket, TransportCommand, TransportEvent, TunnelInfo, }; use crate::network::{add_network_layer, MAX_PACKET_SIZE}; use crate::packet_sources::{PacketSourceConf, PacketSourceTask}; @@ -175,7 +175,7 @@ impl PacketSourceTask for WindowsTask { }; assert_eq!(cursor.position(), len as u64); - let Ok(mut packet) = IpPacket::try_from(data) else { + let Ok(mut packet) = SmolPacket::try_from(data) else { log::error!("Skipping invalid packet: {:?}", &self.buf[..len]); continue; }; diff --git a/src/packet_sources/wireguard.rs b/src/packet_sources/wireguard.rs index 52786311..e13928a2 100755 --- a/src/packet_sources/wireguard.rs +++ b/src/packet_sources/wireguard.rs @@ -21,7 +21,7 @@ use tokio::{ }; use crate::messages::{ - IpPacket, NetworkCommand, NetworkEvent, TransportCommand, TransportEvent, TunnelInfo, + NetworkCommand, NetworkEvent, SmolPacket, TransportCommand, TransportEvent, TunnelInfo, }; use crate::network::{add_network_layer, MAX_PACKET_SIZE}; use crate::packet_sources::{PacketSourceConf, PacketSourceTask}; @@ -287,7 +287,7 @@ impl WireGuardTask { self.peers_by_ip .insert(Ipv4Addr::from(packet.src_addr()).into(), peer); let event = NetworkEvent::ReceivePacket { - packet: IpPacket::from(packet), + packet: SmolPacket::from(packet), tunnel_info: TunnelInfo::WireGuard { src_addr: sender_addr, dst_addr: self.socket.local_addr()?, @@ -319,7 +319,7 @@ impl WireGuardTask { self.peers_by_ip .insert(Ipv6Addr::from(packet.src_addr()).into(), peer); let event = NetworkEvent::ReceivePacket { - packet: IpPacket::from(packet), + packet: SmolPacket::from(packet), tunnel_info: TunnelInfo::WireGuard { src_addr: sender_addr, dst_addr: self.socket.local_addr()?, @@ -341,7 +341,7 @@ impl WireGuardTask { } /// process packets and send the encrypted WireGuard datagrams to the peer. - async fn process_outgoing_packet(&mut self, packet: IpPacket) -> Result<()> { + async fn process_outgoing_packet(&mut self, packet: SmolPacket) -> Result<()> { let peer = self .peers_by_ip .get(&packet.dst_ip()) From ed5820c714f4ba58644185c00a2432e2a1c25c37 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 15 Dec 2023 11:33:55 +0100 Subject: [PATCH 05/20] update type hints --- mitmproxy-rs/mitmproxy_rs/__init__.pyi | 29 +++++++------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/mitmproxy-rs/mitmproxy_rs/__init__.pyi b/mitmproxy-rs/mitmproxy_rs/__init__.pyi index 30dbb806..2aa5449b 100644 --- a/mitmproxy-rs/mitmproxy_rs/__init__.pyi +++ b/mitmproxy-rs/mitmproxy_rs/__init__.pyi @@ -12,8 +12,8 @@ async def start_wireguard_server( port: int, private_key: str, peer_public_keys: list[str], - handle_connection: Callable[[Stream], Awaitable[None]], - receive_datagram: Callable[[DatagramTransport, bytes, tuple[str, int], tuple[str, int]], None], + handle_tcp_stream: Callable[[Stream], Awaitable[None]], + handle_udp_stream: Callable[[Stream], Awaitable[None]], ) -> WireGuardServer: ... @final @@ -26,11 +26,11 @@ def genkey() -> str: ... def pubkey(private_key: str) -> str: ... -# Windows +# Local Redirector async def start_local_redirector( - handle_connection: Callable[[Stream], Awaitable[None]], - receive_datagram: Callable[[DatagramTransport, bytes, tuple[str, int], tuple[str, int]], None], + handle_tcp_stream: Callable[[Stream], Awaitable[None]], + handle_udp_stream: Callable[[Stream], Awaitable[None]], ) -> LocalRedirector: ... @final @@ -41,9 +41,6 @@ class LocalRedirector: def close(self) -> None: ... async def wait_closed(self) -> None: ... -# MacOS -def add_cert(pem: str) -> None: ... -def remove_cert() -> None: ... # TCP / UDP @@ -62,24 +59,14 @@ class Stream: def __repr__(self) -> str: ... -@final -class DatagramTransport: - def sendto(self, data: bytes, addr: tuple[str, int] | None = None): ... - async def drain(self) -> None: ... - - def close(self): ... - def is_closing(self) -> bool: ... - async def wait_closed(self) -> None: ... - - def get_extra_info(self, name: str, default: Any = None) -> Any: ... - def __repr__(self) -> str: ... +# Certificate Installation - def get_protocol(self) -> DatagramTransport: ... +def add_cert(pem: str) -> None: ... +def remove_cert() -> None: ... # Process Info - def active_executables() -> list[Process]: ... def executable_icon(path: Path | str) -> bytes: ... From b5e60c0b4fc4b164c30dd495a610fb2842c4d1f3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 15 Dec 2023 11:38:31 +0100 Subject: [PATCH 06/20] update python api --- mitmproxy-rs/src/server.rs | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/mitmproxy-rs/src/server.rs b/mitmproxy-rs/src/server.rs index 8c75adf6..fb135062 100644 --- a/mitmproxy-rs/src/server.rs +++ b/mitmproxy-rs/src/server.rs @@ -226,14 +226,8 @@ impl WireGuardServer { /// - `port`: The listen port for the WireGuard server. The default port for WireGuard is `51820`. /// - `private_key`: The private X25519 key for the WireGuard server as a base64-encoded string. /// - `peer_public_keys`: List of public X25519 keys for WireGuard peers as base64-encoded strings. -/// - `handle_connection`: A coroutine that will be called for each new `TcpStream`. -/// - `receive_datagram`: A function that will be called for each received UDP datagram. -/// -/// The `receive_datagram` function will be called with the following arguments: -/// -/// - payload of the UDP datagram as `bytes` -/// - source address as `(host: str, port: int)` tuple -/// - destination address as `(host: str, port: int)` tuple +/// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. +/// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. #[pyfunction] pub fn start_wireguard_server( py: Python<'_>, @@ -241,8 +235,8 @@ pub fn start_wireguard_server( port: u16, private_key: String, peer_public_keys: Vec, - handle_connection: PyObject, - receive_datagram: PyObject, + handle_tcp_stream: PyObject, + handle_udp_stream: PyObject, ) -> PyResult<&PyAny> { let private_key = string_to_key(private_key)?; let peer_public_keys = peer_public_keys @@ -256,20 +250,23 @@ pub fn start_wireguard_server( peer_public_keys, }; pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, local_addr) = Server::init(conf, handle_connection, receive_datagram).await?; + let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; Ok(WireGuardServer { server, local_addr }) }) } /// Start an OS-level proxy to intercept traffic from the current machine. /// -/// *Availability: Windows* +/// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. +/// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. +/// +/// *Availability: Windows and macOS* #[pyfunction] #[allow(unused_variables)] pub fn start_local_redirector( py: Python<'_>, - handle_connection: PyObject, - receive_datagram: PyObject, + handle_tcp_stream: PyObject, + handle_udp_stream: PyObject, ) -> PyResult<&PyAny> { #[cfg(windows)] { @@ -282,7 +279,7 @@ pub fn start_local_redirector( } let conf = WindowsConf { executable_path }; pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, conf_tx) = Server::init(conf, handle_connection, receive_datagram).await?; + let (server, conf_tx) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; Ok(LocalRedirector { server, conf_tx }) }) From ae8a005a675809d5dd6f9a8704ce22524322e553 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 15 Dec 2023 11:40:16 +0100 Subject: [PATCH 07/20] cargo fmt --- mitmproxy-rs/src/server.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mitmproxy-rs/src/server.rs b/mitmproxy-rs/src/server.rs index fb135062..365d61ef 100644 --- a/mitmproxy-rs/src/server.rs +++ b/mitmproxy-rs/src/server.rs @@ -279,7 +279,8 @@ pub fn start_local_redirector( } let conf = WindowsConf { executable_path }; pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, conf_tx) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; + let (server, conf_tx) = + Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; Ok(LocalRedirector { server, conf_tx }) }) From 44abe5dfb79886024471480644c3bed4412ba3a0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 15 Dec 2023 16:43:04 +0100 Subject: [PATCH 08/20] wip --- mitmproxy-rs/mitmproxy_rs/__init__.pyi | 24 +- mitmproxy-rs/src/lib.rs | 3 + mitmproxy-rs/src/server.rs | 320 -------------------- mitmproxy-rs/src/server/base.rs | 130 ++++++++ mitmproxy-rs/src/server/local_redirector.rs | 126 ++++++++ mitmproxy-rs/src/server/mod.rs | 8 + mitmproxy-rs/src/server/udp.rs | 78 +++++ mitmproxy-rs/src/server/wireguard.rs | 93 ++++++ mitmproxy-rs/src/task.rs | 2 +- mitmproxy-rs/src/util.rs | 12 +- src/messages.rs | 13 + src/network/core.rs | 41 ++- src/network/mod.rs | 2 +- src/network/task.rs | 4 +- src/network/tcp.rs | 14 +- src/network/udp.rs | 212 +++++++------ src/packet_sources/macos.rs | 2 +- src/packet_sources/mod.rs | 1 + src/packet_sources/udp.rs | 153 ++++++++++ src/packet_sources/windows.rs | 2 +- src/packet_sources/wireguard.rs | 2 +- 21 files changed, 790 insertions(+), 452 deletions(-) delete mode 100644 mitmproxy-rs/src/server.rs create mode 100644 mitmproxy-rs/src/server/base.rs create mode 100644 mitmproxy-rs/src/server/local_redirector.rs create mode 100644 mitmproxy-rs/src/server/mod.rs create mode 100644 mitmproxy-rs/src/server/udp.rs create mode 100644 mitmproxy-rs/src/server/wireguard.rs create mode 100644 src/packet_sources/udp.rs diff --git a/mitmproxy-rs/mitmproxy_rs/__init__.pyi b/mitmproxy-rs/mitmproxy_rs/__init__.pyi index 2aa5449b..a2875b11 100644 --- a/mitmproxy-rs/mitmproxy_rs/__init__.pyi +++ b/mitmproxy-rs/mitmproxy_rs/__init__.pyi @@ -1,8 +1,8 @@ from __future__ import annotations from pathlib import Path -from typing import Awaitable, Callable, Any -from typing import final +from typing import Awaitable, Callable, Any, Literal +from typing import final, overload # WireGuard @@ -21,6 +21,7 @@ class WireGuardServer: def getsockname(self) -> tuple[str, int]: ... def close(self) -> None: ... async def wait_closed(self) -> None: ... + def __repr__(self) -> str: ... def genkey() -> str: ... def pubkey(private_key: str) -> str: ... @@ -42,6 +43,22 @@ class LocalRedirector: async def wait_closed(self) -> None: ... +# UDP Server + +async def start_udp_server( + host: str, + port: int, + handle_udp_stream: Callable[[Stream], Awaitable[None]], +) -> UdpServer: ... + +@final +class UdpServer: + def getsockname(self) -> tuple[str, int]: ... + def close(self) -> None: ... + async def wait_closed(self) -> None: ... + def __repr__(self) -> str: ... + + # TCP / UDP @final @@ -55,6 +72,9 @@ class Stream: def is_closing(self) -> bool: ... async def wait_closed(self) -> None: ... + @overload + def get_extra_info(self, name: Literal["transport_protocol"], default: Any = None) -> Literal["tcp", "udp"]: ... + def get_extra_info(self, name: str, default: Any = None) -> Any: ... def __repr__(self) -> str: ... diff --git a/mitmproxy-rs/src/lib.rs b/mitmproxy-rs/src/lib.rs index 66aca938..add36a15 100644 --- a/mitmproxy-rs/src/lib.rs +++ b/mitmproxy-rs/src/lib.rs @@ -49,6 +49,9 @@ pub fn mitmproxy_rs(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(server::start_local_redirector, m)?)?; m.add_class::()?; + m.add_function(wrap_pyfunction!(server::start_udp_server, m)?)?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(process_info::active_executables, m)?)?; m.add_class::()?; m.add_function(wrap_pyfunction!(process_info::executable_icon, m)?)?; diff --git a/mitmproxy-rs/src/server.rs b/mitmproxy-rs/src/server.rs deleted file mode 100644 index 365d61ef..00000000 --- a/mitmproxy-rs/src/server.rs +++ /dev/null @@ -1,320 +0,0 @@ -use crate::task::PyInteropTask; - -use crate::util::{socketaddr_to_py, string_to_key}; - -use anyhow::Result; -use mitmproxy::intercept_conf::InterceptConf; - -#[cfg(target_os = "macos")] -use mitmproxy::packet_sources::macos::MacosConf; -#[cfg(windows)] -use mitmproxy::packet_sources::windows::WindowsConf; -use mitmproxy::packet_sources::wireguard::WireGuardConf; -use mitmproxy::packet_sources::{PacketSourceConf, PacketSourceTask}; -use mitmproxy::shutdown::ShutdownTask; -use pyo3::prelude::*; -use std::net::SocketAddr; -#[cfg(target_os = "macos")] -use std::path::Path; -#[cfg(windows)] -use std::path::PathBuf; - -use boringtun::x25519::PublicKey; -use tokio::{sync::broadcast, sync::mpsc}; - -#[derive(Debug)] -pub struct Server { - /// channel for notifying subtasks of requested server shutdown - sd_trigger: broadcast::Sender<()>, - /// channel for getting notified of successful server shutdown - sd_barrier: broadcast::Sender<()>, - /// flag to indicate whether server shutdown is in progress - closing: bool, -} - -impl Server { - pub fn close(&mut self) { - if !self.closing { - self.closing = true; - // XXX: Does not really belong here. - #[cfg(target_os = "macos")] - { - if Path::new("/Applications/MitmproxyAppleTunnel.app").exists() { - std::fs::remove_dir_all("/Applications/MitmproxyAppleTunnel.app").expect( - "Failed to remove MitmproxyAppleTunnel.app from Applications folder", - ); - } - } - log::info!("Shutting down."); - // notify tasks to shut down - let _ = self.sd_trigger.send(()); - } - } - - pub fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { - let mut barrier = self.sd_barrier.subscribe(); - pyo3_asyncio::tokio::future_into_py(py, async move { - barrier.recv().await.map_err(|_| { - pyo3::exceptions::PyRuntimeError::new_err("Failed to wait for server shutdown.") - }) - }) - } -} - -impl Server { - /// Set up and initialize a new WireGuard server. - pub async fn init( - packet_source_conf: T, - py_tcp_handler: PyObject, - py_udp_handler: PyObject, - ) -> Result<(Self, T::Data)> - where - T: PacketSourceConf, - { - let typ = packet_source_conf.name(); - log::debug!("Initializing {} ...", typ); - - // initialize channels between the virtual network device and the python interop task - // - only used to notify of incoming connections and datagrams - let (transport_events_tx, transport_events_rx) = mpsc::channel(256); - // - used to send data and to ask for packets - // This channel needs to be unbounded because write() is not async. - let (transport_commands_tx, transport_commands_rx) = mpsc::unbounded_channel(); - - // initialize barriers for handling graceful shutdown - let shutdown = broadcast::channel(1).0; - let shutdown_done = broadcast::channel(1).0; - - let (packet_source_task, data) = packet_source_conf - .build( - transport_events_tx, - transport_commands_rx, - shutdown.subscribe(), - ) - .await?; - - // initialize Python interop task - // Note: The current asyncio event loop needs to be determined here on the main thread. - let py_loop: PyObject = Python::with_gil(|py| { - let py_loop = pyo3_asyncio::tokio::get_current_loop(py)?.into_py(py); - Ok::(py_loop) - })?; - - let py_task = PyInteropTask::new( - py_loop, - transport_commands_tx, - transport_events_rx, - py_tcp_handler, - py_udp_handler, - shutdown.subscribe(), - ); - - // spawn tasks - let wg_handle = tokio::spawn(async move { packet_source_task.run().await }); - let py_handle = tokio::spawn(async move { py_task.run().await }); - - // initialize and run shutdown handler - let sd_task = ShutdownTask::new( - py_handle, - wg_handle, - shutdown.clone(), - shutdown_done.clone(), - ); - tokio::spawn(async move { sd_task.run().await }); - - log::debug!("{} successfully initialized.", typ); - - Ok(( - Server { - sd_trigger: shutdown, - sd_barrier: shutdown_done, - closing: false, - }, - data, - )) - } -} - -impl Drop for Server { - fn drop(&mut self) { - self.close() - } -} - -#[pyclass(module = "mitmproxy_rs")] -#[derive(Debug)] -pub struct LocalRedirector { - server: Server, - conf_tx: mpsc::UnboundedSender, -} - -#[pymethods] -impl LocalRedirector { - /// Return a textual description of the given spec, - /// or raise a ValueError if the spec is invalid. - #[staticmethod] - fn describe_spec(spec: &str) -> PyResult { - InterceptConf::try_from(spec) - .map(|conf| conf.description()) - .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) - } - - /// Set a new intercept spec. - pub fn set_intercept(&self, spec: String) -> PyResult<()> { - let conf = InterceptConf::try_from(spec.as_str())?; - self.conf_tx - .send(conf) - .map_err(crate::util::event_queue_unavailable)?; - Ok(()) - } - - /// Close the OS proxy server. - pub fn close(&mut self) { - self.server.close() - } - - pub fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { - self.server.wait_closed(py) - } -} - -/// A running WireGuard server. -/// -/// A new server can be started by calling the `start_wireguard_server` coroutine. Its public API is intended -/// to be similar to the API provided by -/// [`asyncio.Server`](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.Server) -/// from the Python standard library. -#[pyclass(module = "mitmproxy_rs")] -#[derive(Debug)] -pub struct WireGuardServer { - /// local address of the WireGuard UDP socket - local_addr: SocketAddr, - server: Server, -} - -#[pymethods] -impl WireGuardServer { - /// Request the WireGuard server to gracefully shut down. - /// - /// The server will stop accepting new connections on its UDP socket, but will flush pending - /// outgoing data before shutting down. - pub fn close(&mut self) { - self.server.close() - } - - /// Wait until the WireGuard server has shut down. - /// - /// This coroutine will yield once pending data has been flushed and all server tasks have - /// successfully terminated after calling the `Server.close` method. - pub fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { - self.server.wait_closed(py) - } - - /// Get the local socket address that the WireGuard server is listening on. - pub fn getsockname(&self, py: Python) -> PyObject { - socketaddr_to_py(py, self.local_addr) - } - - pub fn __repr__(&self) -> String { - format!("WireGuardServer({})", self.local_addr) - } -} - -/// Start a WireGuard server that is configured with the given parameters: -/// -/// - `host`: The host address for the WireGuard UDP socket. -/// - `port`: The listen port for the WireGuard server. The default port for WireGuard is `51820`. -/// - `private_key`: The private X25519 key for the WireGuard server as a base64-encoded string. -/// - `peer_public_keys`: List of public X25519 keys for WireGuard peers as base64-encoded strings. -/// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. -/// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. -#[pyfunction] -pub fn start_wireguard_server( - py: Python<'_>, - host: String, - port: u16, - private_key: String, - peer_public_keys: Vec, - handle_tcp_stream: PyObject, - handle_udp_stream: PyObject, -) -> PyResult<&PyAny> { - let private_key = string_to_key(private_key)?; - let peer_public_keys = peer_public_keys - .into_iter() - .map(string_to_key) - .collect::>>()?; - let conf = WireGuardConf { - host, - port, - private_key, - peer_public_keys, - }; - pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; - Ok(WireGuardServer { server, local_addr }) - }) -} - -/// Start an OS-level proxy to intercept traffic from the current machine. -/// -/// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. -/// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. -/// -/// *Availability: Windows and macOS* -#[pyfunction] -#[allow(unused_variables)] -pub fn start_local_redirector( - py: Python<'_>, - handle_tcp_stream: PyObject, - handle_udp_stream: PyObject, -) -> PyResult<&PyAny> { - #[cfg(windows)] - { - let executable_path: PathBuf = py - .import("mitmproxy_windows")? - .call_method0("executable_path")? - .extract()?; - if !executable_path.exists() { - return Err(anyhow::anyhow!("{} does not exist", executable_path.display()).into()); - } - let conf = WindowsConf { executable_path }; - pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, conf_tx) = - Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; - - Ok(LocalRedirector { server, conf_tx }) - }) - } - #[cfg(target_os = "macos")] - { - let destination_path = Path::new("/Applications/Mitmproxy Redirector.app"); - if destination_path.exists() { - log::info!("Using existing mitmproxy redirector app."); - } else { - let filename = py.import("mitmproxy_macos")?.filename()?; - - let source_path = Path::new(filename) - .parent() - .ok_or_else(|| anyhow::anyhow!("invalid path"))? - .join("Mitmproxy Redirector.app.tar"); - - if !source_path.exists() { - return Err(anyhow::anyhow!("{} does not exist", source_path.display()).into()); - } - - // XXX: tokio here? - let redirector_tar = std::fs::File::open(source_path)?; - let mut archive = tar::Archive::new(redirector_tar); - archive.unpack(destination_path.parent().unwrap())?; - } - let conf = MacosConf; - pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, conf_tx) = Server::init(conf, handle_connection, receive_datagram).await?; - Ok(LocalRedirector { server, conf_tx }) - }) - } - #[cfg(not(any(windows, target_os = "macos")))] - Err(pyo3::exceptions::PyNotImplementedError::new_err( - "OS proxy mode is only available on Windows and macOS", - )) -} diff --git a/mitmproxy-rs/src/server/base.rs b/mitmproxy-rs/src/server/base.rs new file mode 100644 index 00000000..aaa2dcc9 --- /dev/null +++ b/mitmproxy-rs/src/server/base.rs @@ -0,0 +1,130 @@ +use crate::task::PyInteropTask; + +use anyhow::Result; + +use mitmproxy::packet_sources::{PacketSourceConf, PacketSourceTask}; +use mitmproxy::shutdown::ShutdownTask; +use pyo3::prelude::*; +#[cfg(target_os = "macos")] +use std::path::Path; + +use tokio::{sync::broadcast, sync::mpsc}; + +#[derive(Debug)] +pub struct Server { + /// channel for notifying subtasks of requested server shutdown + sd_trigger: broadcast::Sender<()>, + /// channel for getting notified of successful server shutdown + sd_barrier: broadcast::Sender<()>, + /// flag to indicate whether server shutdown is in progress + closing: bool, +} + +impl Server { + pub fn close(&mut self) { + if !self.closing { + self.closing = true; + // XXX: Does not really belong here. + #[cfg(target_os = "macos")] + { + if Path::new("/Applications/MitmproxyAppleTunnel.app").exists() { + std::fs::remove_dir_all("/Applications/MitmproxyAppleTunnel.app").expect( + "Failed to remove MitmproxyAppleTunnel.app from Applications folder", + ); + } + } + log::info!("Shutting down."); + // notify tasks to shut down + let _ = self.sd_trigger.send(()); + } + } + + pub fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + let mut barrier = self.sd_barrier.subscribe(); + pyo3_asyncio::tokio::future_into_py(py, async move { + barrier.recv().await.map_err(|_| { + pyo3::exceptions::PyRuntimeError::new_err("Failed to wait for server shutdown.") + }) + }) + } +} + +impl Server { + /// Set up and initialize a new WireGuard server. + pub async fn init( + packet_source_conf: T, + py_tcp_handler: PyObject, + py_udp_handler: PyObject, + ) -> Result<(Self, T::Data)> + where + T: PacketSourceConf, + { + let typ = packet_source_conf.name(); + log::debug!("Initializing {} ...", typ); + + // initialize channels between the virtual network device and the python interop task + // - only used to notify of incoming connections and datagrams + let (transport_events_tx, transport_events_rx) = mpsc::channel(256); + // - used to send data and to ask for packets + // This channel needs to be unbounded because write() is not async. + let (transport_commands_tx, transport_commands_rx) = mpsc::unbounded_channel(); + + // initialize barriers for handling graceful shutdown + let shutdown = broadcast::channel(1).0; + let shutdown_done = broadcast::channel(1).0; + + let (packet_source_task, data) = packet_source_conf + .build( + transport_events_tx, + transport_commands_rx, + shutdown.subscribe(), + ) + .await?; + + // initialize Python interop task + // Note: The current asyncio event loop needs to be determined here on the main thread. + let py_loop: PyObject = Python::with_gil(|py| { + let py_loop = pyo3_asyncio::tokio::get_current_loop(py)?.into_py(py); + Ok::(py_loop) + })?; + + let py_task = PyInteropTask::new( + py_loop, + transport_commands_tx, + transport_events_rx, + py_tcp_handler, + py_udp_handler, + shutdown.subscribe(), + ); + + // spawn tasks + let wg_handle = tokio::spawn(async move { packet_source_task.run().await }); + let py_handle = tokio::spawn(async move { py_task.run().await }); + + // initialize and run shutdown handler + let sd_task = ShutdownTask::new( + py_handle, + wg_handle, + shutdown.clone(), + shutdown_done.clone(), + ); + tokio::spawn(async move { sd_task.run().await }); + + log::debug!("{} successfully initialized.", typ); + + Ok(( + Server { + sd_trigger: shutdown, + sd_barrier: shutdown_done, + closing: false, + }, + data, + )) + } +} + +impl Drop for Server { + fn drop(&mut self) { + self.close() + } +} diff --git a/mitmproxy-rs/src/server/local_redirector.rs b/mitmproxy-rs/src/server/local_redirector.rs new file mode 100644 index 00000000..08a932a8 --- /dev/null +++ b/mitmproxy-rs/src/server/local_redirector.rs @@ -0,0 +1,126 @@ +use mitmproxy::intercept_conf::InterceptConf; + +#[cfg(target_os = "macos")] +use mitmproxy::packet_sources::macos::MacosConf; +#[cfg(windows)] +use mitmproxy::packet_sources::windows::WindowsConf; + +use pyo3::prelude::*; +#[cfg(target_os = "macos")] +use std::path::Path; +#[cfg(windows)] +use std::path::PathBuf; + +use crate::server::base::Server; +use tokio::sync::mpsc; + +#[pyclass(module = "mitmproxy_rs")] +#[derive(Debug)] +pub struct LocalRedirector { + server: Server, + conf_tx: mpsc::UnboundedSender, + spec: String, +} + +#[pymethods] +impl LocalRedirector { + /// Return a textual description of the given spec, + /// or raise a ValueError if the spec is invalid. + #[staticmethod] + fn describe_spec(spec: &str) -> PyResult { + InterceptConf::try_from(spec) + .map(|conf| conf.description()) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string())) + } + + /// Set a new intercept spec. + pub fn set_intercept(&mut self, spec: String) -> PyResult<()> { + let conf = InterceptConf::try_from(spec.as_str())?; + self.spec = spec; + self.conf_tx + .send(conf) + .map_err(crate::util::event_queue_unavailable)?; + Ok(()) + } + + /// Close the OS proxy server. + pub fn close(&mut self) { + self.server.close() + } + + pub fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + self.server.wait_closed(py) + } + + pub fn __repr__(&self) -> String { + format!("Local Redirector({})", self.spec) + } +} + +/// Start an OS-level proxy to intercept traffic from the current machine. +/// +/// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. +/// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. +/// +/// *Availability: Windows and macOS* +#[pyfunction] +#[allow(unused_variables)] +pub fn start_local_redirector( + py: Python<'_>, + handle_tcp_stream: PyObject, + handle_udp_stream: PyObject, +) -> PyResult<&PyAny> { + #[cfg(windows)] + { + let executable_path: PathBuf = py + .import("mitmproxy_windows")? + .call_method0("executable_path")? + .extract()?; + if !executable_path.exists() { + return Err(anyhow::anyhow!("{} does not exist", executable_path.display()).into()); + } + let conf = WindowsConf { executable_path }; + pyo3_asyncio::tokio::future_into_py(py, async move { + let (server, conf_tx) = + Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; + + Ok(LocalRedirector { + server, + conf_tx, + spec: "inactive".to_string(), + }) + }) + } + #[cfg(target_os = "macos")] + { + let destination_path = Path::new("/Applications/Mitmproxy Redirector.app"); + if destination_path.exists() { + log::info!("Using existing mitmproxy redirector app."); + } else { + let filename = py.import("mitmproxy_macos")?.filename()?; + + let source_path = Path::new(filename) + .parent() + .ok_or_else(|| anyhow::anyhow!("invalid path"))? + .join("Mitmproxy Redirector.app.tar"); + + if !source_path.exists() { + return Err(anyhow::anyhow!("{} does not exist", source_path.display()).into()); + } + + // XXX: tokio here? + let redirector_tar = std::fs::File::open(source_path)?; + let mut archive = tar::Archive::new(redirector_tar); + archive.unpack(destination_path.parent().unwrap())?; + } + let conf = MacosConf; + pyo3_asyncio::tokio::future_into_py(py, async move { + let (server, conf_tx) = Server::init(conf, handle_connection, receive_datagram).await?; + Ok(LocalRedirector { server, conf_tx }) + }) + } + #[cfg(not(any(windows, target_os = "macos")))] + Err(pyo3::exceptions::PyNotImplementedError::new_err( + "OS proxy mode is only available on Windows and macOS", + )) +} diff --git a/mitmproxy-rs/src/server/mod.rs b/mitmproxy-rs/src/server/mod.rs new file mode 100644 index 00000000..d95eac80 --- /dev/null +++ b/mitmproxy-rs/src/server/mod.rs @@ -0,0 +1,8 @@ +mod base; +mod local_redirector; +mod udp; +mod wireguard; + +pub use local_redirector::{start_local_redirector, LocalRedirector}; +pub use udp::{start_udp_server, UdpServer}; +pub use wireguard::{start_wireguard_server, WireGuardServer}; diff --git a/mitmproxy-rs/src/server/udp.rs b/mitmproxy-rs/src/server/udp.rs new file mode 100644 index 00000000..9b6d1ab3 --- /dev/null +++ b/mitmproxy-rs/src/server/udp.rs @@ -0,0 +1,78 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +use mitmproxy::packet_sources::udp::UdpConf; + +use pyo3::prelude::*; + +use crate::server::base::Server; +use crate::util::socketaddr_to_py; + +/// A running UDP server. +/// +/// A new server can be started by calling `start_udp_server`. +/// The public API is intended to be similar to the API provided by +/// [`asyncio.Server`](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.Server) +/// from the Python standard library. +#[pyclass(module = "mitmproxy_rs")] +#[derive(Debug)] +pub struct UdpServer { + /// local address of the UDP socket + local_addr: SocketAddr, + server: Server, +} + +#[pymethods] +impl UdpServer { + /// Request the WireGuard server to gracefully shut down. + /// + /// The server will stop accepting new connections on its UDP socket, but will flush pending + /// outgoing data before shutting down. + pub fn close(&mut self) { + self.server.close() + } + + /// Wait until the WireGuard server has shut down. + /// + /// This coroutine will yield once pending data has been flushed and all server tasks have + /// successfully terminated after calling the `Server.close` method. + pub fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + self.server.wait_closed(py) + } + + /// Get the local socket address that the UDP server is listening on. + pub fn getsockname(&self, py: Python) -> PyObject { + socketaddr_to_py(py, self.local_addr) + } + + pub fn __repr__(&self) -> String { + format!("UdpServer({})", self.local_addr) + } +} + +/// Start a UDP server that is configured with the given parameters: +/// +/// - `host`: The host address. +/// - `port`: The listen port. +/// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. +/// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. +#[pyfunction] +pub fn start_udp_server( + py: Python<'_>, + host: String, + port: u16, + handle_udp_stream: PyObject, +) -> PyResult<&PyAny> { + let is_unspecified = host.is_empty(); + let conf = UdpConf { host, port }; + let handle_tcp_stream = py.None(); + pyo3_asyncio::tokio::future_into_py(py, async move { + let (server, mut local_addr) = + Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; + // Work around Windows limitation, see packet_sources/udp.rs + if is_unspecified && local_addr == SocketAddr::from((Ipv4Addr::LOCALHOST, port)) { + local_addr.set_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); + } + + Ok(UdpServer { server, local_addr }) + }) +} diff --git a/mitmproxy-rs/src/server/wireguard.rs b/mitmproxy-rs/src/server/wireguard.rs new file mode 100644 index 00000000..4f9e7361 --- /dev/null +++ b/mitmproxy-rs/src/server/wireguard.rs @@ -0,0 +1,93 @@ +use std::net::SocketAddr; + +use crate::util::{socketaddr_to_py, string_to_key}; + +#[cfg(target_os = "macos")] +use mitmproxy::packet_sources::macos::MacosConf; + +use mitmproxy::packet_sources::wireguard::WireGuardConf; + +use pyo3::prelude::*; +#[cfg(target_os = "macos")] +use std::path::Path; + +use boringtun::x25519::PublicKey; + +use crate::server::base::Server; + +/// A running WireGuard server. +/// +/// A new server can be started by calling `start_udp_server`. +/// The public API is intended to be similar to the API provided by +/// [`asyncio.Server`](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.Server) +/// from the Python standard library. +#[pyclass(module = "mitmproxy_rs")] +#[derive(Debug)] +pub struct WireGuardServer { + /// local address of the WireGuard UDP socket + local_addr: SocketAddr, + server: Server, +} + +#[pymethods] +impl WireGuardServer { + /// Request the WireGuard server to gracefully shut down. + /// + /// The server will stop accepting new connections on its UDP socket, but will flush pending + /// outgoing data before shutting down. + pub fn close(&mut self) { + self.server.close() + } + + /// Wait until the WireGuard server has shut down. + /// + /// This coroutine will yield once pending data has been flushed and all server tasks have + /// successfully terminated after calling the `Server.close` method. + pub fn wait_closed<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { + self.server.wait_closed(py) + } + + /// Get the local socket address that the WireGuard server is listening on. + pub fn getsockname(&self, py: Python) -> PyObject { + socketaddr_to_py(py, self.local_addr) + } + + pub fn __repr__(&self) -> String { + format!("WireGuardServer({})", self.local_addr) + } +} + +/// Start a WireGuard server that is configured with the given parameters: +/// +/// - `host`: The host address for the WireGuard UDP socket. +/// - `port`: The listen port for the WireGuard server. The default port for WireGuard is `51820`. +/// - `private_key`: The private X25519 key for the WireGuard server as a base64-encoded string. +/// - `peer_public_keys`: List of public X25519 keys for WireGuard peers as base64-encoded strings. +/// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. +/// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. +#[pyfunction] +pub fn start_wireguard_server( + py: Python<'_>, + host: String, + port: u16, + private_key: String, + peer_public_keys: Vec, + handle_tcp_stream: PyObject, + handle_udp_stream: PyObject, +) -> PyResult<&PyAny> { + let private_key = string_to_key(private_key)?; + let peer_public_keys = peer_public_keys + .into_iter() + .map(string_to_key) + .collect::>>()?; + let conf = WireGuardConf { + host, + port, + private_key, + peer_public_keys, + }; + pyo3_asyncio::tokio::future_into_py(py, async move { + let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; + Ok(WireGuardServer { server, local_addr }) + }) +} diff --git a/mitmproxy-rs/src/task.rs b/mitmproxy-rs/src/task.rs index d0081989..413086c5 100644 --- a/mitmproxy-rs/src/task.rs +++ b/mitmproxy-rs/src/task.rs @@ -105,7 +105,7 @@ impl PyInteropTask { Ok(()) }) { - log::error!("Failed to spawn TCP connection handler coroutine:\n{}", err); + log::error!("Failed to spawn connection handler:\n{}", err); }; }, } diff --git a/mitmproxy-rs/src/util.rs b/mitmproxy-rs/src/util.rs index 5a3f90e3..195a1fee 100644 --- a/mitmproxy-rs/src/util.rs +++ b/mitmproxy-rs/src/util.rs @@ -28,16 +28,7 @@ where } pub fn socketaddr_to_py(py: Python, s: SocketAddr) -> PyObject { - match s { - SocketAddr::V4(addr) => (addr.ip().to_string(), addr.port()).into_py(py), - SocketAddr::V6(addr) => { - log::debug!( - "Converting IPv6 address/port to Python equivalent (not sure if this is correct): {:?}", - (addr.ip().to_string(), addr.port()) - ); - (addr.ip().to_string(), addr.port()).into_py(py) - } - } + (s.ip().to_string(), s.port()).into_py(py) } #[allow(dead_code)] @@ -146,6 +137,7 @@ pub(crate) fn get_tunnel_info( "remote_endpoint" => return Ok(remote_endpoint.clone().into_py(py)), _ => (), }, + TunnelInfo::Udp {} => (), } match default { Some(x) => Ok(x), diff --git a/src/messages.rs b/src/messages.rs index 129d8b59..ae1d9c15 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -19,6 +19,7 @@ pub enum TunnelInfo { /// an unresolved remote_endpoint instead. remote_endpoint: Option<(String, u16)>, }, + Udp {}, } /// Events that are sent by WireGuard to the TCP stack. @@ -96,6 +97,18 @@ pub enum TransportCommand { CloseConnection(ConnectionId, bool), } +impl TransportCommand { + pub fn is_tcp(&self) -> bool { + match self { + TransportCommand::ReadData(id, _, _) => id, + TransportCommand::WriteData(id, _) => id, + TransportCommand::DrainWriter(id, _) => id, + TransportCommand::CloseConnection(id, _) => id, + } + .is_tcp() + } +} + /// Generic IPv4/IPv6 packet type that wraps smoltcp's IPv4 and IPv6 packet buffers #[derive(Debug, Clone)] pub enum SmolPacket { diff --git a/src/network/core.rs b/src/network/core.rs index b8f94218..8575cec2 100644 --- a/src/network/core.rs +++ b/src/network/core.rs @@ -12,7 +12,7 @@ use crate::messages::{NetworkCommand, NetworkEvent, SmolPacket, TransportCommand use crate::network::icmp::{handle_icmpv4_echo_request, handle_icmpv6_echo_request}; use crate::network::tcp::TcpHandler; -use crate::network::udp::UdpHandler; +use crate::network::udp::{UdpHandler, UdpPacket}; pub struct NetworkStack<'a> { tcp: TcpHandler<'a>, @@ -24,7 +24,7 @@ impl<'a> NetworkStack<'a> { pub fn new(net_tx: Sender) -> Self { Self { tcp: TcpHandler::new(net_tx.clone()), - udp: UdpHandler::new(net_tx.clone()), + udp: UdpHandler::new(), net_tx, } } @@ -50,7 +50,13 @@ impl<'a> NetworkStack<'a> { match packet.transport_protocol() { IpProtocol::Tcp => self.tcp.receive_packet(packet, tunnel_info, permit), - IpProtocol::Udp => self.udp.receive_packet(packet, tunnel_info, permit), + IpProtocol::Udp => { + match UdpPacket::try_from(packet) { + Ok(packet) => self.udp.receive_data(packet, tunnel_info, permit), + Err(e) => log::debug!("Received invalid UDP packet: {}", e), + }; + Ok(()) + } IpProtocol::Icmp => self.receive_packet_icmp(packet), _ => { log::debug!( @@ -83,24 +89,17 @@ impl<'a> NetworkStack<'a> { } pub fn handle_transport_command(&mut self, command: TransportCommand) { - match command { - TransportCommand::ReadData(id, n, tx) => match id.is_tcp() { - true => self.tcp.read_data(id, n, tx), - false => self.udp.read_data(id, tx), - }, - TransportCommand::WriteData(id, buf) => match id.is_tcp() { - true => self.tcp.write_data(id, buf), - false => self.udp.write_data(id, buf), - }, - TransportCommand::DrainWriter(id, tx) => match id.is_tcp() { - true => self.tcp.drain_writer(id, tx), - false => self.udp.drain_writer(id, tx), - }, - TransportCommand::CloseConnection(id, half_close) => match id.is_tcp() { - true => self.tcp.close_connection(id, half_close), - false => self.udp.close_connection(id), - }, - }; + if command.is_tcp() { + self.tcp.handle_transport_command(command); + } else if let Some(packet) = self.udp.handle_transport_command(command) { + if self + .net_tx + .try_send(NetworkCommand::SendPacket(SmolPacket::from(packet))) + .is_err() + { + log::debug!("Channel unavailable, discarding UDP packet."); + } + } } pub fn poll_delay(&mut self) -> Option { diff --git a/src/network/mod.rs b/src/network/mod.rs index 998a8331..66924e57 100755 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -9,6 +9,6 @@ mod icmp; mod tcp; #[cfg(test)] mod tests; -mod udp; +pub(crate) mod udp; pub const MAX_PACKET_SIZE: usize = 65535; diff --git a/src/network/task.rs b/src/network/task.rs index 0e61e8ab..11a43ef2 100755 --- a/src/network/task.rs +++ b/src/network/task.rs @@ -55,7 +55,7 @@ impl NetworkTask<'_> { net_rx: Receiver, py_tx: Sender, py_rx: UnboundedReceiver, - sd_watcher: BroadcastReceiver<()>, + shutdown: BroadcastReceiver<()>, ) -> Result { let io = NetworkStack::new(net_tx.clone()); Ok(Self { @@ -63,7 +63,7 @@ impl NetworkTask<'_> { net_rx, py_tx, py_rx, - shutdown: sd_watcher, + shutdown, io, }) } diff --git a/src/network/tcp.rs b/src/network/tcp.rs index ed26d9c6..ddd26b59 100644 --- a/src/network/tcp.rs +++ b/src/network/tcp.rs @@ -19,7 +19,8 @@ use tokio::sync::{ }; use crate::messages::{ - ConnectionId, ConnectionIdGenerator, NetworkCommand, SmolPacket, TransportEvent, TunnelInfo, + ConnectionId, ConnectionIdGenerator, NetworkCommand, SmolPacket, TransportCommand, + TransportEvent, TunnelInfo, }; use super::virtual_device::VirtualDevice; @@ -160,6 +161,17 @@ impl<'a> TcpHandler<'a> { .map(Duration::from) } + pub fn handle_transport_command(&mut self, command: TransportCommand) { + match command { + TransportCommand::ReadData(id, n, tx) => self.read_data(id, n, tx), + TransportCommand::WriteData(id, buf) => self.write_data(id, buf), + TransportCommand::DrainWriter(id, tx) => self.drain_writer(id, tx), + TransportCommand::CloseConnection(id, half_close) => { + self.close_connection(id, half_close) + } + }; + } + pub fn read_data(&mut self, id: ConnectionId, n: u32, tx: oneshot::Sender>) { if let Some(data) = self.socket_data.get_mut(&id) { assert!(data.recv_waiter.is_none()); diff --git a/src/network/udp.rs b/src/network/udp.rs index 3fc7c0d0..8e45d6f1 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -3,19 +3,18 @@ use std::net::SocketAddr; use std::time::Duration; use lru_time_cache::LruCache; -use tokio::sync::mpsc::{Permit, Sender}; +use tokio::sync::mpsc::Permit; use tokio::sync::oneshot; use crate::messages::{ - ConnectionId, ConnectionIdGenerator, NetworkCommand, SmolPacket, TransportEvent, TunnelInfo, + ConnectionId, ConnectionIdGenerator, SmolPacket, TransportCommand, TransportEvent, TunnelInfo, }; -use anyhow::Result; use internet_packet::InternetPacket; use smoltcp::phy::ChecksumCapabilities; use smoltcp::wire::{ IpProtocol, IpRepr, Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv6Address, Ipv6Packet, Ipv6Repr, - UdpPacket, UdpRepr, + UdpRepr, }; struct ConnectionState { @@ -70,11 +69,10 @@ pub struct UdpHandler { connection_id_generator: ConnectionIdGenerator, id_lookup: LruCache<(SocketAddr, SocketAddr), ConnectionId>, connections: LruCache, - net_tx: Sender, } impl UdpHandler { - pub fn new(net_tx: Sender) -> Self { + pub fn new() -> Self { let connections = LruCache::::with_expiry_duration( Duration::from_secs(60), ); @@ -84,87 +82,54 @@ impl UdpHandler { Self { connections, id_lookup, - net_tx, connection_id_generator: ConnectionIdGenerator::udp(), } } + pub(crate) fn handle_transport_command( + &mut self, + command: TransportCommand, + ) -> Option { + match command { + TransportCommand::ReadData(id, _, tx) => { + self.read_data(id, tx); + None + } + TransportCommand::WriteData(id, data) => self.write_data(id, data), + TransportCommand::DrainWriter(id, tx) => { + self.drain_writer(id, tx); + None + } + TransportCommand::CloseConnection(id, _) => { + self.close_connection(id); + None + } + } + } + pub fn read_data(&mut self, id: ConnectionId, tx: oneshot::Sender>) { if let Some(state) = self.connections.get_mut(&id) { state.read_packet_payload(tx); } } - pub fn write_data(&mut self, id: ConnectionId, data: Vec) { + pub(crate) fn write_data(&mut self, id: ConnectionId, data: Vec) -> Option { let Some(state) = self.connections.get(&id) else { - return; + return None; }; // Refresh id lookup. self.id_lookup .insert((state.local_addr, state.remote_addr), id); - let permit = match self.net_tx.try_reserve() { - Ok(p) => p, - Err(_) => { - log::debug!("Channel full, discarding UDP packet."); - return; - } - }; - - // We now know that there's space for us to send, - // let's painstakingly reassemble the IP packet... - - let udp_repr = UdpRepr { - src_port: state.local_addr.port(), - dst_port: state.remote_addr.port(), - }; - - let ip_repr: IpRepr = match (state.local_addr, state.remote_addr) { - (SocketAddr::V4(src_addr), SocketAddr::V4(dst_addr)) => IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address::from(*src_addr.ip()), - dst_addr: Ipv4Address::from(*dst_addr.ip()), - next_header: IpProtocol::Udp, - payload_len: udp_repr.header_len() + data.len(), - hop_limit: 255, - }), - (SocketAddr::V6(src_addr), SocketAddr::V6(dst_addr)) => IpRepr::Ipv6(Ipv6Repr { - src_addr: Ipv6Address::from(*src_addr.ip()), - dst_addr: Ipv6Address::from(*dst_addr.ip()), - next_header: IpProtocol::Udp, - payload_len: udp_repr.header_len() + data.len(), - hop_limit: 255, - }), - _ => { - log::error!("Failed to assemble UDP datagram: mismatched IP address versions"); - return; - } - }; - - let buf = vec![0u8; ip_repr.buffer_len()]; - - let mut ip_packet = match ip_repr { - IpRepr::Ipv4(repr) => { - let mut packet = Ipv4Packet::new_unchecked(buf); - repr.emit(&mut packet, &ChecksumCapabilities::default()); - SmolPacket::from(packet) - } - IpRepr::Ipv6(repr) => { - let mut packet = Ipv6Packet::new_unchecked(buf); - repr.emit(&mut packet); - SmolPacket::from(packet) - } - }; - - udp_repr.emit( - &mut UdpPacket::new_unchecked(ip_packet.payload_mut()), - &ip_repr.src_addr(), - &ip_repr.dst_addr(), - data.len(), - |buf| buf.copy_from_slice(data.as_slice()), - &ChecksumCapabilities::default(), - ); + if state.closed { + return None; + } - permit.send(NetworkCommand::SendPacket(ip_packet)); + Some(UdpPacket { + src_addr: state.local_addr, + dst_addr: state.remote_addr, + payload: data, + }) } pub fn drain_writer(&mut self, _id: ConnectionId, tx: oneshot::Sender<()>) { @@ -177,50 +142,39 @@ impl UdpHandler { } } - pub fn receive_packet( + pub(crate) fn receive_data( &mut self, - packet: SmolPacket, + packet: UdpPacket, tunnel_info: TunnelInfo, permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - let packet: InternetPacket = match packet.try_into() { - Ok(p) => p, - Err(e) => { - log::debug!("Received invalid IP packet: {}", e); - return Ok(()); - } - }; - let src_addr = packet.src(); - let dst_addr = packet.dst(); - + ) { let potential_cid = self .id_lookup - .get(&(src_addr, dst_addr)) + .get(&(packet.src_addr, packet.dst_addr)) .cloned() .unwrap_or(ConnectionId::unassigned()); - let payload = packet.payload().to_vec(); + let payload = packet.payload; match self.connections.get_mut(&potential_cid) { Some(state) => { state.receive_packet_payload(payload); } None => { - let mut state = ConnectionState::new(src_addr, dst_addr); + let mut state = ConnectionState::new(packet.src_addr, packet.dst_addr); state.receive_packet_payload(payload); let connection_id = self.connection_id_generator.next_id(); - self.id_lookup.insert((src_addr, dst_addr), connection_id); + self.id_lookup + .insert((packet.src_addr, packet.dst_addr), connection_id); self.connections.insert(connection_id, state); permit.send(TransportEvent::ConnectionEstablished { connection_id, - src_addr, - dst_addr, + src_addr: packet.src_addr, + dst_addr: packet.dst_addr, tunnel_info, }); } }; - - Ok(()) } pub fn poll_delay(&mut self) -> Option { @@ -238,6 +192,82 @@ impl UdpHandler { } } +pub(crate) struct UdpPacket { + pub src_addr: SocketAddr, + pub dst_addr: SocketAddr, + pub payload: Vec, +} +impl TryFrom for UdpPacket { + type Error = internet_packet::ParseError; + + fn try_from(value: SmolPacket) -> Result { + let packet: InternetPacket = value.try_into()?; + Ok(UdpPacket { + src_addr: packet.src(), + dst_addr: packet.dst(), + payload: packet.payload().to_vec(), + }) + } +} + +impl From for SmolPacket { + fn from(value: UdpPacket) -> Self { + let UdpPacket { + src_addr, + dst_addr, + payload, + } = value; + + let udp_repr = UdpRepr { + src_port: src_addr.port(), + dst_port: dst_addr.port(), + }; + + let ip_repr: IpRepr = match (src_addr, dst_addr) { + (SocketAddr::V4(src_addr), SocketAddr::V4(dst_addr)) => IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address::from(*src_addr.ip()), + dst_addr: Ipv4Address::from(*dst_addr.ip()), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + payload.len(), + hop_limit: 255, + }), + (SocketAddr::V6(src_addr), SocketAddr::V6(dst_addr)) => IpRepr::Ipv6(Ipv6Repr { + src_addr: Ipv6Address::from(*src_addr.ip()), + dst_addr: Ipv6Address::from(*dst_addr.ip()), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + payload.len(), + hop_limit: 255, + }), + _ => unreachable!("Mismatched IP address versions"), + }; + + let buf = vec![0u8; ip_repr.buffer_len()]; + + let mut smol_packet = match ip_repr { + IpRepr::Ipv4(repr) => { + let mut packet = Ipv4Packet::new_unchecked(buf); + repr.emit(&mut packet, &ChecksumCapabilities::default()); + SmolPacket::from(packet) + } + IpRepr::Ipv6(repr) => { + let mut packet = Ipv6Packet::new_unchecked(buf); + repr.emit(&mut packet); + SmolPacket::from(packet) + } + }; + + udp_repr.emit( + &mut smoltcp::wire::UdpPacket::new_unchecked(smol_packet.payload_mut()), + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + payload.len(), + |buf| buf.copy_from_slice(payload.as_slice()), + &ChecksumCapabilities::default(), + ); + smol_packet + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/packet_sources/macos.rs b/src/packet_sources/macos.rs index c2d8eeac..7d758c0f 100644 --- a/src/packet_sources/macos.rs +++ b/src/packet_sources/macos.rs @@ -80,7 +80,7 @@ impl PacketSourceConf for MacosConf { transport_events_tx: Sender, transport_commands_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, - ) -> Result<(MacOsTask, Self::Data)> { + ) -> Result<(Self::Task, Self::Data)> { let listener_addr = format!("/tmp/mitmproxy-{}", std::process::id()); let listener = UnixListener::bind(&listener_addr)?; diff --git a/src/packet_sources/mod.rs b/src/packet_sources/mod.rs index 6f763335..f11004a4 100755 --- a/src/packet_sources/mod.rs +++ b/src/packet_sources/mod.rs @@ -6,6 +6,7 @@ use crate::messages::{TransportCommand, TransportEvent}; #[cfg(target_os = "macos")] pub mod macos; +pub mod udp; #[cfg(windows)] pub mod windows; pub mod wireguard; diff --git a/src/packet_sources/udp.rs b/src/packet_sources/udp.rs new file mode 100644 index 00000000..869153eb --- /dev/null +++ b/src/packet_sources/udp.rs @@ -0,0 +1,153 @@ +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +use anyhow::Result; +use async_trait::async_trait; + +use tokio::sync::mpsc::{Permit, UnboundedReceiver}; +use tokio::{ + net::UdpSocket, + sync::{broadcast, mpsc::Sender}, +}; + +use crate::messages::{TransportCommand, TransportEvent, TunnelInfo}; +use crate::network::udp::{UdpHandler, UdpPacket}; +use crate::network::MAX_PACKET_SIZE; +use crate::packet_sources::{PacketSourceConf, PacketSourceTask}; + +pub struct UdpConf { + pub host: String, + pub port: u16, +} + +#[async_trait] +impl PacketSourceConf for UdpConf { + type Task = UdpTask; + type Data = SocketAddr; + + fn name(&self) -> &'static str { + "UDP server" + } + + async fn build( + self, + transport_events_tx: Sender, + transport_commands_rx: UnboundedReceiver, + shutdown: broadcast::Receiver<()>, + ) -> Result<(Self::Task, Self::Data)> { + // bind to UDP socket(s) + + let socket_addrs = if self.host.is_empty() { + vec![ + // Windows quirks: We need to bind to 127.0.0.1 explicitly for IPv4. + #[cfg(windows)] + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), self.port), + #[cfg(not(windows))] + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), self.port), + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), self.port), + ] + } else { + vec![SocketAddr::new(self.host.parse()?, self.port)] + }; + + let socket = UdpSocket::bind(socket_addrs.as_slice()).await?; + let local_addr = socket.local_addr()?; + + log::debug!( + "UDP server listening on {} ...", + socket_addrs + .iter() + .map(|addr| addr.to_string()) + .collect::>() + .join(" and ") + ); + + Ok(( + UdpTask { + socket, + local_addr, + handler: UdpHandler::new(), + transport_events_tx, + transport_commands_rx, + shutdown, + }, + local_addr, + )) + } +} + +pub struct UdpTask { + socket: UdpSocket, + local_addr: SocketAddr, + + handler: UdpHandler, + + transport_events_tx: Sender, + transport_commands_rx: UnboundedReceiver, + shutdown: broadcast::Receiver<()>, +} + +#[async_trait] +impl PacketSourceTask for UdpTask { + async fn run(mut self) -> Result<()> { + let transport_events_tx = self.transport_events_tx.clone(); + let mut udp_buf = [0; MAX_PACKET_SIZE]; + + let mut packet_needs_sending = false; + let mut packet_payload = Vec::new(); + let mut packet_dst = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)); + + let mut permit: Option> = None; + + loop { + let py_tx_available = permit.is_some(); + + tokio::select! { + // wait for graceful shutdown + _ = self.shutdown.recv() => break, + // wait for transport_events_tx channel capacity... + Ok(p) = transport_events_tx.reserve(), if !py_tx_available => { + permit = Some(p); + continue; + }, + // ... or process incoming packets + Ok((len, src_addr)) = self.socket.recv_from(&mut udp_buf), if py_tx_available => { + self.process_incoming_datagram(&udp_buf[..len], src_addr, permit.take().unwrap()).await?; + }, + // send_to is cancel safe, so we can use that for backpressure. + _ = self.socket.send_to(&packet_payload, packet_dst), if packet_needs_sending => { + packet_needs_sending = false; + }, + Some(command) = self.transport_commands_rx.recv(), if !packet_needs_sending => { + if let Some(UdpPacket { payload, dst_addr, .. }) = self.handler.handle_transport_command(command) { + packet_payload = payload; + packet_dst = dst_addr; + packet_needs_sending = true; + } + } + } + } + log::debug!("UDP server task shutting down."); + Ok(()) + } +} + +impl UdpTask { + async fn process_incoming_datagram( + &mut self, + data: &[u8], + sender_addr: SocketAddr, + permit: Permit<'_, TransportEvent>, + ) -> Result<()> { + let packet = UdpPacket { + src_addr: sender_addr, + dst_addr: self.local_addr, + payload: data.to_vec(), + }; + let tunnel_info = TunnelInfo::WireGuard { + src_addr: sender_addr, + dst_addr: self.socket.local_addr()?, + }; + self.handler.receive_data(packet, tunnel_info, permit); + Ok(()) + } +} diff --git a/src/packet_sources/windows.rs b/src/packet_sources/windows.rs index ea38b0c6..d32f43e4 100755 --- a/src/packet_sources/windows.rs +++ b/src/packet_sources/windows.rs @@ -46,7 +46,7 @@ impl PacketSourceConf for WindowsConf { transport_events_tx: Sender, transport_commands_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, - ) -> Result<(WindowsTask, Self::Data)> { + ) -> Result<(Self::Task, Self::Data)> { let pipe_name = format!( r"\\.\pipe\mitmproxy-transparent-proxy-{}", std::process::id() diff --git a/src/packet_sources/wireguard.rs b/src/packet_sources/wireguard.rs index e13928a2..c6f94a74 100755 --- a/src/packet_sources/wireguard.rs +++ b/src/packet_sources/wireguard.rs @@ -56,7 +56,7 @@ impl PacketSourceConf for WireGuardConf { transport_events_tx: Sender, transport_commands_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, - ) -> Result<(WireGuardTask, Self::Data)> { + ) -> Result<(Self::Task, Self::Data)> { let (network_task_handle, net_tx, net_rx) = add_network_layer(transport_events_tx, transport_commands_rx, shutdown)?; From d147d986134ad7ef4c8a345b62c9cdc8398d27f4 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 17 Dec 2023 17:20:27 +0100 Subject: [PATCH 09/20] add UDP client --- mitmproxy-rs/mitmproxy_rs/__init__.pyi | 31 ++++- mitmproxy-rs/src/lib.rs | 3 + mitmproxy-rs/src/server/udp.rs | 12 +- mitmproxy-rs/src/stream.rs | 56 ++++++--- mitmproxy-rs/src/task.rs | 2 +- mitmproxy-rs/src/udp_client.rs | 151 +++++++++++++++++++++++++ mitmproxy-rs/src/util.rs | 34 +----- src/messages.rs | 2 +- src/packet_sources/udp.rs | 33 ++---- 9 files changed, 235 insertions(+), 89 deletions(-) create mode 100644 mitmproxy-rs/src/udp_client.rs diff --git a/mitmproxy-rs/mitmproxy_rs/__init__.pyi b/mitmproxy-rs/mitmproxy_rs/__init__.pyi index a2875b11..d39a1e6b 100644 --- a/mitmproxy-rs/mitmproxy_rs/__init__.pyi +++ b/mitmproxy-rs/mitmproxy_rs/__init__.pyi @@ -2,9 +2,11 @@ from __future__ import annotations from pathlib import Path from typing import Awaitable, Callable, Any, Literal -from typing import final, overload +from typing import final, overload, TypeVar +T = TypeVar("T") + # WireGuard async def start_wireguard_server( @@ -43,7 +45,7 @@ class LocalRedirector: async def wait_closed(self) -> None: ... -# UDP Server +# UDP async def start_udp_server( host: str, @@ -58,6 +60,12 @@ class UdpServer: async def wait_closed(self) -> None: ... def __repr__(self) -> str: ... +async def open_udp_connection( + host: str, + port: int, + *, + local_addr: tuple[str, int] | None = None, +) -> Stream: ... # TCP / UDP @@ -73,9 +81,24 @@ class Stream: async def wait_closed(self) -> None: ... @overload - def get_extra_info(self, name: Literal["transport_protocol"], default: Any = None) -> Literal["tcp", "udp"]: ... + def get_extra_info(self, name: Literal["transport_protocol"], default: None = None) -> Literal["tcp", "udp"]: ... + @overload + def get_extra_info(self, name: Literal["transport_protocol"], default: T) -> Literal["tcp", "udp"] | T: ... + @overload + def get_extra_info(self, name: Literal["peername", "sockname", "original_src", "original_dst", "remote_endpoint"], default: None = None) -> tuple[str, int]: ... + @overload + def get_extra_info(self, name: Literal["peername", "sockname", "original_src", "original_dst", "remote_endpoint"], default: T) -> tuple[str, int] | T: ... + @overload + def get_extra_info(self, name: Literal["pid"], default: None = None) -> int: ... + @overload + def get_extra_info(self, name: Literal["pid"], default: T) -> int | T: ... + @overload + def get_extra_info(self, name: Literal["process_name"], default: None = None) -> str: ... + @overload + def get_extra_info(self, name: Literal["process_name"], default: T) -> str | T: ... + @overload + def get_extra_info(self, name: str, default: T) -> T: ... - def get_extra_info(self, name: str, default: Any = None) -> Any: ... def __repr__(self) -> str: ... diff --git a/mitmproxy-rs/src/lib.rs b/mitmproxy-rs/src/lib.rs index add36a15..b4f3e746 100644 --- a/mitmproxy-rs/src/lib.rs +++ b/mitmproxy-rs/src/lib.rs @@ -9,6 +9,7 @@ mod process_info; mod server; mod stream; mod task; +mod udp_client; mod util; static LOGGER_INITIALIZED: Lazy> = Lazy::new(|| RwLock::new(false)); @@ -52,6 +53,8 @@ pub fn mitmproxy_rs(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(server::start_udp_server, m)?)?; m.add_class::()?; + m.add_function(wrap_pyfunction!(udp_client::open_udp_connection, m)?)?; + m.add_function(wrap_pyfunction!(process_info::active_executables, m)?)?; m.add_class::()?; m.add_function(wrap_pyfunction!(process_info::executable_icon, m)?)?; diff --git a/mitmproxy-rs/src/server/udp.rs b/mitmproxy-rs/src/server/udp.rs index 9b6d1ab3..8bebce36 100644 --- a/mitmproxy-rs/src/server/udp.rs +++ b/mitmproxy-rs/src/server/udp.rs @@ -1,10 +1,11 @@ -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::net::SocketAddr; use mitmproxy::packet_sources::udp::UdpConf; use pyo3::prelude::*; use crate::server::base::Server; + use crate::util::socketaddr_to_py; /// A running UDP server. @@ -62,17 +63,10 @@ pub fn start_udp_server( port: u16, handle_udp_stream: PyObject, ) -> PyResult<&PyAny> { - let is_unspecified = host.is_empty(); let conf = UdpConf { host, port }; let handle_tcp_stream = py.None(); pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, mut local_addr) = - Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; - // Work around Windows limitation, see packet_sources/udp.rs - if is_unspecified && local_addr == SocketAddr::from((Ipv4Addr::LOCALHOST, port)) { - local_addr.set_ip(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); - } - + let (server, local_addr) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; Ok(UdpServer { server, local_addr }) }) } diff --git a/mitmproxy-rs/src/stream.rs b/mitmproxy-rs/src/stream.rs index 0bffbc8d..a59b6076 100644 --- a/mitmproxy-rs/src/stream.rs +++ b/mitmproxy-rs/src/stream.rs @@ -2,6 +2,7 @@ use std::net::SocketAddr; use once_cell::sync::Lazy; +use pyo3::exceptions::PyKeyError; use pyo3::{exceptions::PyOSError, intern, prelude::*, types::PyBytes}; use tokio::sync::{ @@ -11,7 +12,7 @@ use tokio::sync::{ use mitmproxy::messages::{ConnectionId, TransportCommand, TunnelInfo}; -use crate::util::{event_queue_unavailable, get_tunnel_info, socketaddr_to_py}; +use crate::util::{event_queue_unavailable, socketaddr_to_py}; #[derive(Debug)] pub enum StreamState { @@ -28,7 +29,7 @@ pub enum StreamState { pub struct Stream { pub connection_id: ConnectionId, pub state: StreamState, - pub event_tx: mpsc::UnboundedSender, + pub command_tx: mpsc::UnboundedSender, pub peername: SocketAddr, pub sockname: SocketAddr, pub tunnel_info: TunnelInfo, @@ -49,7 +50,7 @@ impl Stream { StreamState::Open | StreamState::HalfClosed => { let (tx, rx) = oneshot::channel(); - self.event_tx + self.command_tx .send(TransportCommand::ReadData(self.connection_id, n, tx)) .ok(); // if this fails tx is dropped and rx.await will error. @@ -77,7 +78,7 @@ impl Stream { fn write(&self, data: Vec) -> PyResult<()> { match self.state { StreamState::Open => self - .event_tx + .command_tx .send(TransportCommand::WriteData(self.connection_id, data)) .map_err(event_queue_unavailable), StreamState::HalfClosed => Err(PyOSError::new_err("connection closed")), @@ -92,7 +93,7 @@ impl Stream { fn drain<'p>(&self, py: Python<'p>) -> PyResult<&'p PyAny> { let (tx, rx) = oneshot::channel(); - self.event_tx + self.command_tx .send(TransportCommand::DrainWriter(self.connection_id, tx)) .map_err(event_queue_unavailable)?; @@ -111,7 +112,7 @@ impl Stream { match self.state { StreamState::Open => { self.state = StreamState::HalfClosed; - self.event_tx + self.command_tx .send(TransportCommand::CloseConnection(self.connection_id, true)) .map_err(event_queue_unavailable) } @@ -128,7 +129,7 @@ impl Stream { match self.state { StreamState::Open | StreamState::HalfClosed => { self.state = StreamState::Closed; - self.event_tx + self.command_tx .send(TransportCommand::CloseConnection(self.connection_id, false)) .map_err(event_queue_unavailable) } @@ -155,7 +156,6 @@ impl Stream { /// - Always available: `transport_protocol`, `peername`, `sockname` /// - WireGuard mode: `original_dst`, `original_src` /// - Local redirector mode: `pid`, `process_name`, `remote_endpoint` - #[pyo3(text_signature = "(self, name, default=None)")] fn get_extra_info( &self, py: Python, @@ -163,14 +163,38 @@ impl Stream { default: Option, ) -> PyResult { match name.as_str() { - "transport_protocol" => Ok(PyObject::from(if self.connection_id.is_tcp() { - intern!(py, "tcp") - } else { - intern!(py, "udp") - })), - "peername" => Ok(socketaddr_to_py(py, self.peername)), - "sockname" => Ok(socketaddr_to_py(py, self.sockname)), - _ => get_tunnel_info(&self.tunnel_info, py, name, default), + "transport_protocol" => { + if self.connection_id.is_tcp() { + return Ok(PyObject::from(intern!(py, "tcp"))); + } else { + return Ok(PyObject::from(intern!(py, "udp"))); + } + } + "peername" => return Ok(socketaddr_to_py(py, self.peername)), + "sockname" => return Ok(socketaddr_to_py(py, self.sockname)), + _ => (), + } + match &self.tunnel_info { + TunnelInfo::WireGuard { src_addr, dst_addr } => match name.as_str() { + "original_src" => return Ok(socketaddr_to_py(py, *src_addr)), + "original_dst" => return Ok(socketaddr_to_py(py, *dst_addr)), + _ => (), + }, + TunnelInfo::LocalRedirector { + pid, + process_name, + remote_endpoint, + } => match name.as_str() { + "pid" => return Ok(pid.into_py(py)), + "process_name" => return Ok(process_name.clone().into_py(py)), + "remote_endpoint" => return Ok(remote_endpoint.clone().into_py(py)), + _ => (), + }, + TunnelInfo::Udp {} => (), + } + match default { + Some(x) => Ok(x), + None => Err(PyKeyError::new_err(name)), } } diff --git a/mitmproxy-rs/src/task.rs b/mitmproxy-rs/src/task.rs index 413086c5..9937c240 100644 --- a/mitmproxy-rs/src/task.rs +++ b/mitmproxy-rs/src/task.rs @@ -68,7 +68,7 @@ impl PyInteropTask { let stream = Stream { connection_id, state: StreamState::Open, - event_tx: self.transport_commands.clone(), + command_tx: self.transport_commands.clone(), peername: src_addr, sockname: dst_addr, tunnel_info, diff --git a/mitmproxy-rs/src/udp_client.rs b/mitmproxy-rs/src/udp_client.rs new file mode 100644 index 00000000..56e13466 --- /dev/null +++ b/mitmproxy-rs/src/udp_client.rs @@ -0,0 +1,151 @@ +use anyhow::Context; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + +use anyhow::Result; +use pyo3::prelude::*; +use tokio::net::{lookup_host, UdpSocket}; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver}; +use tokio::sync::oneshot; + +use crate::stream::{Stream, StreamState}; +use mitmproxy::messages::{ConnectionId, TransportCommand, TunnelInfo}; +use mitmproxy::MAX_PACKET_SIZE; + +/// Start a UDP client that is configured with the given parameters: +/// +/// - `host`: The host address. +/// - `port`: The listen port. +/// - `local_addr`: The local address to bind to. +#[pyfunction] +#[pyo3(signature = (host, port, *, local_addr = None))] +pub fn open_udp_connection( + py: Python<'_>, + host: String, + port: u16, + local_addr: Option<(String, u16)>, +) -> PyResult<&PyAny> { + pyo3_asyncio::tokio::future_into_py(py, async move { + let socket = udp_connect(host, port, local_addr).await?; + + let peername = socket.peer_addr()?; + let sockname = socket.local_addr()?; + + let (command_tx, command_rx) = unbounded_channel(); + + tokio::spawn( + UdpClientTask { + socket, + transport_commands_rx: command_rx, + } + .run(), + ); + + let stream = Stream { + connection_id: ConnectionId::unassigned(), + state: StreamState::Open, + command_tx, + peername, + sockname, + tunnel_info: TunnelInfo::Udp, + }; + + Ok(stream) + }) +} + +/// Open an UDP socket from bind_to to host:port. +/// This is a bit trickier than expected because we want to support IPv4 and IPv6. +async fn udp_connect( + host: String, + port: u16, + local_addr: Option<(String, u16)>, +) -> Result { + let addrs: Vec = lookup_host((host.as_str(), port)) + .await + .with_context(|| format!("unable to resolve hostname: {}", host))? + .collect(); + + if let Some((host, port)) = local_addr { + let socket = UdpSocket::bind((host.as_str(), port)) + .await + .with_context(|| format!("unable to bind to ({}, {})", host, port))?; + socket + .connect(addrs.as_slice()) + .await + .context("unable to connect to remote address")?; + Ok(socket) + } else { + if let Ok(socket) = + UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)).await + { + if socket.connect(addrs.as_slice()).await.is_ok() { + return Ok(socket); + } + } + let socket = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)) + .await + .context("unable to bind to 127.0.0.1:0")?; + socket + .connect(addrs.as_slice()) + .await + .context("unable to connect to remote address")?; + Ok(socket) + } +} + +#[derive(Debug)] +pub struct UdpClientTask { + socket: UdpSocket, + transport_commands_rx: UnboundedReceiver, +} + +impl UdpClientTask { + pub async fn run(mut self) { + let mut udp_buf = [0; MAX_PACKET_SIZE]; + + // this here isn't perfect because we block the entire transport_commands_rx channel if we + // cannot send (so we also block receiving new packets), but that's hopefully good enough. + let mut packet_needs_sending = false; + let mut packet_payload = Vec::new(); + + let mut packet_tx: Option>> = None; + + loop { + tokio::select! { + // wait for transport_events_tx channel capacity... + Ok(len) = self.socket.recv(&mut udp_buf), if packet_tx.is_some() => { + packet_tx + .take() + .unwrap() + .send(udp_buf[..len].to_vec()) + .ok(); + }, + // send_to is cancel safe, so we can use that for backpressure. + _ = self.socket.send(&packet_payload), if packet_needs_sending => { + packet_needs_sending = false; + }, + Some(command) = self.transport_commands_rx.recv(), if !packet_needs_sending => { + match command { + TransportCommand::ReadData(_,_,tx) => { + packet_tx = Some(tx); + }, + TransportCommand::WriteData(_, data) => { + packet_payload = data; + packet_needs_sending = true; + }, + TransportCommand::DrainWriter(_,tx) => { + tx.send(()).ok(); + }, + TransportCommand::CloseConnection(_, half_close) => { + if !half_close { + break; + } + }, + } + } + else => break, + } + } + log::debug!("UDP client task shutting down."); + } +} diff --git a/mitmproxy-rs/src/util.rs b/mitmproxy-rs/src/util.rs index 195a1fee..bce1b7bd 100644 --- a/mitmproxy-rs/src/util.rs +++ b/mitmproxy-rs/src/util.rs @@ -3,8 +3,8 @@ use anyhow::{anyhow, Result}; use data_encoding::BASE64; #[cfg(target_os = "macos")] use mitmproxy::macos; -use mitmproxy::messages::TunnelInfo; -use pyo3::exceptions::{PyKeyError, PyOSError}; + +use pyo3::exceptions::PyOSError; use pyo3::types::{PyString, PyTuple}; use pyo3::{exceptions::PyValueError, prelude::*}; use rand_core::OsRng; @@ -114,33 +114,3 @@ pub fn remove_cert() -> PyResult<()> { "OS proxy mode is only available on macos", )) } - -pub(crate) fn get_tunnel_info( - tunnel: &TunnelInfo, - py: Python, - name: String, - default: Option, -) -> PyResult { - match tunnel { - TunnelInfo::WireGuard { src_addr, dst_addr } => match name.as_str() { - "original_src" => return Ok(socketaddr_to_py(py, *src_addr)), - "original_dst" => return Ok(socketaddr_to_py(py, *dst_addr)), - _ => (), - }, - TunnelInfo::LocalRedirector { - pid, - process_name, - remote_endpoint, - } => match name.as_str() { - "pid" => return Ok(pid.into_py(py)), - "process_name" => return Ok(process_name.clone().into_py(py)), - "remote_endpoint" => return Ok(remote_endpoint.clone().into_py(py)), - _ => (), - }, - TunnelInfo::Udp {} => (), - } - match default { - Some(x) => Ok(x), - None => Err(PyKeyError::new_err(name)), - } -} diff --git a/src/messages.rs b/src/messages.rs index ae1d9c15..efbbe547 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -19,7 +19,7 @@ pub enum TunnelInfo { /// an unresolved remote_endpoint instead. remote_endpoint: Option<(String, u16)>, }, - Udp {}, + Udp, } /// Events that are sent by WireGuard to the TCP stack. diff --git a/src/packet_sources/udp.rs b/src/packet_sources/udp.rs index 869153eb..da228b9d 100644 --- a/src/packet_sources/udp.rs +++ b/src/packet_sources/udp.rs @@ -1,6 +1,6 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use std::net::{Ipv4Addr, SocketAddr}; -use anyhow::Result; +use anyhow::{Context, Result}; use async_trait::async_trait; use tokio::sync::mpsc::{Permit, UnboundedReceiver}; @@ -34,32 +34,13 @@ impl PacketSourceConf for UdpConf { transport_commands_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, ) -> Result<(Self::Task, Self::Data)> { - // bind to UDP socket(s) - - let socket_addrs = if self.host.is_empty() { - vec![ - // Windows quirks: We need to bind to 127.0.0.1 explicitly for IPv4. - #[cfg(windows)] - SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), self.port), - #[cfg(not(windows))] - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), self.port), - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), self.port), - ] - } else { - vec![SocketAddr::new(self.host.parse()?, self.port)] - }; - - let socket = UdpSocket::bind(socket_addrs.as_slice()).await?; + // bind to UDP socket. Note that UdpSocket::bind accepts ToSocketAddrs, but will only ever bind to one address! + let socket = UdpSocket::bind((self.host.as_str(), self.port)) + .await + .with_context(|| format!("Failed to bind UDP socket to {}:{}", self.host, self.port))?; let local_addr = socket.local_addr()?; - log::debug!( - "UDP server listening on {} ...", - socket_addrs - .iter() - .map(|addr| addr.to_string()) - .collect::>() - .join(" and ") - ); + log::debug!("UDP server listening on {} ...", local_addr); Ok(( UdpTask { From e4dc32770230ead4d74755637a551b9b2137c2ef Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 17 Dec 2023 19:44:58 +0100 Subject: [PATCH 10/20] fixup type annotation --- mitmproxy-rs/mitmproxy_rs/__init__.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mitmproxy-rs/mitmproxy_rs/__init__.pyi b/mitmproxy-rs/mitmproxy_rs/__init__.pyi index d39a1e6b..bee90ab5 100644 --- a/mitmproxy-rs/mitmproxy_rs/__init__.pyi +++ b/mitmproxy-rs/mitmproxy_rs/__init__.pyi @@ -97,7 +97,7 @@ class Stream: @overload def get_extra_info(self, name: Literal["process_name"], default: T) -> str | T: ... @overload - def get_extra_info(self, name: str, default: T) -> T: ... + def get_extra_info(self, name: str, default: Any) -> Any: ... def __repr__(self) -> str: ... From b31620d5dce83b91102c0a88a11518c2a1ddb351 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 17 Dec 2023 22:12:20 +0100 Subject: [PATCH 11/20] fixup macos mode --- .../TransparentProxyProvider.swift | 2 +- mitmproxy-rs/src/server/local_redirector.rs | 21 +- mitmproxy-rs/src/server/wireguard.rs | 5 - mitmproxy-rs/src/task.rs | 4 +- src/messages.rs | 10 +- src/network/core.rs | 2 +- src/network/task.rs | 2 +- src/network/tcp.rs | 1 + src/network/tests.rs | 6 +- src/network/udp.rs | 13 +- src/packet_sources/macos.rs | 237 +++++++----------- 11 files changed, 130 insertions(+), 173 deletions(-) diff --git a/mitmproxy-macos/redirector/network-extension/TransparentProxyProvider.swift b/mitmproxy-macos/redirector/network-extension/TransparentProxyProvider.swift index e4718d42..bd5171b6 100644 --- a/mitmproxy-macos/redirector/network-extension/TransparentProxyProvider.swift +++ b/mitmproxy-macos/redirector/network-extension/TransparentProxyProvider.swift @@ -157,7 +157,7 @@ class TransparentProxyProvider: NETransparentProxyProvider { guard let remoteEndpoint = tcp_flow.remoteEndpoint as? NWHostEndpoint else { throw TransparentProxyError.noRemoteEndpoint } - log.debug("remoteEndpoint: \(String(describing: remoteEndpoint), privacy: .public)") + // log.debug("remoteEndpoint: \(String(describing: remoteEndpoint), privacy: .public)") // It would be nice if we could also include info on the local endpoint here, but that's not exposed. message = MitmproxyIpc_NewFlow.with { $0.tcp = MitmproxyIpc_TcpFlow.with { diff --git a/mitmproxy-rs/src/server/local_redirector.rs b/mitmproxy-rs/src/server/local_redirector.rs index 08a932a8..b9df16e1 100644 --- a/mitmproxy-rs/src/server/local_redirector.rs +++ b/mitmproxy-rs/src/server/local_redirector.rs @@ -22,6 +22,16 @@ pub struct LocalRedirector { spec: String, } +impl LocalRedirector { + pub fn new(server: Server, conf_tx: mpsc::UnboundedSender) -> Self { + Self { + server, + conf_tx, + spec: "inactive".to_string(), + } + } +} + #[pymethods] impl LocalRedirector { /// Return a textual description of the given spec, @@ -84,11 +94,7 @@ pub fn start_local_redirector( let (server, conf_tx) = Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; - Ok(LocalRedirector { - server, - conf_tx, - spec: "inactive".to_string(), - }) + Ok(LocalRedirector::new(server, conf_tx)) }) } #[cfg(target_os = "macos")] @@ -115,8 +121,9 @@ pub fn start_local_redirector( } let conf = MacosConf; pyo3_asyncio::tokio::future_into_py(py, async move { - let (server, conf_tx) = Server::init(conf, handle_connection, receive_datagram).await?; - Ok(LocalRedirector { server, conf_tx }) + let (server, conf_tx) = + Server::init(conf, handle_tcp_stream, handle_udp_stream).await?; + Ok(LocalRedirector::new(server, conf_tx)) }) } #[cfg(not(any(windows, target_os = "macos")))] diff --git a/mitmproxy-rs/src/server/wireguard.rs b/mitmproxy-rs/src/server/wireguard.rs index 4f9e7361..aebd94ac 100644 --- a/mitmproxy-rs/src/server/wireguard.rs +++ b/mitmproxy-rs/src/server/wireguard.rs @@ -2,14 +2,9 @@ use std::net::SocketAddr; use crate::util::{socketaddr_to_py, string_to_key}; -#[cfg(target_os = "macos")] -use mitmproxy::packet_sources::macos::MacosConf; - use mitmproxy::packet_sources::wireguard::WireGuardConf; use pyo3::prelude::*; -#[cfg(target_os = "macos")] -use std::path::Path; use boringtun::x25519::PublicKey; diff --git a/mitmproxy-rs/src/task.rs b/mitmproxy-rs/src/task.rs index 9937c240..394603d6 100644 --- a/mitmproxy-rs/src/task.rs +++ b/mitmproxy-rs/src/task.rs @@ -63,12 +63,14 @@ impl PyInteropTask { src_addr, dst_addr, tunnel_info, + command_tx, } => { + let command_tx = command_tx.unwrap_or_else(|| self.transport_commands.clone()); // initialize new stream let stream = Stream { connection_id, state: StreamState::Open, - command_tx: self.transport_commands.clone(), + command_tx, peername: src_addr, sockname: dst_addr, tunnel_info, diff --git a/src/messages.rs b/src/messages.rs index efbbe547..2e496c77 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -4,7 +4,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use anyhow::{anyhow, Result}; use internet_packet::InternetPacket; use smoltcp::wire::{IpProtocol, Ipv4Packet, Ipv6Packet}; -use tokio::sync::oneshot; +use tokio::sync::{mpsc, oneshot}; #[derive(Debug, Clone)] pub enum TunnelInfo { @@ -39,10 +39,10 @@ pub enum NetworkCommand { pub struct ConnectionIdGenerator(usize); impl ConnectionIdGenerator { - pub fn tcp() -> Self { + pub const fn tcp() -> Self { Self(2) } - pub fn udp() -> Self { + pub const fn udp() -> Self { Self(1) } pub fn next_id(&mut self) -> ConnectionId { @@ -85,6 +85,7 @@ pub enum TransportEvent { src_addr: SocketAddr, dst_addr: SocketAddr, tunnel_info: TunnelInfo, + command_tx: Option>, }, } @@ -98,14 +99,13 @@ pub enum TransportCommand { } impl TransportCommand { - pub fn is_tcp(&self) -> bool { + pub fn connection_id(&self) -> &ConnectionId { match self { TransportCommand::ReadData(id, _, _) => id, TransportCommand::WriteData(id, _) => id, TransportCommand::DrainWriter(id, _) => id, TransportCommand::CloseConnection(id, _) => id, } - .is_tcp() } } diff --git a/src/network/core.rs b/src/network/core.rs index 8575cec2..c9f62a36 100644 --- a/src/network/core.rs +++ b/src/network/core.rs @@ -89,7 +89,7 @@ impl<'a> NetworkStack<'a> { } pub fn handle_transport_command(&mut self, command: TransportCommand) { - if command.is_tcp() { + if command.connection_id().is_tcp() { self.tcp.handle_transport_command(command); } else if let Some(packet) = self.udp.handle_transport_command(command) { if self diff --git a/src/network/task.rs b/src/network/task.rs index 11a43ef2..6c06e57c 100755 --- a/src/network/task.rs +++ b/src/network/task.rs @@ -93,7 +93,7 @@ impl NetworkTask<'_> { // wait for graceful shutdown _ = self.shutdown.recv() => break 'task, // wait for timeouts when the device is idle - _ = async { tokio::time::sleep(delay.unwrap()).await }, if delay.is_some() => {}, + _ = tokio::time::sleep(delay.unwrap()), if delay.is_some() => {}, // wait for py_tx channel capacity... Ok(permit) = self.py_tx.reserve(), if !py_tx_available => { py_tx_permit = Some(permit); diff --git a/src/network/tcp.rs b/src/network/tcp.rs index ddd26b59..43e5d02e 100644 --- a/src/network/tcp.rs +++ b/src/network/tcp.rs @@ -147,6 +147,7 @@ impl<'a> TcpHandler<'a> { src_addr, dst_addr, tunnel_info, + command_tx: None, }; permit.send(event); } diff --git a/src/network/tests.rs b/src/network/tests.rs index bec43ca0..b655b220 100755 --- a/src/network/tests.rs +++ b/src/network/tests.rs @@ -356,7 +356,7 @@ async fn udp_read_write( connection_id, src_addr: recv_src_addr, dst_addr: recv_dst_addr, - tunnel_info: _, + .. } = event; assert_eq!(src_addr, recv_src_addr); @@ -494,7 +494,7 @@ async fn tcp_ipv4_connection() -> Result<()> { connection_id: tcp_conn_id, src_addr: tcp_src_sock, dst_addr: tcp_dst_sock, - tunnel_info: _, + .. } = event; assert_eq!(IpAddress::Ipv4(src_addr), tcp_src_sock.ip().into()); assert_eq!(IpAddress::Ipv4(dst_addr), tcp_dst_sock.ip().into()); @@ -671,7 +671,7 @@ async fn tcp_ipv6_connection() -> Result<()> { connection_id: tcp_conn_id, src_addr: tcp_src_sock, dst_addr: tcp_dst_sock, - tunnel_info: _, + .. } = event; assert_eq!(IpAddress::Ipv6(src_addr), tcp_src_sock.ip().into()); assert_eq!(IpAddress::Ipv6(dst_addr), tcp_dst_sock.ip().into()); diff --git a/src/network/udp.rs b/src/network/udp.rs index 8e45d6f1..a3a54b2c 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -65,6 +65,8 @@ impl ConnectionState { } } +pub const UDP_TIMEOUT: Duration = Duration::from_secs(60); + pub struct UdpHandler { connection_id_generator: ConnectionIdGenerator, id_lookup: LruCache<(SocketAddr, SocketAddr), ConnectionId>, @@ -73,12 +75,10 @@ pub struct UdpHandler { impl UdpHandler { pub fn new() -> Self { - let connections = LruCache::::with_expiry_duration( - Duration::from_secs(60), - ); - let id_lookup = LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration( - Duration::from_secs(60), - ); + let connections = + LruCache::::with_expiry_duration(UDP_TIMEOUT); + let id_lookup = + LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration(UDP_TIMEOUT); Self { connections, id_lookup, @@ -172,6 +172,7 @@ impl UdpHandler { src_addr: packet.src_addr, dst_addr: packet.dst_addr, tunnel_info, + command_tx: None, }); } }; diff --git a/src/packet_sources/macos.rs b/src/packet_sources/macos.rs index 7d758c0f..58b32c3a 100644 --- a/src/packet_sources/macos.rs +++ b/src/packet_sources/macos.rs @@ -1,7 +1,6 @@ -use std::collections::HashMap; -use std::net::{Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use crate::messages::{ConnectionId, TransportCommand, TransportEvent, TunnelInfo}; +use crate::messages::{ConnectionIdGenerator, TransportCommand, TransportEvent, TunnelInfo}; use crate::intercept_conf::InterceptConf; use crate::ipc; @@ -12,7 +11,7 @@ use async_trait::async_trait; use futures_util::SinkExt; use futures_util::StreamExt; -use prost::bytes::{Buf, BytesMut}; +use prost::bytes::BytesMut; use prost::Message; use std::process::Stdio; @@ -78,7 +77,7 @@ impl PacketSourceConf for MacosConf { async fn build( self, transport_events_tx: Sender, - transport_commands_rx: UnboundedReceiver, + _transport_commands_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, ) -> Result<(Self::Task, Self::Data)> { let listener_addr = format!("/tmp/mitmproxy-{}", std::process::id()); @@ -100,11 +99,7 @@ impl PacketSourceConf for MacosConf { control_channel, listener, connections: JoinSet::new(), - connection_by_id: HashMap::new(), - connection_by_addr: HashMap::new(), - next_connection_id: 0, transport_events_tx, - transport_commands_rx, conf_rx, shutdown, }, @@ -116,12 +111,8 @@ impl PacketSourceConf for MacosConf { pub struct MacOsTask { control_channel: UnixStream, listener: UnixListener, - connections: JoinSet)>>, - connection_by_id: HashMap>, - connection_by_addr: HashMap>, - next_connection_id: ConnectionId, + connections: JoinSet>, transport_events_tx: Sender, - transport_commands_rx: UnboundedReceiver, conf_rx: UnboundedReceiver, shutdown: broadcast::Receiver<()>, } @@ -131,9 +122,6 @@ impl PacketSourceTask for MacOsTask { async fn run(mut self) -> Result<()> { let mut control_channel = Framed::new(self.control_channel, LengthDelimitedCodec::new()); - let (register_addr_tx, mut register_addr_rx) = - unbounded_channel::(); - loop { tokio::select! { // wait for graceful shutdown @@ -144,66 +132,24 @@ impl PacketSourceTask for MacOsTask { }, Some(task) = self.connections.join_next() => { match task { - Ok(Ok((cid, src_addr))) => { - self.connection_by_id.remove(&cid); - if let Some(src_addr) = src_addr { - self.connection_by_addr.remove(&src_addr); - } - }, + Ok(Ok(())) => (), Ok(Err(e)) => log::error!("Connection task failure: {e:?}"), Err(e) => log::error!("Connection task panic: {e:?}"), } }, - Some(RegisterConnectionSocketAddr(cid, addr, done)) = register_addr_rx.recv() => { - let tx = self.connection_by_id.get(&cid).unwrap().clone(); - self.connection_by_addr.insert(addr, tx); - done.send(()).expect("ok channel dead"); - }, l = self.listener.accept() => { match l { Ok((stream, _)) => { - let (conn_tx, conn_rx) = unbounded_channel(); - let connection_id = { - self.next_connection_id += 1; - self.next_connection_id - }; - self.connections.spawn( - ConnectionTask::new(connection_id, stream, conn_rx, self.transport_events_tx.clone(), register_addr_tx.clone()) - .run() - ); - self.connection_by_id.insert( - connection_id, - conn_tx + let task = ConnectionTask::new( + stream, + self.transport_events_tx.clone(), + self.shutdown.resubscribe(), ); + self.connections.spawn(task.run()); }, Err(e) => log::error!("Error accepting connection from macos-redirector: {}", e) } }, - Some(cmd) = self.transport_commands_rx.recv() => { - match &cmd { - TransportCommand::ReadData(connection_id, _, _) - | TransportCommand::WriteData(connection_id, _) - | TransportCommand::DrainWriter(connection_id, _) - | TransportCommand::CloseConnection(connection_id, _) => { - let Some(conn_tx) = self.connection_by_id.get(connection_id) else { - log::error!("Received command for unknown connection: {:?}", &cmd); - continue; - }; - conn_tx.send(cmd).ok(); - }, - TransportCommand::SendDatagram { - data: _, - src_addr, - dst_addr, - } => { - let Some(conn_tx) = self.connection_by_addr.get(dst_addr) else { - log::error!("Received command for unknown address: src={:?} dst={:?}", src_addr, dst_addr); - continue; - }; - conn_tx.send(cmd).ok(); - }, - } - } // pipe through changes to the intercept list Some(conf) = self.conf_rx.recv() => { let msg = ipc::InterceptConf::from(conf); @@ -220,37 +166,25 @@ impl PacketSourceTask for MacOsTask { } } -struct RegisterConnectionSocketAddr(ConnectionId, SocketAddr, oneshot::Sender<()>); - struct ConnectionTask { - id: ConnectionId, stream: UnixStream, - commands: UnboundedReceiver, events: Sender, - read_tx: Option<(usize, oneshot::Sender>)>, - drain_tx: Option>, - register_addr: UnboundedSender, + shutdown: broadcast::Receiver<()>, } impl ConnectionTask { pub fn new( - id: ConnectionId, stream: UnixStream, - commands: UnboundedReceiver, events: Sender, - register_addr: UnboundedSender, + shutdown: broadcast::Receiver<()>, ) -> Self { Self { - id, stream, - commands, events, - read_tx: None, - drain_tx: None, - register_addr, + shutdown, } } - async fn run(mut self) -> Result<(ConnectionId, Option)> { + async fn run(mut self) -> Result<()> { let new_flow = { let len = self .stream @@ -276,7 +210,7 @@ impl ConnectionTask { } } - async fn handle_udp(mut self, flow: UdpFlow) -> Result<(ConnectionId, Option)> { + async fn handle_udp(mut self, flow: UdpFlow) -> Result<()> { // For UDP connections, we pass length-delimited protobuf messages over the unix socket // in both directions. let mut write_buf = BytesMut::new(); @@ -292,134 +226,150 @@ impl ConnectionTask { remote_endpoint: None, } }; - let local_addr = { + let local_address = { let Some(addr) = &flow.local_address else { bail!("no local address") }; SocketAddr::try_from(addr)? }; + let mut remote_address = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let (command_tx, mut command_rx) = unbounded_channel(); - // Send our socket address to the main macos task and wait until it has been processed. - let (done_tx, done_rx) = oneshot::channel(); - self.register_addr - .send(RegisterConnectionSocketAddr(self.id, local_addr, done_tx))?; - done_rx.await?; + let mut first_packet = Some((tunnel_info, local_address, command_tx)); + + let mut read_data: Option> = None; + let mut read_tx: Option>> = None; loop { tokio::select! { - packet = stream.next() => { - let Some(packet) = packet else { - break; - }; + _ = self.shutdown.recv() => break, + Some(packet) = stream.next(), if read_data.is_none() => { let packet = ipc::UdpPacket::decode( packet.context("IPC read error")? ).context("invalid IPC message")?; - let dst_addr = { let Some(dst_addr) = &packet.remote_address else { bail!("no remote addr") }; SocketAddr::try_from(dst_addr).context("invalid socket address")? }; - todo!(); - if let Err(e) = self.events.try_send(TransportEvent::DatagramReceived { - data: packet.data, - src_addr: local_addr, - dst_addr, - tunnel_info: tunnel_info.clone(), - }) { - log::debug!("Failed to send UDP packet: {}", e); + // We can only send ConnectionEstablished once we know the destination address. + if let Some((tunnel_info, local_address, command_tx)) = first_packet.take() { + remote_address = dst_addr; + self.events.send(TransportEvent::ConnectionEstablished { + connection_id: ConnectionIdGenerator::udp().next_id(), + src_addr: local_address, + dst_addr, + tunnel_info, + command_tx: Some(command_tx), + }).await?; + } else if remote_address != dst_addr { + bail!("UDP packet destinations do not match: {remote_address} -> {dst_addr}") + } + if let Some(tx) = read_tx.take() { + tx.send(packet.data).ok(); + } else { + read_data = Some(packet.data); } }, - command = self.commands.recv() => { - let Some(command) = command else { - break; - }; + Some(command) = command_rx.recv() => { match command { - todo!(); - TransportCommand::SendDatagram { data, src_addr, dst_addr } => { - assert_eq!(dst_addr, local_addr); + TransportCommand::ReadData(_, _, tx) => { + if let Some(data) = read_data.take() { + tx.send(data).ok(); + } else { + if read_tx.is_some() { + bail!("Concurrent readers are not supported."); + } + read_tx = Some(tx); + } + }, + TransportCommand::WriteData(_, data) => { + assert!(first_packet.is_none()); let packet = ipc::UdpPacket { data, - remote_address: Some(src_addr.into()), + remote_address: Some(remote_address.into()), }; write_buf.reserve(packet.encoded_len()); packet.encode(&mut write_buf)?; - stream.send(write_buf.split().freeze()).await?; + // Awaiting here isn't ideal because it blocks reading, but what to do. + stream.send(write_buf.split().freeze()).await.ok(); }, - TransportCommand::ReadData(_, _, _) | - TransportCommand::WriteData(_, _) | - TransportCommand::DrainWriter(_, _) | - TransportCommand::CloseConnection(_, _) => { - bail!("UDP connection received TCP event: {command:?}"); + TransportCommand::DrainWriter(_, tx) => { + tx.send(()).ok(); + }, + TransportCommand::CloseConnection(_, half_close) => { + if !half_close { + break; + } } } } + else => break, } } - Ok((self.id, Some(local_addr))) + Ok(()) } - async fn handle_tcp(mut self, flow: TcpFlow) -> Result<(ConnectionId, Option)> { + async fn handle_tcp(mut self, flow: TcpFlow) -> Result<()> { let mut write_buf = BytesMut::new(); + let mut drain_tx: Option> = None; + let mut read_tx: Option<(usize, oneshot::Sender>)> = None; + + let (command_tx, mut command_rx) = unbounded_channel(); let remote = flow.remote_address.expect("no remote address"); let src_addr = SocketAddr::from((Ipv4Addr::LOCALHOST, 0)); - let dst_addr = match SocketAddr::try_from(&remote) { - Ok(addr) => addr, - Err(_) => SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)), + let dst_addr = SocketAddr::try_from(&remote) + .unwrap_or_else(|_| SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))); + let tunnel_info = TunnelInfo::LocalRedirector { + pid: flow.tunnel_info.as_ref().map(|t| t.pid).unwrap_or(0), + process_name: flow.tunnel_info.and_then(|t| t.process_name), + remote_endpoint: Some((remote.host, remote.port as u16)), }; - let remote_endpoint = Some((remote.host, remote.port as u16)); self.events .send(TransportEvent::ConnectionEstablished { - connection_id: self.id, + connection_id: ConnectionIdGenerator::tcp().next_id(), src_addr, dst_addr, - tunnel_info: TunnelInfo::LocalRedirector { - pid: flow.tunnel_info.as_ref().map(|t| t.pid).unwrap_or(0), - process_name: flow.tunnel_info.and_then(|t| t.process_name), - remote_endpoint, - }, + tunnel_info, + command_tx: Some(command_tx), }) .await?; loop { tokio::select! { + _ = self.shutdown.recv() => break, Ok(()) = self.stream.writable(), if !write_buf.is_empty() => { self.stream.write_buf(&mut write_buf).await.context("failed to write to socket from buf")?; if write_buf.is_empty() { - if let Some(tx) = self.drain_tx.take() { + if let Some(tx) = drain_tx.take() { tx.send(()).ok(); } } }, - Ok(()) = self.stream.readable(), if self.read_tx.is_some() => { - let (n, tx) = self.read_tx.take().unwrap(); + Ok(()) = self.stream.readable(), if read_tx.is_some() => { + let (n, tx) = read_tx.take().unwrap(); let mut data = Vec::with_capacity(n); self.stream.read_buf(&mut data).await.context("failed to read from socket")?; tx.send(data).ok(); }, - command = self.commands.recv() => { - let Some(command) = command else { - break; - }; + Some(command) = command_rx.recv() => { match command { TransportCommand::ReadData(_, n, tx) => { - assert!(self.read_tx.is_none()); - self.read_tx = Some((n as usize, tx)); + assert!(read_tx.is_none()); + read_tx = Some((n as usize, tx)); }, TransportCommand::WriteData(_, data) => { - let mut c = std::io::Cursor::new(data); - self.stream.write_buf(&mut c).await.context("failed to write to socket")?; - write_buf.extend_from_slice(c.chunk()); + write_buf.extend_from_slice(data.as_slice()); }, TransportCommand::DrainWriter(_, tx) => { - assert!(self.drain_tx.is_none()); + assert!(drain_tx.is_none()); if write_buf.is_empty() { tx.send(()).ok(); } else { - self.drain_tx = Some(tx); + drain_tx = Some(tx); } }, TransportCommand::CloseConnection(_, half_close) => { @@ -430,9 +380,10 @@ impl ConnectionTask { } } } - } + }, + else => break, } } - Ok((self.id, None)) + Ok(()) } } From f616bb972afa4a51e767b395be0d52b9bfc26dfe Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 17 Dec 2023 22:37:00 +0100 Subject: [PATCH 12/20] fix tests --- src/network/task.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/network/task.rs b/src/network/task.rs index 6c06e57c..11a43ef2 100755 --- a/src/network/task.rs +++ b/src/network/task.rs @@ -93,7 +93,7 @@ impl NetworkTask<'_> { // wait for graceful shutdown _ = self.shutdown.recv() => break 'task, // wait for timeouts when the device is idle - _ = tokio::time::sleep(delay.unwrap()), if delay.is_some() => {}, + _ = async { tokio::time::sleep(delay.unwrap()).await }, if delay.is_some() => {}, // wait for py_tx channel capacity... Ok(permit) = self.py_tx.reserve(), if !py_tx_available => { py_tx_permit = Some(permit); From 8d3434b82dc6f77ff06b6b6e6be297ca12a1fa8d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 17 Dec 2023 22:43:38 +0100 Subject: [PATCH 13/20] ignore typevar in stubtest --- mitmproxy-rs/stubtest-allowlist.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mitmproxy-rs/stubtest-allowlist.txt b/mitmproxy-rs/stubtest-allowlist.txt index 472a68e3..7a781506 100644 --- a/mitmproxy-rs/stubtest-allowlist.txt +++ b/mitmproxy-rs/stubtest-allowlist.txt @@ -2,3 +2,4 @@ mitmproxy_rs.mitmproxy_rs mitmproxy_rs._pyinstaller.hook-mitmproxy_rs mitmproxy_rs._pyinstaller.hook-mitmproxy_windows mitmproxy_rs._pyinstaller.hook-mitmproxy_macos +mitmproxy_rs.T \ No newline at end of file From 99ba2742684ff37dfcff77c77e002c7eead862a1 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 Dec 2023 00:23:09 +0100 Subject: [PATCH 14/20] improve error handling --- mitmproxy-rs/src/server/udp.rs | 5 ++--- mitmproxy-rs/src/udp_client.rs | 25 +++++++++++++-------- src/messages.rs | 8 +++++-- src/network/udp.rs | 41 +++++++++++++++++----------------- src/packet_sources/macos.rs | 17 +++++++++----- src/packet_sources/udp.rs | 34 +++++++++------------------- 6 files changed, 67 insertions(+), 63 deletions(-) diff --git a/mitmproxy-rs/src/server/udp.rs b/mitmproxy-rs/src/server/udp.rs index 8bebce36..6ef6e4b2 100644 --- a/mitmproxy-rs/src/server/udp.rs +++ b/mitmproxy-rs/src/server/udp.rs @@ -24,7 +24,7 @@ pub struct UdpServer { #[pymethods] impl UdpServer { - /// Request the WireGuard server to gracefully shut down. + /// Request the server to gracefully shut down. /// /// The server will stop accepting new connections on its UDP socket, but will flush pending /// outgoing data before shutting down. @@ -32,7 +32,7 @@ impl UdpServer { self.server.close() } - /// Wait until the WireGuard server has shut down. + /// Wait until the server has shut down. /// /// This coroutine will yield once pending data has been flushed and all server tasks have /// successfully terminated after calling the `Server.close` method. @@ -54,7 +54,6 @@ impl UdpServer { /// /// - `host`: The host address. /// - `port`: The listen port. -/// - `handle_tcp_stream`: An async function that will be called for each new TCP `Stream`. /// - `handle_udp_stream`: An async function that will be called for each new UDP `Stream`. #[pyfunction] pub fn start_udp_server( diff --git a/mitmproxy-rs/src/udp_client.rs b/mitmproxy-rs/src/udp_client.rs index 56e13466..f470119f 100644 --- a/mitmproxy-rs/src/udp_client.rs +++ b/mitmproxy-rs/src/udp_client.rs @@ -32,13 +32,15 @@ pub fn open_udp_connection( let (command_tx, command_rx) = unbounded_channel(); - tokio::spawn( - UdpClientTask { + tokio::spawn(async { + let task = UdpClientTask { socket, transport_commands_rx: command_rx, + }; + if let Err(e) = task.run().await { + log::error!("UDP client errored: {e}"); } - .run(), - ); + }); let stream = Stream { connection_id: ConnectionId::unassigned(), @@ -100,7 +102,7 @@ pub struct UdpClientTask { } impl UdpClientTask { - pub async fn run(mut self) { + pub async fn run(mut self) -> Result<()> { let mut udp_buf = [0; MAX_PACKET_SIZE]; // this here isn't perfect because we block the entire transport_commands_rx channel if we @@ -113,7 +115,8 @@ impl UdpClientTask { loop { tokio::select! { // wait for transport_events_tx channel capacity... - Ok(len) = self.socket.recv(&mut udp_buf), if packet_tx.is_some() => { + len = self.socket.recv(&mut udp_buf), if packet_tx.is_some() => { + let len = len.context("UDP recv() failed")?; packet_tx .take() .unwrap() @@ -121,10 +124,14 @@ impl UdpClientTask { .ok(); }, // send_to is cancel safe, so we can use that for backpressure. - _ = self.socket.send(&packet_payload), if packet_needs_sending => { + e = self.socket.send(&packet_payload), if packet_needs_sending => { + e.context("UDP send() failed")?; packet_needs_sending = false; }, - Some(command) = self.transport_commands_rx.recv(), if !packet_needs_sending => { + command = self.transport_commands_rx.recv(), if !packet_needs_sending => { + let Some(command) = command else { + break; + }; match command { TransportCommand::ReadData(_,_,tx) => { packet_tx = Some(tx); @@ -143,9 +150,9 @@ impl UdpClientTask { }, } } - else => break, } } log::debug!("UDP client task shutting down."); + Ok(()) } } diff --git a/src/messages.rs b/src/messages.rs index 2e496c77..2838434e 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -56,7 +56,7 @@ impl ConnectionIdGenerator { pub struct ConnectionId(usize); impl ConnectionId { pub fn is_tcp(&self) -> bool { - self.0 & 1 == 0 + self.0 > 0 && self.0 & 1 == 0 } pub fn unassigned() -> Self { ConnectionId(0) @@ -69,7 +69,9 @@ impl fmt::Display for ConnectionId { } impl fmt::Debug for ConnectionId { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.is_tcp() { + if self.0 == 0 { + write!(f, "0") + } else if self.is_tcp() { write!(f, "{}#TCP", self.0) } else { write!(f, "{}#UDP", self.0) @@ -85,6 +87,8 @@ pub enum TransportEvent { src_addr: SocketAddr, dst_addr: SocketAddr, tunnel_info: TunnelInfo, + // Channel over which the stream should emit commands. + // If command_tx is None, the main channel is used. command_tx: Option>, }, } diff --git a/src/network/udp.rs b/src/network/udp.rs index a3a54b2c..662d18d0 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -35,15 +35,16 @@ impl ConnectionState { read_tx: None, } } - fn receive_packet_payload(&mut self, data: Vec) { + fn add_packet(&mut self, data: Vec) { if self.closed { + drop(data); } else if let Some(tx) = self.read_tx.take() { tx.send(data).ok(); } else { self.packets.push_back(data); } } - fn read_packet_payload(&mut self, tx: oneshot::Sender>) { + fn add_reader(&mut self, tx: oneshot::Sender>) { assert!(self.read_tx.is_none()); if self.closed { drop(tx); @@ -55,13 +56,13 @@ impl ConnectionState { } fn close(&mut self) { if self.closed { + // already closed. } else if let Some(tx) = self.read_tx.take() { drop(tx); - self.closed = true; } else { self.packets.clear(); - self.closed = true; } + self.closed = true; } } @@ -75,6 +76,8 @@ pub struct UdpHandler { impl UdpHandler { pub fn new() -> Self { + // This implementation is largely based on the fact that LruCache eventually + // drops the state, which closes the respective channels. let connections = LruCache::::with_expiry_duration(UDP_TIMEOUT); let id_lookup = @@ -109,7 +112,7 @@ impl UdpHandler { pub fn read_data(&mut self, id: ConnectionId, tx: oneshot::Sender>) { if let Some(state) = self.connections.get_mut(&id) { - state.read_packet_payload(tx); + state.add_reader(tx); } } @@ -154,15 +157,13 @@ impl UdpHandler { .cloned() .unwrap_or(ConnectionId::unassigned()); - let payload = packet.payload; - match self.connections.get_mut(&potential_cid) { Some(state) => { - state.receive_packet_payload(payload); + state.add_packet(packet.payload); } None => { let mut state = ConnectionState::new(packet.src_addr, packet.dst_addr); - state.receive_packet_payload(payload); + state.add_packet(packet.payload); let connection_id = self.connection_id_generator.next_id(); self.id_lookup .insert((packet.src_addr, packet.dst_addr), connection_id); @@ -280,13 +281,13 @@ mod tests { #[test] fn test_connection_state_recv_recv_read_read() { let mut state = ConnectionState::new(SRC, DST); - state.receive_packet_payload(vec![1, 2, 3]); - state.receive_packet_payload(vec![4, 5, 6]); + state.add_packet(vec![1, 2, 3]); + state.add_packet(vec![4, 5, 6]); let (tx, rx) = oneshot::channel(); - state.read_packet_payload(tx); + state.add_reader(tx); assert_eq!(vec![1, 2, 3], rx.blocking_recv().unwrap()); let (tx, rx) = oneshot::channel(); - state.read_packet_payload(tx); + state.add_reader(tx); assert_eq!(vec![4, 5, 6], rx.blocking_recv().unwrap()); } @@ -294,9 +295,9 @@ mod tests { fn test_connection_state_read_recv_recv() { let mut state = ConnectionState::new(SRC, DST); let (tx, rx) = oneshot::channel(); - state.read_packet_payload(tx); - state.receive_packet_payload(vec![1, 2, 3]); - state.receive_packet_payload(vec![4, 5, 6]); + state.add_reader(tx); + state.add_packet(vec![1, 2, 3]); + state.add_packet(vec![4, 5, 6]); assert_eq!(vec![1, 2, 3], rx.blocking_recv().unwrap()); } @@ -305,8 +306,8 @@ mod tests { let mut state = ConnectionState::new(SRC, DST); let (tx, rx) = oneshot::channel(); state.close(); - state.receive_packet_payload(vec![1, 2, 3]); - state.read_packet_payload(tx); + state.add_packet(vec![1, 2, 3]); + state.add_reader(tx); assert!(rx.blocking_recv().is_err()); } @@ -314,9 +315,9 @@ mod tests { fn test_connection_state_read_close_recv() { let mut state = ConnectionState::new(SRC, DST); let (tx, rx) = oneshot::channel(); - state.read_packet_payload(tx); + state.add_reader(tx); state.close(); - state.receive_packet_payload(vec![1, 2, 3]); + state.add_packet(vec![1, 2, 3]); assert!(rx.blocking_recv().is_err()); } } diff --git a/src/packet_sources/macos.rs b/src/packet_sources/macos.rs index 58b32c3a..8f49accc 100644 --- a/src/packet_sources/macos.rs +++ b/src/packet_sources/macos.rs @@ -243,7 +243,10 @@ impl ConnectionTask { loop { tokio::select! { _ = self.shutdown.recv() => break, - Some(packet) = stream.next(), if read_data.is_none() => { + packet = stream.next(), if read_data.is_none() => { + let Some(packet) = packet else { + break + }; let packet = ipc::UdpPacket::decode( packet.context("IPC read error")? ).context("invalid IPC message")?; @@ -271,7 +274,10 @@ impl ConnectionTask { read_data = Some(packet.data); } }, - Some(command) = command_rx.recv() => { + command = command_rx.recv() => { + let Some(command) = command else { + break; + }; match command { TransportCommand::ReadData(_, _, tx) => { if let Some(data) = read_data.take() { @@ -304,7 +310,6 @@ impl ConnectionTask { } } } - else => break, } } @@ -355,7 +360,10 @@ impl ConnectionTask { self.stream.read_buf(&mut data).await.context("failed to read from socket")?; tx.send(data).ok(); }, - Some(command) = command_rx.recv() => { + command = command_rx.recv() => { + let Some(command) = command else { + break; + }; match command { TransportCommand::ReadData(_, n, tx) => { assert!(read_tx.is_none()); @@ -381,7 +389,6 @@ impl ConnectionTask { } } }, - else => break, } } Ok(()) diff --git a/src/packet_sources/udp.rs b/src/packet_sources/udp.rs index da228b9d..bcdab639 100644 --- a/src/packet_sources/udp.rs +++ b/src/packet_sources/udp.rs @@ -88,14 +88,21 @@ impl PacketSourceTask for UdpTask { // wait for transport_events_tx channel capacity... Ok(p) = transport_events_tx.reserve(), if !py_tx_available => { permit = Some(p); - continue; }, // ... or process incoming packets Ok((len, src_addr)) = self.socket.recv_from(&mut udp_buf), if py_tx_available => { - self.process_incoming_datagram(&udp_buf[..len], src_addr, permit.take().unwrap()).await?; + self.handler.receive_data( + UdpPacket { + src_addr, + dst_addr: self.local_addr, + payload: udp_buf[..len].to_vec(), + }, + TunnelInfo::Udp {}, + permit.take().unwrap() + ); }, // send_to is cancel safe, so we can use that for backpressure. - _ = self.socket.send_to(&packet_payload, packet_dst), if packet_needs_sending => { + Ok(_) = self.socket.send_to(&packet_payload, packet_dst), if packet_needs_sending => { packet_needs_sending = false; }, Some(command) = self.transport_commands_rx.recv(), if !packet_needs_sending => { @@ -111,24 +118,3 @@ impl PacketSourceTask for UdpTask { Ok(()) } } - -impl UdpTask { - async fn process_incoming_datagram( - &mut self, - data: &[u8], - sender_addr: SocketAddr, - permit: Permit<'_, TransportEvent>, - ) -> Result<()> { - let packet = UdpPacket { - src_addr: sender_addr, - dst_addr: self.local_addr, - payload: data.to_vec(), - }; - let tunnel_info = TunnelInfo::WireGuard { - src_addr: sender_addr, - dst_addr: self.socket.local_addr()?, - }; - self.handler.receive_data(packet, tunnel_info, permit); - Ok(()) - } -} From cd85233c3656958310bcc630e8877366cb83ad9c Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 Dec 2023 00:26:49 +0100 Subject: [PATCH 15/20] unassigned -> udp --- mitmproxy-rs/src/udp_client.rs | 2 +- src/messages.rs | 12 +++++------- src/network/udp.rs | 2 +- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/mitmproxy-rs/src/udp_client.rs b/mitmproxy-rs/src/udp_client.rs index f470119f..d31e8f9f 100644 --- a/mitmproxy-rs/src/udp_client.rs +++ b/mitmproxy-rs/src/udp_client.rs @@ -43,7 +43,7 @@ pub fn open_udp_connection( }); let stream = Stream { - connection_id: ConnectionId::unassigned(), + connection_id: ConnectionId::unassigned_udp(), state: StreamState::Open, command_tx, peername, diff --git a/src/messages.rs b/src/messages.rs index 2838434e..d1ef4f32 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -43,7 +43,7 @@ impl ConnectionIdGenerator { Self(2) } pub const fn udp() -> Self { - Self(1) + Self(3) } pub fn next_id(&mut self) -> ConnectionId { let ret = ConnectionId(self.0); @@ -56,10 +56,10 @@ impl ConnectionIdGenerator { pub struct ConnectionId(usize); impl ConnectionId { pub fn is_tcp(&self) -> bool { - self.0 > 0 && self.0 & 1 == 0 + self.0 & 1 == 0 } - pub fn unassigned() -> Self { - ConnectionId(0) + pub const fn unassigned_udp() -> Self { + ConnectionId(1) } } impl fmt::Display for ConnectionId { @@ -69,9 +69,7 @@ impl fmt::Display for ConnectionId { } impl fmt::Debug for ConnectionId { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.0 == 0 { - write!(f, "0") - } else if self.is_tcp() { + if self.is_tcp() { write!(f, "{}#TCP", self.0) } else { write!(f, "{}#UDP", self.0) diff --git a/src/network/udp.rs b/src/network/udp.rs index 662d18d0..90c95df6 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -155,7 +155,7 @@ impl UdpHandler { .id_lookup .get(&(packet.src_addr, packet.dst_addr)) .cloned() - .unwrap_or(ConnectionId::unassigned()); + .unwrap_or(ConnectionId::unassigned_udp()); match self.connections.get_mut(&potential_cid) { Some(state) => { From 74a030e0d93197f1366099456e8fa6f777c28b55 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 Dec 2023 15:15:46 +0100 Subject: [PATCH 16/20] macos: use udp connection state --- src/network/udp.rs | 69 +++++++++++++++---------------------- src/packet_sources/macos.rs | 22 ++++-------- 2 files changed, 34 insertions(+), 57 deletions(-) diff --git a/src/network/udp.rs b/src/network/udp.rs index 90c95df6..843efc6b 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -17,25 +17,15 @@ use smoltcp::wire::{ UdpRepr, }; -struct ConnectionState { - remote_addr: SocketAddr, - local_addr: SocketAddr, +#[derive(Default)] +pub struct ConnectionState { closed: bool, packets: VecDeque>, read_tx: Option>>, } impl ConnectionState { - fn new(remote_addr: SocketAddr, local_addr: SocketAddr) -> Self { - Self { - remote_addr, - local_addr, - closed: false, - packets: VecDeque::new(), - read_tx: None, - } - } - fn add_packet(&mut self, data: Vec) { + pub fn add_packet(&mut self, data: Vec) { if self.closed { drop(data); } else if let Some(tx) = self.read_tx.take() { @@ -44,7 +34,10 @@ impl ConnectionState { self.packets.push_back(data); } } - fn add_reader(&mut self, tx: oneshot::Sender>) { + pub fn packet_queue_len(&self) -> usize { + self.packets.len() + } + pub fn add_reader(&mut self, tx: oneshot::Sender>) { assert!(self.read_tx.is_none()); if self.closed { drop(tx); @@ -54,7 +47,7 @@ impl ConnectionState { self.read_tx = Some(tx); } } - fn close(&mut self) { + pub fn close(&mut self) { if self.closed { // already closed. } else if let Some(tx) = self.read_tx.take() { @@ -68,23 +61,21 @@ impl ConnectionState { pub const UDP_TIMEOUT: Duration = Duration::from_secs(60); +type FourTuple = (SocketAddr, SocketAddr); + pub struct UdpHandler { connection_id_generator: ConnectionIdGenerator, - id_lookup: LruCache<(SocketAddr, SocketAddr), ConnectionId>, - connections: LruCache, + id_lookup: LruCache, + connections: LruCache, } impl UdpHandler { pub fn new() -> Self { // This implementation is largely based on the fact that LruCache eventually // drops the state, which closes the respective channels. - let connections = - LruCache::::with_expiry_duration(UDP_TIMEOUT); - let id_lookup = - LruCache::<(SocketAddr, SocketAddr), ConnectionId>::with_expiry_duration(UDP_TIMEOUT); Self { - connections, - id_lookup, + connections: LruCache::with_expiry_duration(UDP_TIMEOUT), + id_lookup: LruCache::with_expiry_duration(UDP_TIMEOUT), connection_id_generator: ConnectionIdGenerator::udp(), } } @@ -111,26 +102,25 @@ impl UdpHandler { } pub fn read_data(&mut self, id: ConnectionId, tx: oneshot::Sender>) { - if let Some(state) = self.connections.get_mut(&id) { + if let Some((state, _)) = self.connections.get_mut(&id) { state.add_reader(tx); } } pub(crate) fn write_data(&mut self, id: ConnectionId, data: Vec) -> Option { - let Some(state) = self.connections.get(&id) else { + let Some((state, addrs)) = self.connections.get(&id) else { return None; }; // Refresh id lookup. - self.id_lookup - .insert((state.local_addr, state.remote_addr), id); + self.id_lookup.insert(*addrs, id); if state.closed { return None; } Some(UdpPacket { - src_addr: state.local_addr, - dst_addr: state.remote_addr, + src_addr: addrs.0, + dst_addr: addrs.1, payload: data, }) } @@ -140,7 +130,7 @@ impl UdpHandler { } pub fn close_connection(&mut self, id: ConnectionId) { - if let Some(state) = self.connections.get_mut(&id) { + if let Some((state, _)) = self.connections.get_mut(&id) { state.close(); } } @@ -158,16 +148,17 @@ impl UdpHandler { .unwrap_or(ConnectionId::unassigned_udp()); match self.connections.get_mut(&potential_cid) { - Some(state) => { + Some((state, _)) => { state.add_packet(packet.payload); } None => { - let mut state = ConnectionState::new(packet.src_addr, packet.dst_addr); + let mut state = ConnectionState::default(); state.add_packet(packet.payload); let connection_id = self.connection_id_generator.next_id(); self.id_lookup .insert((packet.src_addr, packet.dst_addr), connection_id); - self.connections.insert(connection_id, state); + self.connections + .insert(connection_id, (state, (packet.src_addr, packet.dst_addr))); permit.send(TransportEvent::ConnectionEstablished { connection_id, src_addr: packet.src_addr, @@ -273,14 +264,10 @@ impl From for SmolPacket { #[cfg(test)] mod tests { use super::*; - use std::net::{IpAddr, Ipv4Addr}; - - const SRC: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 54321); - const DST: SocketAddr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 80); #[test] fn test_connection_state_recv_recv_read_read() { - let mut state = ConnectionState::new(SRC, DST); + let mut state = ConnectionState::default(); state.add_packet(vec![1, 2, 3]); state.add_packet(vec![4, 5, 6]); let (tx, rx) = oneshot::channel(); @@ -293,7 +280,7 @@ mod tests { #[test] fn test_connection_state_read_recv_recv() { - let mut state = ConnectionState::new(SRC, DST); + let mut state = ConnectionState::default(); let (tx, rx) = oneshot::channel(); state.add_reader(tx); state.add_packet(vec![1, 2, 3]); @@ -303,7 +290,7 @@ mod tests { #[test] fn test_connection_state_close_recv_read() { - let mut state = ConnectionState::new(SRC, DST); + let mut state = ConnectionState::default(); let (tx, rx) = oneshot::channel(); state.close(); state.add_packet(vec![1, 2, 3]); @@ -313,7 +300,7 @@ mod tests { #[test] fn test_connection_state_read_close_recv() { - let mut state = ConnectionState::new(SRC, DST); + let mut state = ConnectionState::default(); let (tx, rx) = oneshot::channel(); state.add_reader(tx); state.close(); diff --git a/src/packet_sources/macos.rs b/src/packet_sources/macos.rs index 8f49accc..d57f3486 100644 --- a/src/packet_sources/macos.rs +++ b/src/packet_sources/macos.rs @@ -21,6 +21,7 @@ use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{UnixListener, UnixStream}; +use crate::network::udp::ConnectionState; use tokio::process::Command; use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; @@ -237,13 +238,12 @@ impl ConnectionTask { let mut first_packet = Some((tunnel_info, local_address, command_tx)); - let mut read_data: Option> = None; - let mut read_tx: Option>> = None; + let mut state = ConnectionState::default(); loop { tokio::select! { _ = self.shutdown.recv() => break, - packet = stream.next(), if read_data.is_none() => { + packet = stream.next(), if state.packet_queue_len() < 10 => { let Some(packet) = packet else { break }; @@ -268,11 +268,7 @@ impl ConnectionTask { } else if remote_address != dst_addr { bail!("UDP packet destinations do not match: {remote_address} -> {dst_addr}") } - if let Some(tx) = read_tx.take() { - tx.send(packet.data).ok(); - } else { - read_data = Some(packet.data); - } + state.add_packet(packet.data); }, command = command_rx.recv() => { let Some(command) = command else { @@ -280,14 +276,7 @@ impl ConnectionTask { }; match command { TransportCommand::ReadData(_, _, tx) => { - if let Some(data) = read_data.take() { - tx.send(data).ok(); - } else { - if read_tx.is_some() { - bail!("Concurrent readers are not supported."); - } - read_tx = Some(tx); - } + state.add_reader(tx); }, TransportCommand::WriteData(_, data) => { assert!(first_packet.is_none()); @@ -305,6 +294,7 @@ impl ConnectionTask { }, TransportCommand::CloseConnection(_, half_close) => { if !half_close { + state.close(); break; } } From f0a9324e2660925c74321c5504a9b72bfe565f1a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 Dec 2023 15:49:46 +0100 Subject: [PATCH 17/20] fix binding on macos --- mitmproxy-rs/src/stream.rs | 6 ++++- mitmproxy-rs/src/udp_client.rs | 43 ++++++++++++++++------------------ 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/mitmproxy-rs/src/stream.rs b/mitmproxy-rs/src/stream.rs index a59b6076..4e036ca8 100644 --- a/mitmproxy-rs/src/stream.rs +++ b/mitmproxy-rs/src/stream.rs @@ -187,7 +187,11 @@ impl Stream { } => match name.as_str() { "pid" => return Ok(pid.into_py(py)), "process_name" => return Ok(process_name.clone().into_py(py)), - "remote_endpoint" => return Ok(remote_endpoint.clone().into_py(py)), + "remote_endpoint" => { + if let Some(endpoint) = remote_endpoint { + return Ok(endpoint.clone().into_py(py)); + } + } _ => (), }, TunnelInfo::Udp {} => (), diff --git a/mitmproxy-rs/src/udp_client.rs b/mitmproxy-rs/src/udp_client.rs index d31e8f9f..cfb20b06 100644 --- a/mitmproxy-rs/src/udp_client.rs +++ b/mitmproxy-rs/src/udp_client.rs @@ -64,35 +64,32 @@ async fn udp_connect( ) -> Result { let addrs: Vec = lookup_host((host.as_str(), port)) .await - .with_context(|| format!("unable to resolve hostname: {}", host))? + .with_context(|| format!("unable to resolve hostname: {host}"))? .collect(); - if let Some((host, port)) = local_addr { - let socket = UdpSocket::bind((host.as_str(), port)) + let socket = if let Some((host, port)) = local_addr { + UdpSocket::bind((host.as_str(), port)) .await - .with_context(|| format!("unable to bind to ({}, {})", host, port))?; - socket - .connect(addrs.as_slice()) + .with_context(|| format!("unable to bind to ({}, {})", host, port))? + } else if addrs.iter().any(|x| x.is_ipv4()) { + // we initially tried to bind to IPv6 by default if that doesn't fail, + // but binding mysteriously works if there are only IPv4 addresses in addrs, + // and then we get a weird "invalid argument" error when calling socket.recv(). + // So we just do the lazy thing and do IPv4 by default. + UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)) .await - .context("unable to connect to remote address")?; - Ok(socket) + .context("unable to bind to 127.0.0.1:0")? } else { - if let Ok(socket) = - UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)).await - { - if socket.connect(addrs.as_slice()).await.is_ok() { - return Ok(socket); - } - } - let socket = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)) - .await - .context("unable to bind to 127.0.0.1:0")?; - socket - .connect(addrs.as_slice()) + UdpSocket::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)) .await - .context("unable to connect to remote address")?; - Ok(socket) - } + .context("unable to bind to [::]:0")? + }; + + socket + .connect(addrs.as_slice()) + .await + .with_context(|| format!("unable to connect to {host}"))?; + Ok(socket) } #[derive(Debug)] From d72424e84d7b2b2ef75c800fdc7cb6c598c4387d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 Dec 2023 16:51:05 +0100 Subject: [PATCH 18/20] tests++ --- mitmproxy-rs/src/udp_client.rs | 40 ++++++++++++++++++++++ src/network/tests.rs | 3 +- src/network/udp.rs | 61 ++++++++++++++++++++++++++++++++-- 3 files changed, 101 insertions(+), 3 deletions(-) diff --git a/mitmproxy-rs/src/udp_client.rs b/mitmproxy-rs/src/udp_client.rs index cfb20b06..427f0aa5 100644 --- a/mitmproxy-rs/src/udp_client.rs +++ b/mitmproxy-rs/src/udp_client.rs @@ -153,3 +153,43 @@ impl UdpClientTask { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_udp_client_echo() -> Result<()> { + let server = UdpSocket::bind("127.0.0.1:0").await?; + let addr = server.local_addr()?; + + let socket = udp_connect(addr.ip().to_string(), addr.port(), None).await?; + + let (command_tx, command_rx) = unbounded_channel(); + + let handle = tokio::spawn( + UdpClientTask { + socket, + transport_commands_rx: command_rx, + } + .run(), + ); + let cid = ConnectionId::unassigned_udp(); + + command_tx.send(TransportCommand::WriteData(cid, b"Hello World".to_vec()))?; + + let mut recv_buf = [0u8; 20]; + let (n, src) = server.recv_from(&mut recv_buf).await?; + assert_eq!(&recv_buf[..n], b"Hello World"); + + server.send_to(b"Hello back", src).await?; + + let (tx, rx) = oneshot::channel(); + command_tx.send(TransportCommand::ReadData(cid, 0, tx))?; + assert_eq!(rx.await?, b"Hello back"); + + command_tx.send(TransportCommand::CloseConnection(cid, false))?; + handle.await??; + Ok(()) + } +} diff --git a/src/network/tests.rs b/src/network/tests.rs index b655b220..9597aaa5 100755 --- a/src/network/tests.rs +++ b/src/network/tests.rs @@ -15,6 +15,7 @@ use tokio::{ use crate::messages::{ NetworkCommand, NetworkEvent, SmolPacket, TransportCommand, TransportEvent, TunnelInfo, }; +use crate::packet_sources::PacketSourceConf; use super::task::NetworkTask; @@ -390,7 +391,7 @@ async fn udp_read_write( } #[tokio::test] -async fn ivp4_udp() -> Result<()> { +async fn ipv4_udp() -> Result<()> { init_logger(); let src_addr = Ipv4Address([10, 0, 0, 1]); let dst_addr = Ipv4Address([10, 0, 0, 42]); diff --git a/src/network/udp.rs b/src/network/udp.rs index 843efc6b..9941056b 100644 --- a/src/network/udp.rs +++ b/src/network/udp.rs @@ -34,6 +34,7 @@ impl ConnectionState { self.packets.push_back(data); } } + #[allow(dead_code)] pub fn packet_queue_len(&self) -> usize { self.packets.len() } @@ -119,8 +120,8 @@ impl UdpHandler { } Some(UdpPacket { - src_addr: addrs.0, - dst_addr: addrs.1, + src_addr: addrs.1, + dst_addr: addrs.0, payload: data, }) } @@ -264,12 +265,17 @@ impl From for SmolPacket { #[cfg(test)] mod tests { use super::*; + use crate::packet_sources::udp::UdpConf; + use crate::packet_sources::{PacketSourceConf, PacketSourceTask}; + use std::net::{IpAddr, Ipv4Addr}; + use tokio::net::UdpSocket; #[test] fn test_connection_state_recv_recv_read_read() { let mut state = ConnectionState::default(); state.add_packet(vec![1, 2, 3]); state.add_packet(vec![4, 5, 6]); + assert_eq!(state.packet_queue_len(), 2); let (tx, rx) = oneshot::channel(); state.add_reader(tx); assert_eq!(vec![1, 2, 3], rx.blocking_recv().unwrap()); @@ -307,4 +313,55 @@ mod tests { state.add_packet(vec![1, 2, 3]); assert!(rx.blocking_recv().is_err()); } + + #[tokio::test] + async fn test_udp_server_echo() -> anyhow::Result<()> { + let (commands_tx, commands_rx) = tokio::sync::mpsc::unbounded_channel(); + let (events_tx, mut events_rx) = tokio::sync::mpsc::channel(1); + let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel(10); + let (task, addr) = UdpConf { + host: "127.0.0.1".to_string(), + port: 0, + } + .build(events_tx, commands_rx, shutdown_rx) + .await?; + + let handle = tokio::spawn(task.run()); + + let client = UdpSocket::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)).await?; + client.connect(addr).await?; + client.send(b"Hello World!").await?; + + let TransportEvent::ConnectionEstablished { + connection_id, + command_tx: None, + .. + } = events_rx.recv().await.unwrap() + else { + panic!("unexpected command tx needs test adjustment"); + }; + + let (data_tx, data_rx) = oneshot::channel(); + commands_tx.send(TransportCommand::ReadData(connection_id, 0, data_tx))?; + assert_eq!(data_rx.await.unwrap(), b"Hello World!"); + + commands_tx.send(TransportCommand::WriteData( + connection_id, + b"Hello back!".to_vec(), + ))?; + + let mut recv_buf = [0u8; 20]; + let n = client.recv(&mut recv_buf).await?; + assert_eq!(&recv_buf[..n], b"Hello back!"); + + commands_tx.send(TransportCommand::CloseConnection(connection_id, false))?; + let (data_tx, data_rx) = oneshot::channel(); + commands_tx.send(TransportCommand::ReadData(connection_id, 0, data_tx))?; + assert!(data_rx.await.is_err()); + + shutdown_tx.send(())?; + handle.await??; + + Ok(()) + } } From 204910cf3225d76f94ec84c9efe45539bf9b02c3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 Dec 2023 20:25:08 +0100 Subject: [PATCH 19/20] fix nits --- mitmproxy-rs/src/task.rs | 12 ++++++------ mitmproxy-windows/redirector/src/main2.rs | 3 ++- src/network/tests.rs | 1 - src/packet_sources/macos.rs | 15 +++------------ src/packet_sources/udp.rs | 6 ++++-- src/packet_sources/wireguard.rs | 3 ++- 6 files changed, 17 insertions(+), 23 deletions(-) diff --git a/mitmproxy-rs/src/task.rs b/mitmproxy-rs/src/task.rs index 394603d6..75e582bf 100644 --- a/mitmproxy-rs/src/task.rs +++ b/mitmproxy-rs/src/task.rs @@ -17,7 +17,7 @@ pub struct PyInteropTask { transport_events: mpsc::Receiver, py_tcp_handler: PyObject, py_udp_handler: PyObject, - sd_watcher: broadcast::Receiver<()>, + shutdown: broadcast::Receiver<()>, } impl PyInteropTask { @@ -36,7 +36,7 @@ impl PyInteropTask { transport_events, py_tcp_handler, py_udp_handler, - sd_watcher, + shutdown: sd_watcher, } } @@ -48,9 +48,9 @@ impl PyInteropTask { })?; loop { - tokio::select!( + tokio::select! { // wait for graceful shutdown - _ = self.sd_watcher.recv() => break, + _ = self.shutdown.recv() => break, // wait for network events event = self.transport_events.recv() => { let Some(event) = event else { @@ -111,8 +111,8 @@ impl PyInteropTask { }; }, } - }, - ); + } + }; } log::debug!("Python interoperability task shutting down."); diff --git a/mitmproxy-windows/redirector/src/main2.rs b/mitmproxy-windows/redirector/src/main2.rs index 17c17723..f340d271 100644 --- a/mitmproxy-windows/redirector/src/main2.rs +++ b/mitmproxy-windows/redirector/src/main2.rs @@ -420,7 +420,8 @@ async fn handle_ipc( } } }, - Some(packet) = ipc_rx.recv() => { + r = ipc_rx.recv() => { + let Some(packet) = r else { break }; packet.encode(&mut buf.as_mut_slice())?; let len = packet.encoded_len(); diff --git a/src/network/tests.rs b/src/network/tests.rs index 9597aaa5..361ede40 100755 --- a/src/network/tests.rs +++ b/src/network/tests.rs @@ -15,7 +15,6 @@ use tokio::{ use crate::messages::{ NetworkCommand, NetworkEvent, SmolPacket, TransportCommand, TransportEvent, TunnelInfo, }; -use crate::packet_sources::PacketSourceConf; use super::task::NetworkTask; diff --git a/src/packet_sources/macos.rs b/src/packet_sources/macos.rs index d57f3486..d71efd1d 100644 --- a/src/packet_sources/macos.rs +++ b/src/packet_sources/macos.rs @@ -243,10 +243,7 @@ impl ConnectionTask { loop { tokio::select! { _ = self.shutdown.recv() => break, - packet = stream.next(), if state.packet_queue_len() < 10 => { - let Some(packet) = packet else { - break - }; + Some(packet) = stream.next(), if state.packet_queue_len() < 10 => { let packet = ipc::UdpPacket::decode( packet.context("IPC read error")? ).context("invalid IPC message")?; @@ -270,10 +267,7 @@ impl ConnectionTask { } state.add_packet(packet.data); }, - command = command_rx.recv() => { - let Some(command) = command else { - break; - }; + Some(command) = command_rx.recv() => { match command { TransportCommand::ReadData(_, _, tx) => { state.add_reader(tx); @@ -350,10 +344,7 @@ impl ConnectionTask { self.stream.read_buf(&mut data).await.context("failed to read from socket")?; tx.send(data).ok(); }, - command = command_rx.recv() => { - let Some(command) = command else { - break; - }; + Some(command) = command_rx.recv() => { match command { TransportCommand::ReadData(_, n, tx) => { assert!(read_tx.is_none()); diff --git a/src/packet_sources/udp.rs b/src/packet_sources/udp.rs index bcdab639..27ac9a0f 100644 --- a/src/packet_sources/udp.rs +++ b/src/packet_sources/udp.rs @@ -90,7 +90,8 @@ impl PacketSourceTask for UdpTask { permit = Some(p); }, // ... or process incoming packets - Ok((len, src_addr)) = self.socket.recv_from(&mut udp_buf), if py_tx_available => { + r = self.socket.recv_from(&mut udp_buf), if py_tx_available => { + let (len, src_addr) = r.context("UDP recv() failed")?; self.handler.receive_data( UdpPacket { src_addr, @@ -102,7 +103,8 @@ impl PacketSourceTask for UdpTask { ); }, // send_to is cancel safe, so we can use that for backpressure. - Ok(_) = self.socket.send_to(&packet_payload, packet_dst), if packet_needs_sending => { + r = self.socket.send_to(&packet_payload, packet_dst), if packet_needs_sending => { + r.context("UDP send_to() failed")?; packet_needs_sending = false; }, Some(command) = self.transport_commands_rx.recv(), if !packet_needs_sending => { diff --git a/src/packet_sources/wireguard.rs b/src/packet_sources/wireguard.rs index c6f94a74..5b8bdfaa 100755 --- a/src/packet_sources/wireguard.rs +++ b/src/packet_sources/wireguard.rs @@ -158,7 +158,8 @@ impl PacketSourceTask for WireGuardTask { tokio::select! { exit = &mut self.network_task_handle => break exit.context("network task panic")?.context("network task error")?, // wait for WireGuard packets incoming on the UDP socket - Ok((len, src_orig)) = self.socket.recv_from(&mut udp_buf) => { + r = self.socket.recv_from(&mut udp_buf) => { + let (len, src_orig) = r.context("UDP recv() failed")?; self.process_incoming_datagram(&udp_buf[..len], src_orig).await?; }, // wait for outgoing IP packets From eecf6cfb1c649ad61c1fd7c517e4fac8ba7bca3d Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 18 Dec 2023 20:29:12 +0100 Subject: [PATCH 20/20] fixup --- mitmproxy-windows/redirector/src/main2.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mitmproxy-windows/redirector/src/main2.rs b/mitmproxy-windows/redirector/src/main2.rs index f340d271..9a3a29f3 100644 --- a/mitmproxy-windows/redirector/src/main2.rs +++ b/mitmproxy-windows/redirector/src/main2.rs @@ -420,9 +420,7 @@ async fn handle_ipc( } } }, - r = ipc_rx.recv() => { - let Some(packet) = r else { break }; - + Some(packet) = ipc_rx.recv() => { packet.encode(&mut buf.as_mut_slice())?; let len = packet.encoded_len();