Skip to content

Commit

Permalink
Refactor run function to constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
akiroz committed Jan 18, 2024
1 parent 8990796 commit 0e793c6
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 105 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "zika"
version = "3.2.3"
version = "3.3.0"
license = "MIT"
description = "IP Tunneling over MQTT"
repository = "https://github.com/akiroz/zika"
Expand Down
3 changes: 1 addition & 2 deletions src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ async fn main() {
.init();
let config = read_from_default_location().expect("A proper config file");
log::debug!("Config = {:?}", config);
let mut client = Client::from_config(config).await;
client.run().await;
let _ = Client::from_config(config).await;
}
3 changes: 1 addition & 2 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ async fn main() {
.init();
let config = read_from_default_location().expect("A proper config file");
log::debug!("Config = {:?}", config);
let mut server = Server::from_config(config);
server.run().await;
let _ = Server::from_config(config);
}
86 changes: 38 additions & 48 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rand::{thread_rng, Rng, distributions::Standard};
use rumqttc;
use etherparse::Ipv4Header;
use ipnetwork::Ipv4Network;
use tokio::{task, sync::{mpsc, broadcast, Mutex}};
use tokio::{task, sync::{broadcast, Mutex}};
use tokio_util::codec::Framed;
use tun::{AsyncDevice, TunPacket, TunPacketCodec};

Expand All @@ -21,24 +21,21 @@ use crate::ip_iter::SizedIpv4NetworkIterator;
type TunSink = SplitSink<Framed<AsyncDevice, TunPacketCodec>, TunPacket>;

pub struct Client {
local_addr: Ipv4Addr,
pub local_addr: Ipv4Addr,
tunnels: Arc<Vec<Tunnel>>,
pub remote: Arc<Mutex<remote::Remote>>, // Allow external mqtt ops
remote_recv: mpsc::Receiver<(String, Bytes)>,
remote_passthru_send: broadcast::Sender<(String, Bytes)>,
pub remote_passthru_recv: broadcast::Receiver<(String, Bytes)>,
tun_sink: TunSink,
}

struct Tunnel {
id: Vec<u8>,
topic: String,
topic_base: String,
bind_addr: std::net::Ipv4Addr,
bind_addr: Ipv4Addr,
}

