Skip to content

Commit

Permalink
Merge pull request #4 from b123400/rust-rewrite
Browse files Browse the repository at this point in the history
Handle topic alias
  • Loading branch information
akiroz authored Dec 27, 2023
2 parents 3594d65 + 66ea5c2 commit 1cc28c4
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 75 deletions.
6 changes: 3 additions & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl Client {

let mqtt_options = config.broker_mqtt_options();

let (remote, remote_receiver) = remote::Remote::new(&mqtt_options, Vec::new());
let (mut remote, remote_receiver) = remote::Remote::new(&mqtt_options, Vec::new());

let mut tunnels = Vec::with_capacity(client_config.tunnels.len());
let mut rng = thread_rng();
Expand Down Expand Up @@ -114,7 +114,7 @@ impl Client {
while let Some(packet) = stream.next().await {
match packet {
Ok(pkt) => {
Self::handle_packet(&remote, &loop_tunnels, &pkt).await;
Self::handle_packet(&mut remote, &loop_tunnels, &pkt).await;
}
Err(err) => panic!("Error: {:?}", err),
}
Expand All @@ -130,7 +130,7 @@ impl Client {
}

// tun -> mqtt
async fn handle_packet(remote: &remote::Remote, tunnels: &Vec<Tunnel>, packet: &TunPacket) {
async fn handle_packet(remote: &mut remote::Remote, tunnels: &Vec<Tunnel>, packet: &TunPacket) {
let dest = etherparse::IpHeader::from_slice(&packet.get_bytes())
.ok()
.and_then(|header| match header {
Expand Down
56 changes: 39 additions & 17 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,20 @@ use serde::Deserialize;
use std::fs;
use std::net::Ipv4Addr;

#[derive(Deserialize, Debug)]
#[derive(Deserialize, Clone, Debug)]
pub struct MqttOptions {
#[serde(default = "default_keepalive_interval")]
pub keepalive_interval: u64,
pub keepalive_interval: Option<u64>,

pub username: Option<String>,
pub password: Option<String>,

pub ca_file: Option<String>,

#[serde(default = "default_tls_insecure")]
pub tls_insecure: bool,
pub tls_insecure: Option<bool>,
pub key_file: Option<String>,
pub cert_file: Option<String>,
}

fn default_keepalive_interval() -> u64 {
return 60;
}

fn default_tls_insecure() -> bool {
return false;
pub topic_alias_max: Option<u16>,
}

#[derive(Deserialize, Debug)]
Expand Down Expand Up @@ -89,19 +81,48 @@ pub fn read_from_default_location() -> Result<Config, ConfigError> {
return Ok(deserialized);
}

impl MqttOptions {
pub fn merge_with_option(&self, another_option: MqttOptions) -> MqttOptions {
MqttOptions {
keepalive_interval: another_option
.keepalive_interval
.or(self.keepalive_interval),

username: another_option.username.or(self.username.clone()),
password: another_option.password.or(self.password.clone()),

ca_file: another_option.ca_file.or(self.ca_file.clone()),

tls_insecure: another_option.tls_insecure.or(self.tls_insecure),
key_file: another_option.key_file.or(self.key_file.clone()),
cert_file: another_option.cert_file.or(self.cert_file.clone()),

topic_alias_max: another_option.topic_alias_max.or(self.topic_alias_max),
}
}
}

impl MqttBroker {
pub fn to_mqtt_options(&self) -> rumqttc::v5::MqttOptions {
pub fn to_mqtt_options(
&self,
base_mqtt_options: &Option<MqttOptions>,
) -> rumqttc::v5::MqttOptions {
let mut rng = thread_rng();
let random_id: String = (&mut rng)
.sample_iter(Alphanumeric)
.take(7)
.map(char::from)
.collect();
let mut options = rumqttc::v5::MqttOptions::new(random_id, &self.host, self.port);
options.set_topic_alias_max(Some(5));
let mqtt_options = match (base_mqtt_options, self.options.clone()) {
(Some(b), Some(opts)) => Some(b.merge_with_option(opts)),
(None, o) => o,
(o, None) => o.clone(),
};
// TODO: more options
if let Some(opts) = &self.options {
options.set_keep_alive(Duration::new(opts.keepalive_interval, 0));
if let Some(opts) = mqtt_options {
options.set_keep_alive(Duration::new(opts.keepalive_interval.unwrap_or(60), 0));
options.set_topic_alias_max(opts.topic_alias_max);

if let (Some(u), Some(p)) = (&opts.username, &opts.password) {
options.set_credentials(u, p);
Expand All @@ -113,11 +134,12 @@ impl MqttBroker {

impl Config {
pub fn broker_mqtt_options(&self) -> Vec<rumqttc::v5::MqttOptions> {
let base_options = &self.mqtt.options;
return self
.mqtt
.brokers
.iter()
.map(|s| s.to_mqtt_options())
.map(|s| s.to_mqtt_options(base_options))
.collect::<Vec<_>>();
}
}
35 changes: 27 additions & 8 deletions src/lookup_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ where
.pool
.pop_front()
.unwrap_or_else(|| self.iterator.next().unwrap());
// We can use insert_forward to insert anything, so we need to make sure the value from
// iterator is not already used, otherwise try to get the next one
if self.reverse.contains_key(&b) {
return self.get_forward(a);
}
if let Some((.., v)) = self.forward.push(a.clone(), b) {
if v != b {
self.reverse.remove(&v);
Expand All @@ -52,16 +57,30 @@ where
self.reverse.get(b)
}

pub fn insert_forward(&mut self, a: &A, b: B) {
if let Some((.., v)) = self.forward.push(a.clone(), b) {
if v != b {
self.reverse.remove(&v);
self.pool.push_back(v);
}
}
self.reverse.insert(b, a.clone());
}

pub fn cap(&self) -> usize {
self.forward.cap().get()
}

// pub fn resize(&mut self, range: Range<B>) {
// let size = range.len() - 1;
// self.range = range;
// self.pool.clear();
// self.reverse.clear();
// self.forward.clear();
// self.forward.resize(NonZeroUsize::new(size).unwrap());
// }
pub fn resize(&mut self, iterator: I) {
let size = iterator.len() - 1;
self.iterator = iterator;
self.pool.clear();
self.reverse.clear();
self.forward.clear();
self.forward.resize(NonZeroUsize::new(size).unwrap());
}

pub fn contains(&self, a: &A) -> bool {
self.forward.contains(a)
}
}
151 changes: 109 additions & 42 deletions src/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,26 @@ use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::{sync::mpsc, task};

use crate::lookup_pool::LookupPool;
use rumqttc::v5::mqttbytes::v5::PublishProperties;
use std::ops::Range;

// Context for receiving messsage from remote
struct RemoteIncomingContext {
mqtt_client: Arc<mqtt::AsyncClient>,
nth: usize,
sender: mpsc::Sender<(String, Bytes)>,
subs: Arc<Mutex<Vec<String>>>,
alias_pool: Option<LookupPool<Bytes, u16, Range<u16>>>, // alias they created
}

struct RemoteClient {
nth: usize,
mqttc: Arc<mqtt::AsyncClient>,
// alias: LookupPool<String, u16>,
alias_pool: Option<LookupPool<String, u16, Range<u16>>>, // alias we invented
}

// Used for sending message to remote
pub struct Remote {
clients: Vec<RemoteClient>,
subs: Arc<Mutex<Vec<String>>>,
Expand All @@ -33,28 +47,31 @@ impl Remote {
for (idx, opt) in broker_opts.iter().enumerate() {
let (mqtt_client, mut event_loop) = mqtt::AsyncClient::new(opt.clone(), 128);
let arc_mqtt_client = Arc::new(mqtt_client);

let alias_pool = opt
.topic_alias_max()
.filter(|n| *n > 0)
.map(|count| LookupPool::new(1..count));
let remote_client = RemoteClient {
nth: idx,
mqttc: arc_mqtt_client.clone(),
// alias: LookupPool::new(0..opt.topic_alias_max().unwrap_or(2)),
alias_pool,
};

let mut context = RemoteIncomingContext {
mqtt_client: arc_mqtt_client,
nth: idx,
sender: sender.clone(),
subs: subs.clone(),
alias_pool: None, // TODO: Don't send with topic alias by default, until getting topic_alias_max from remote
};
let loop_sender = sender.clone();
let loop_mqtt_client = arc_mqtt_client.clone();
let loop_subs = subs.clone();
task::spawn(async move {
loop {
use mqtt::Event::Incoming;
match event_loop.poll().await {
Ok(Incoming(pkt)) => {
log::debug!("Received Incoming Packet {:?}", pkt);
Self::handle_packet(
&loop_mqtt_client,
idx,
loop_sender.clone(),
pkt,
loop_subs.clone(),
)
.await;
Self::handle_packet(&mut context, pkt).await;
}
x => {
log::debug!("Received Other Packet {:?}", x);
Expand All @@ -68,45 +85,72 @@ impl Remote {
(remote, receiver)
}

async fn handle_packet(
mqtt_client: &mqtt::AsyncClient,
nth: usize,
sender: mpsc::Sender<(String, Bytes)>,
pkt: Packet,
subs_m: Arc<Mutex<Vec<String>>>,
) {
async fn handle_packet(context: &mut RemoteIncomingContext, pkt: Packet) {
use mqtt::mqttbytes::v5::{ConnAck, ConnectReturnCode::Success, Filter, Publish};
match pkt {
Packet::ConnAck(ConnAck {
code: Success,
properties: Some(_prop),
properties: Some(prop),
session_present,
}) => {
// if let Some(alias_max) = prop.topic_alias_max {
// log::info!("broker[{}] alias resize({})", nth, alias_max);
// client.alias.resize(0..alias_max);
// }
if let Some(alias_max) = prop.topic_alias_max.filter(|n| *n > 0) {
let range = 1..alias_max;
if let Some(ref mut alias_pool) = context.alias_pool {
alias_pool.resize(range);
} else {
context.alias_pool = Some(LookupPool::new(range));
}
} else {
context.alias_pool = None
}
if !session_present {
log::info!("broker[{}] !session_present", nth);
let subs_v = subs_m.lock().await;
log::info!("broker[{}] !session_present", context.nth);
let subs_v = context.subs.lock().await;
let subs = subs_v.iter().map(|path| Filter {
path: path.clone(),
qos: QoS::AtMostOnce,
nolocal: false,
preserve_retain: false,
retain_forward_rule: Default::default(),
});
if let Err(err) = mqtt_client.subscribe_many(subs).await {
log::info!("broker[{}] subscribe_many {:?}", nth, err);
if let Err(err) = context.mqtt_client.subscribe_many(subs).await {
log::info!("broker[{}] subscribe_many {:?}", context.nth, err);
}
}
}
Packet::ConnAck(ConnAck { code, .. }) => {
panic!("Refused by broker: {:?}", code);
}
Packet::Publish(Publish { topic, payload, .. }) => {
if let Ok(topic) = String::from_utf8(topic.to_vec()) {
_ = sender.send((topic, payload)).await; // What if it's not ok?
Packet::Publish(Publish {
topic,
payload,
properties,
..
}) => {
let topic_alias = properties.and_then(|p| p.topic_alias);
let topic_str = String::from_utf8(topic.to_vec())
.ok()
.filter(|n| n.len() > 0);
if let (Some(alias), Some(_)) = (topic_alias, &topic_str) {
if let Some(ref mut pool) = context.alias_pool {
pool.insert_forward(&topic, alias);
}
}
if let Some(topic) = topic_str {
_ = context.sender.send((topic, payload)).await; // What if it's not ok?
} else if let Some(alias) = topic_alias {
// No topic but we have alias
if let Some(ref mut pool) = context.alias_pool {
if let Some(t) = pool
.get_reverse(&alias)
.and_then(|t| String::from_utf8(t.to_vec()).ok())
{
log::debug!("Received message, Alias: {:?}, topic: {:?}", alias, t);
_ = context.sender.send((t, payload)).await;
} else {
log::error!("Cannot find topic for alias {:?}", alias);
}
}
} else {
log::debug!("drop packet, non utf8 topic: {:?}", topic);
}
Expand All @@ -133,22 +177,45 @@ impl Remote {
unreachable!()
}

pub async fn publish(&self, topic: &String, payload: Vec<u8>) -> Result<(), mqtt::ClientError> {
if self.clients.len() == 1 {
return self.clients[0]
.mqttc
.publish(topic, QoS::AtMostOnce, false, payload)
.await;
}
for (idx, client) in self.clients.iter().enumerate() {
pub async fn publish(
&mut self,
topic: &String,
payload: Vec<u8>,
) -> Result<(), mqtt::ClientError> {
let clients_length = self.clients.len();
for (idx, client) in self.clients.iter_mut().enumerate() {
let mut properties: PublishProperties = Default::default();
let topic_to_send = if let Some(ref mut pool) = client.alias_pool {
let already_sent_alias = pool.contains(topic);
let alias = pool.get_forward(topic);
properties.topic_alias = Some(alias);
if already_sent_alias {
""
} else {
topic
}
} else {
topic
};
log::debug!(
"Sending message with topic: {:?} props: {:?}",
topic_to_send,
properties
);
let res = client
.mqttc
.publish(topic.clone(), QoS::AtMostOnce, false, payload.clone())
.publish_with_properties(
topic_to_send,
QoS::AtMostOnce,
false,
payload.clone(),
properties,
)
.await;
if res.is_ok() {
return Ok(());
}
if idx == self.clients.len() - 1 {
if idx == clients_length - 1 {
return res;
}
}
Expand Down
Loading

0 comments on commit 1cc28c4

Please sign in to comment.