From 0e793c62d02c125fe76534e4a1ebf8cba135d225 Mon Sep 17 00:00:00 2001 From: akiroz Date: Thu, 18 Jan 2024 22:46:18 +0800 Subject: [PATCH] Refactor run function to constructor --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/bin/client.rs | 3 +- src/bin/server.rs | 3 +- src/client.rs | 86 ++++++++++++++++++++------------------------ src/server.rs | 91 +++++++++++++++++++++-------------------------- 6 files changed, 82 insertions(+), 105 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 40decde..5a94e69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1315,7 +1315,7 @@ checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" [[package]] name = "zika" -version = "3.2.3" +version = "3.3.0" dependencies = [ "base64", "bytes", diff --git a/Cargo.toml b/Cargo.toml index f458e1e..c5bc8fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zika" -version = "3.2.3" +version = "3.3.0" license = "MIT" description = "IP Tunneling over MQTT" repository = "https://github.com/akiroz/zika" diff --git a/src/bin/client.rs b/src/bin/client.rs index 364d48f..7cd1cac 100644 --- a/src/bin/client.rs +++ b/src/bin/client.rs @@ -10,6 +10,5 @@ async fn main() { .init(); let config = read_from_default_location().expect("A proper config file"); log::debug!("Config = {:?}", config); - let mut client = Client::from_config(config).await; - client.run().await; + let _ = Client::from_config(config).await; } diff --git a/src/bin/server.rs b/src/bin/server.rs index e22f75e..8e74e4b 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -10,6 +10,5 @@ async fn main() { .init(); let config = read_from_default_location().expect("A proper config file"); log::debug!("Config = {:?}", config); - let mut server = Server::from_config(config); - server.run().await; + let _ = Server::from_config(config); } diff --git a/src/client.rs b/src/client.rs index f732174..2759e9a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,7 +10,7 @@ use rand::{thread_rng, Rng, distributions::Standard}; use rumqttc; use etherparse::Ipv4Header; use ipnetwork::Ipv4Network; -use tokio::{task, sync::{mpsc, broadcast, Mutex}}; +use tokio::{task, sync::{broadcast, Mutex}}; use tokio_util::codec::Framed; use tun::{AsyncDevice, TunPacket, TunPacketCodec}; @@ -21,24 +21,21 @@ use crate::ip_iter::SizedIpv4NetworkIterator; type TunSink = SplitSink, TunPacket>; pub struct Client { - local_addr: Ipv4Addr, + pub local_addr: Ipv4Addr, tunnels: Arc>, pub remote: Arc>, // Allow external mqtt ops - remote_recv: mpsc::Receiver<(String, Bytes)>, - remote_passthru_send: broadcast::Sender<(String, Bytes)>, pub remote_passthru_recv: broadcast::Receiver<(String, Bytes)>, - tun_sink: TunSink, } struct Tunnel { id: Vec, topic: String, topic_base: String, - bind_addr: std::net::Ipv4Addr, + bind_addr: Ipv4Addr, } impl Client { - pub async fn from_config(config: config::Config) -> Self { + pub async fn from_config(config: config::Config) -> Arc { let mqtt_options = config.mqtt.to_mqtt_options().expect("valid mqtt config"); let client_config = config.client.expect("non-null client config"); Self::new(mqtt_options, client_config).await @@ -47,7 +44,7 @@ impl Client { pub async fn new( mqtt_options: Vec, client_config: config::ClientConfig - ) -> Self { + ) -> Arc { let ip_network: Ipv4Network = client_config.bind_cidr.parse().expect("CIDR notation"); let local_addr = SizedIpv4NetworkIterator::new(ip_network).next().expect("subnet size > 1"); @@ -64,9 +61,9 @@ impl Client { tun_config.up(); let tun_dev = tun::create_as_async(&tun_config).expect("tunnel"); - let (tun_sink, mut tun_stream) = tun_dev.into_framed().split(); + let (mut tun_sink, mut tun_stream) = tun_dev.into_framed().split(); - let (remote, remote_recv) = remote::Remote::new(&mqtt_options, vec![]); + let (remote, mut remote_recv) = remote::Remote::new(&mqtt_options, vec![]); let mut tunnels = Vec::with_capacity(client_config.tunnels.len()); let mut rng = thread_rng(); for client_tunnel_config in &client_config.tunnels { @@ -99,35 +96,53 @@ impl Client { } let (remote_passthru_send, remote_passthru_recv) = broadcast::channel(1); - let client = Self { + let client = Arc::new(Self { local_addr, tunnels: Arc::new(tunnels), remote: Arc::new(Mutex::new(remote)), - remote_recv, - remote_passthru_send, remote_passthru_recv, - tun_sink, - }; + }); + let loop_client = client.clone(); let loop_remote = client.remote.clone(); - let loop_tunnels = client.tunnels.clone(); task::spawn(async move { while let Some(packet) = tun_stream.next().await { match packet { Ok(pkt) => { let mut remote = loop_remote.lock().await; - Self::handle_packet(&mut remote, &loop_tunnels, &pkt).await; + loop_client.handle_packet(&mut remote, &pkt).await; } Err(err) => panic!("Error: {:?}", err), } } }); + let loop2_client = client.clone(); + task::spawn(async move { + loop { + if let Some((topic, msg)) = remote_recv.recv().await { + let handle_result = loop2_client.handle_remote_message(&mut tun_sink, topic.as_str(), &msg).await; + match handle_result { + Err(err) => log::error!("handle_remote_message error {:?}", err), + Ok(handled) => { + if !handled { + if let Err(err) = remote_passthru_send.send((topic, msg)) { + log::warn!("remote_passthru_send error {:?}", err); + } + } + } + } + } else { + break; + } + } + }); + client } // tun -> mqtt - async fn handle_packet(remote: &mut remote::Remote, tunnels: &Vec, packet: &TunPacket) { + async fn handle_packet(&self, remote: &mut remote::Remote, packet: &TunPacket) { let dest = etherparse::IpHeader::from_slice(&packet.get_bytes()) .ok() .and_then(|header| match header { @@ -136,7 +151,7 @@ impl Client { } _ => None, }); - for tunnel in tunnels { + for tunnel in self.tunnels.iter() { if dest == Some(tunnel.bind_addr) { let payload = [&tunnel.id[..], &packet.get_bytes().to_vec()[..]].concat(); let result = remote.publish(&tunnel.topic_base, payload).await; @@ -150,13 +165,9 @@ impl Client { } // mqtt -> tun - async fn handle_remote_message( - &mut self, - topic: String, - message: Bytes, - ) -> Result { - if let Some(tunnel) = self.tunnels.as_ref().iter().find(|&t| t.topic == topic) { - match Ipv4Header::from_slice(&message) { + async fn handle_remote_message(&self, tun_sink: &mut TunSink, topic: &str, msg: &[u8]) -> Result { + if let Some(tunnel) = self.tunnels.iter().find(|&t| t.topic == topic) { + match Ipv4Header::from_slice(&msg) { Err(error) => { log::debug!("packet parse failed {:?}", error); Ok(false) @@ -167,7 +178,7 @@ impl Client { let mut cursor = Cursor::new(Vec::new()); ipv4_header.write(&mut cursor)?; cursor.write_all(rest)?; - self.tun_sink.send(TunPacket::new(cursor.into_inner())).await?; + tun_sink.send(TunPacket::new(cursor.into_inner())).await?; Ok(true) } } @@ -175,25 +186,4 @@ impl Client { Ok(false) } } - - pub async fn run(&mut self) { - loop { - if let Some((topic, message)) = self.remote_recv.recv().await { - match self.handle_remote_message(topic.clone(), message.clone()).await { - Err(err) => { - log::error!("handle_remote_message error {:?}", err); - } - Ok(handled) => { - if !handled { - if let Err(err) = self.remote_passthru_send.send((topic, message)) { - log::warn!("remote_passthru_send error {:?}", err); - } - } - } - } - } else { - break; - } - } - } } diff --git a/src/server.rs b/src/server.rs index 8e17f6e..8c38591 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,13 +2,12 @@ use std::io::{Cursor, Write}; use std::net::Ipv4Addr; use std::sync::Arc; -use bytes::Bytes; use base64::{engine::general_purpose, Engine as _}; use futures::{SinkExt, stream::{SplitSink, StreamExt}}; use etherparse::Ipv4Header; use ipnetwork::Ipv4Network; -use tokio::{task, sync::{mpsc, Mutex}}; +use tokio::{task, sync::Mutex}; use tokio_util::codec::Framed; use tun::{AsyncDevice, TunPacket, TunPacketCodec}; @@ -21,16 +20,14 @@ type TunSink = SplitSink, TunPacket>; type IpPool = LookupPool; pub struct Server { - id_length: usize, - topic: String, - local_addr: Ipv4Addr, + pub id_length: usize, + pub topic: String, + pub local_addr: Ipv4Addr, ip_pool: Arc>, - remote_recv: mpsc::Receiver<(String, Bytes)>, - tun_sink: TunSink, } impl Server { - pub fn from_config(config: config::Config) -> Self { + pub fn from_config(config: config::Config) -> Arc { let mqtt_options = config.mqtt.to_mqtt_options().expect("valid mqtt config"); let server_config = config.server.expect("non-null server config"); Self::new(mqtt_options, server_config) @@ -39,12 +36,12 @@ impl Server { pub fn new( mqtt_options: Vec, server_config: config::ServerConfig - ) -> Self { + ) -> Arc { let ip_network: Ipv4Network = server_config.bind_cidr.parse().expect("CIDR notation"); let mut ip_iter = SizedIpv4NetworkIterator::new(ip_network); let local_addr = ip_iter.next().expect("subnet size > 1"); - let (mut remote, remote_recv) = remote::Remote::new(&mqtt_options, vec![server_config.topic.clone()]); + let (mut remote, mut remote_recv) = remote::Remote::new(&mqtt_options, vec![server_config.topic.clone()]); log::info!("bind {:?}/{}", local_addr, ip_network.prefix()); @@ -61,16 +58,22 @@ impl Server { tun_config.up(); let dev = tun::create_as_async(&tun_config).expect("tunnel"); - let (tun_sink, mut tun_stream) = dev.into_framed().split(); + let (mut tun_sink, mut tun_stream) = dev.into_framed().split(); - let ip_pool_arc = Arc::new(Mutex::new(LookupPool::new(ip_iter))); - let loop_ip_pool_arc = ip_pool_arc.clone(); + let server = Arc::new(Self { + id_length: server_config.id_length, + topic: server_config.topic, + local_addr, + ip_pool: Arc::new(Mutex::new(LookupPool::new(ip_iter))), + }); + let loop_ip_pool = server.ip_pool.clone(); task::spawn(async move { while let Some(packet) = tun_stream.next().await { match packet { Ok(pkt) => { - let result = Self::handle_packet(&mut remote, loop_ip_pool_arc.clone(), &pkt).await; + let mut ip_pool = loop_ip_pool.lock().await; + let result = Self::handle_packet(&mut remote, &mut ip_pool, &pkt).await; if let Err(err) = result { log::error!("handle_packet error {:?}", err); } @@ -80,29 +83,32 @@ impl Server { } }); - Self { - id_length: server_config.id_length, - topic: server_config.topic, - local_addr, - ip_pool: ip_pool_arc, - remote_recv, - tun_sink, - } + let loop_server = server.clone(); + task::spawn(async move { + loop { + if let Some((_topic, payload)) = remote_recv.recv().await { + let (id, msg) = payload.split_at(server_config.id_length); + let handle_result = loop_server.handle_remote_message(&mut tun_sink, id, msg).await; + if let Err(err) = handle_result { + log::error!("handle_remote_message error {:?}", err); + } + } else { + break; + } + } + }); + + server } // tun -> mqtt - async fn handle_packet( - remote: &mut remote::Remote, - ip_pool: Arc>, - packet: &TunPacket, - ) -> Result<(), rumqttc::v5::ClientError> { - let dest = Ipv4Header::from_slice(&packet.get_bytes()) + async fn handle_packet(remote: &mut remote::Remote, ip_pool: &mut IpPool, pkt: &TunPacket) -> Result<(), rumqttc::v5::ClientError> { + let dest = Ipv4Header::from_slice(&pkt.get_bytes()) .ok() .map(|(ipv4_header, _)| Ipv4Addr::from(ipv4_header.destination)); if let Some(d) = dest { - let ip_pool = ip_pool.lock().await; if let Some(topic) = ip_pool.get_reverse(&d.into()) { - remote.publish(&topic, packet.get_bytes().to_vec()).await?; + remote.publish(&topic, pkt.get_bytes().to_vec()).await?; } else { log::debug!("drop packet: no tunnel for {:?}", &d); } @@ -111,11 +117,7 @@ impl Server { } // mqtt -> tun - async fn handle_remote_message( - &mut self, - id: &[u8], - message: &[u8], - ) -> Result<(), etherparse::WriteError> { + async fn handle_remote_message(&self, tun_sink: &mut TunSink, id: &[u8], msg: &[u8]) -> Result<(), etherparse::WriteError> { let base64_id = general_purpose::URL_SAFE_NO_PAD.encode(id); let topic_base = &self.topic; let topic = format!("{topic_base}/{base64_id}"); @@ -124,10 +126,10 @@ impl Server { let ip = ip_pool.get_forward(&topic).into(); ip }; - let packet_with_updated_header = match Ipv4Header::from_slice(message) { + let pkt = match Ipv4Header::from_slice(msg) { Err(error) => { log::debug!("packet parse failed {:?}", error); - message.to_vec() + msg.to_vec() } Ok((mut ipv4_header, rest)) => { ipv4_header.source = ip.octets(); @@ -138,20 +140,7 @@ impl Server { cursor.into_inner() } }; - self.tun_sink.send(TunPacket::new(packet_with_updated_header)).await?; + tun_sink.send(TunPacket::new(pkt)).await?; Ok(()) } - - pub async fn run(&mut self) { - loop { - if let Some((_topic, id_message)) = self.remote_recv.recv().await { - let (id, message) = id_message.split_at(self.id_length); - if let Err(err) = self.handle_remote_message(id, message).await { - log::error!("handle_remote_message error {:?}", err); - } - } else { - break; - } - } - } }