impl Client {
pub async fn from_config(config: config::Config) -> Self {
pub async fn from_config(config: config::Config) -> Arc<Self> {
let mqtt_options = config.mqtt.to_mqtt_options().expect("valid mqtt config");
let client_config = config.client.expect("non-null client config");
Self::new(mqtt_options, client_config).await
Expand All @@ -47,7 +44,7 @@ impl Client {
pub async fn new(
mqtt_options: Vec<rumqttc::v5::MqttOptions>,
client_config: config::ClientConfig
) -> Self {
) -> Arc<Self> {
let ip_network: Ipv4Network = client_config.bind_cidr.parse().expect("CIDR notation");
let local_addr = SizedIpv4NetworkIterator::new(ip_network).next().expect("subnet size > 1");

Expand All @@ -64,9 +61,9 @@ impl Client {

tun_config.up();
let tun_dev = tun::create_as_async(&tun_config).expect("tunnel");
let (tun_sink, mut tun_stream) = tun_dev.into_framed().split();
let (mut tun_sink, mut tun_stream) = tun_dev.into_framed().split();

let (remote, remote_recv) = remote::Remote::new(&mqtt_options, vec![]);
let (remote, mut remote_recv) = remote::Remote::new(&mqtt_options, vec![]);
let mut tunnels = Vec::with_capacity(client_config.tunnels.len());
let mut rng = thread_rng();
for client_tunnel_config in &client_config.tunnels {
Expand Down Expand Up @@ -99,35 +96,53 @@ impl Client {
}

let (remote_passthru_send, remote_passthru_recv) = broadcast::channel(1);
let client = Self {
let client = Arc::new(Self {
local_addr,
tunnels: Arc::new(tunnels),
remote: Arc::new(Mutex::new(remote)),
remote_recv,
remote_passthru_send,
remote_passthru_recv,
tun_sink,
};
});

let loop_client = client.clone();
let loop_remote = client.remote.clone();
let loop_tunnels = client.tunnels.clone();
task::spawn(async move {
while let Some(packet) = tun_stream.next().await {
match packet {
Ok(pkt) => {
let mut remote = loop_remote.lock().await;
Self::handle_packet(&mut remote, &loop_tunnels, &pkt).await;
loop_client.handle_packet(&mut remote, &pkt).await;
}
Err(err) => panic!("Error: {:?}", err),
}
}
});

let loop2_client = client.clone();
task::spawn(async move {
loop {
if let Some((topic, msg)) = remote_recv.recv().await {
let handle_result = loop2_client.handle_remote_message(&mut tun_sink, topic.as_str(), &msg).await;
match handle_result {
Err(err) => log::error!("handle_remote_message error {:?}", err),
Ok(handled) => {
if !handled {
if let Err(err) = remote_passthru_send.send((topic, msg)) {
log::warn!("remote_passthru_send error {:?}", err);
}
}
}
}
} else {
break;
}
}
});

client
}

// tun -> mqtt
async fn handle_packet(remote: &mut remote::Remote, tunnels: &Vec<Tunnel>, packet: &TunPacket) {
async fn handle_packet(&self, remote: &mut remote::Remote, packet: &TunPacket) {
let dest = etherparse::IpHeader::from_slice(&packet.get_bytes())
.ok()
.and_then(|header| match header {
Expand All @@ -136,7 +151,7 @@ impl Client {
}
_ => None,
});
for tunnel in tunnels {
for tunnel in self.tunnels.iter() {
if dest == Some(tunnel.bind_addr) {
let payload = [&tunnel.id[..], &packet.get_bytes().to_vec()[..]].concat();
let result = remote.publish(&tunnel.topic_base, payload).await;
Expand All @@ -150,13 +165,9 @@ impl Client {
}

// mqtt -> tun
async fn handle_remote_message(
&mut self,
topic: String,
message: Bytes,
) -> Result<bool, etherparse::WriteError> {
if let Some(tunnel) = self.tunnels.as_ref().iter().find(|&t| t.topic == topic) {
match Ipv4Header::from_slice(&message) {
async fn handle_remote_message(&self, tun_sink: &mut TunSink, topic: &str, msg: &[u8]) -> Result<bool, etherparse::WriteError> {
if let Some(tunnel) = self.tunnels.iter().find(|&t| t.topic == topic) {
match Ipv4Header::from_slice(&msg) {
Err(error) => {
log::debug!("packet parse failed {:?}", error);
Ok(false)
Expand All @@ -167,33 +178,12 @@ impl Client {
let mut cursor = Cursor::new(Vec::new());
ipv4_header.write(&mut cursor)?;
cursor.write_all(rest)?;
self.tun_sink.send(TunPacket::new(cursor.into_inner())).await?;
tun_sink.send(TunPacket::new(cursor.into_inner())).await?;
Ok(true)
}
}
} else {
Ok(false)
}
}

pub async fn run(&mut self) {
loop {
if let Some((topic, message)) = self.remote_recv.recv().await {
match self.handle_remote_message(topic.clone(), message.clone()).await {
Err(err) => {
log::error!("handle_remote_message error {:?}", err);
}
Ok(handled) => {
if !handled {
if let Err(err) = self.remote_passthru_send.send((topic, message)) {
log::warn!("remote_passthru_send error {:?}", err);
}
}
}
}
} else {
break;
}
}
}
}
91 changes: 40 additions & 51 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ use std::io::{Cursor, Write};
use std::net::Ipv4Addr;
use std::sync::Arc;

use bytes::Bytes;
use base64::{engine::general_purpose, Engine as _};
use futures::{SinkExt, stream::{SplitSink, StreamExt}};

use etherparse::Ipv4Header;
use ipnetwork::Ipv4Network;
use tokio::{task, sync::{mpsc, Mutex}};
use tokio::{task, sync::Mutex};
use tokio_util::codec::Framed;
use tun::{AsyncDevice, TunPacket, TunPacketCodec};

Expand All @@ -21,16 +20,14 @@ type TunSink = SplitSink<Framed<AsyncDevice, TunPacketCodec>, TunPacket>;
type IpPool = LookupPool<String, Ipv4Addr, SizedIpv4NetworkIterator>;

pub struct Server {
id_length: usize,
topic: String,
local_addr: Ipv4Addr,
pub id_length: usize,
pub topic: String,
pub local_addr: Ipv4Addr,
ip_pool: Arc<Mutex<IpPool>>,
remote_recv: mpsc::Receiver<(String, Bytes)>,
tun_sink: TunSink,
}

impl Server {
pub fn from_config(config: config::Config) -> Self {
pub fn from_config(config: config::Config) -> Arc<Self> {
let mqtt_options = config.mqtt.to_mqtt_options().expect("valid mqtt config");
let server_config = config.server.expect("non-null server config");
Self::new(mqtt_options, server_config)
Expand All @@ -39,12 +36,12 @@ impl Server {
pub fn new(
mqtt_options: Vec<rumqttc::v5::MqttOptions>,
server_config: config::ServerConfig
) -> Self {
) -> Arc<Self> {
let ip_network: Ipv4Network = server_config.bind_cidr.parse().expect("CIDR notation");
let mut ip_iter = SizedIpv4NetworkIterator::new(ip_network);
let local_addr = ip_iter.next().expect("subnet size > 1");

let (mut remote, remote_recv) = remote::Remote::new(&mqtt_options, vec![server_config.topic.clone()]);
let (mut remote, mut remote_recv) = remote::Remote::new(&mqtt_options, vec![server_config.topic.clone()]);

log::info!("bind {:?}/{}", local_addr, ip_network.prefix());

Expand All @@ -61,16 +58,22 @@ impl Server {
tun_config.up();

let dev = tun::create_as_async(&tun_config).expect("tunnel");
let (tun_sink, mut tun_stream) = dev.into_framed().split();
let (mut tun_sink, mut tun_stream) = dev.into_framed().split();

let ip_pool_arc = Arc::new(Mutex::new(LookupPool::new(ip_iter)));
let loop_ip_pool_arc = ip_pool_arc.clone();
let server = Arc::new(Self {
id_length: server_config.id_length,
topic: server_config.topic,
local_addr,
ip_pool: Arc::new(Mutex::new(LookupPool::new(ip_iter))),
});

let loop_ip_pool = server.ip_pool.clone();
task::spawn(async move {
while let Some(packet) = tun_stream.next().await {
match packet {
Ok(pkt) => {
let result = Self::handle_packet(&mut remote, loop_ip_pool_arc.clone(), &pkt).await;
let mut ip_pool = loop_ip_pool.lock().await;
let result = Self::handle_packet(&mut remote, &mut ip_pool, &pkt).await;
if let Err(err) = result {
log::error!("handle_packet error {:?}", err);
}
Expand All @@ -80,29 +83,32 @@ impl Server {
}
});

Self {
id_length: server_config.id_length,
topic: server_config.topic,
local_addr,
ip_pool: ip_pool_arc,
remote_recv,
tun_sink,
}
let loop_server = server.clone();
task::spawn(async move {
loop {
if let Some((_topic, payload)) = remote_recv.recv().await {
let (id, msg) = payload.split_at(server_config.id_length);
let handle_result = loop_server.handle_remote_message(&mut tun_sink, id, msg).await;
if let Err(err) = handle_result {
log::error!("handle_remote_message error {:?}", err);
}
} else {
break;
}
}
});

server
}

// tun -> mqtt
async fn handle_packet(
remote: &mut remote::Remote,
ip_pool: Arc<Mutex<IpPool>>,
packet: &TunPacket,
) -> Result<(), rumqttc::v5::ClientError> {
let dest = Ipv4Header::from_slice(&packet.get_bytes())
async fn handle_packet(remote: &mut remote::Remote, ip_pool: &mut IpPool, pkt: &TunPacket) -> Result<(), rumqttc::v5::ClientError> {
let dest = Ipv4Header::from_slice(&pkt.get_bytes())
.ok()
.map(|(ipv4_header, _)| Ipv4Addr::from(ipv4_header.destination));
if let Some(d) = dest {
let ip_pool = ip_pool.lock().await;
if let Some(topic) = ip_pool.get_reverse(&d.into()) {
remote.publish(&topic, packet.get_bytes().to_vec()).await?;
remote.publish(&topic, pkt.get_bytes().to_vec()).await?;
} else {
log::debug!("drop packet: no tunnel for {:?}", &d);
}
Expand All @@ -111,11 +117,7 @@ impl Server {
}

// mqtt -> tun
async fn handle_remote_message(
&mut self,
id: &[u8],
message: &[u8],
) -> Result<(), etherparse::WriteError> {
async fn handle_remote_message(&self, tun_sink: &mut TunSink, id: &[u8], msg: &[u8]) -> Result<(), etherparse::WriteError> {
let base64_id = general_purpose::URL_SAFE_NO_PAD.encode(id);
let topic_base = &self.topic;
let topic = format!("{topic_base}/{base64_id}");
Expand All @@ -124,10 +126,10 @@ impl Server {
let ip = ip_pool.get_forward(&topic).into();
ip
};
let packet_with_updated_header = match Ipv4Header::from_slice(message) {
let pkt = match Ipv4Header::from_slice(msg) {
Err(error) => {
log::debug!("packet parse failed {:?}", error);
message.to_vec()
msg.to_vec()
}
Ok((mut ipv4_header, rest)) => {
ipv4_header.source = ip.octets();
Expand All @@ -138,20 +140,7 @@ impl Server {
cursor.into_inner()
}
};
self.tun_sink.send(TunPacket::new(packet_with_updated_header)).await?;
tun_sink.send(TunPacket::new(pkt)).await?;
Ok(())
}

pub async fn run(&mut self) {
loop {
if let Some((_topic, id_message)) = self.remote_recv.recv().await {
let (id, message) = id_message.split_at(self.id_length);
if let Err(err) = self.handle_remote_message(id, message).await {
log::error!("handle_remote_message error {:?}", err);
}
} else {
break;
}
}
}
}

0 comments on commit 0e793c6

Please sign in to comment.