Skip to content

Commit

Permalink
Feat/querier api (#64)
Browse files Browse the repository at this point in the history
* Add Querier interface and handling of messages in plugin, Minor refactor of plugin structure

* add querier Class, RemoteQuerier, Locality, ReplyKeyExpr QuerierOptions, QuerierGetOptions

* Fix spelling error, add undeclare in typescript

* fmt

* Adding Querier Example

* Fix incorrect Enum->int functions

* Fix Incorrect ordering of deserializing enums
  • Loading branch information
Charles-Schleich authored Dec 11, 2024
1 parent e007d6f commit 3ea3779
Show file tree
Hide file tree
Showing 13 changed files with 944 additions and 158 deletions.
93 changes: 91 additions & 2 deletions zenoh-plugin-remote-api/src/handle_control_message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@

use std::{error::Error, net::SocketAddr, time::Duration};

use base64::{prelude::BASE64_STANDARD, Engine};
use tracing::{error, warn};
use uuid::Uuid;
use zenoh::{
bytes::ZBytes,
handlers::{FifoChannel, RingChannel},
key_expr::KeyExpr,
query::Selector,
};

use crate::{
interface::{
ControlMsg, DataMsg, HandlerChannel, LivelinessMsg, QueryWS, QueryableMsg, RemoteAPIMsg,
ReplyWS, SampleWS,
B64String, ControlMsg, DataMsg, HandlerChannel, LivelinessMsg, QueryWS, QueryableMsg,
RemoteAPIMsg, ReplyWS, SampleWS,
},
spawn_future, RemoteState, StateMap,
};
Expand Down Expand Up @@ -333,7 +335,94 @@ pub(crate) async fn handle_control_message(
ControlMsg::Liveliness(liveliness_msg) => {
return handle_liveliness(liveliness_msg, state_map).await;
}
ControlMsg::DeclareQuerier {
id,
key_expr,
target,
timeout,
accept_replies,
congestion_control,
priority,
consolidation,
allowed_destination,
express,
} => {
let mut querier_builder = state_map.session.declare_querier(key_expr);
let timeout = timeout.map(|millis| Duration::from_millis(millis));

add_if_some!(target, querier_builder);
add_if_some!(timeout, querier_builder);
add_if_some!(accept_replies, querier_builder);
add_if_some!(accept_replies, querier_builder);
add_if_some!(congestion_control, querier_builder);
add_if_some!(priority, querier_builder);
add_if_some!(consolidation, querier_builder);
add_if_some!(allowed_destination, querier_builder);
add_if_some!(express, querier_builder);

let querier = querier_builder.await?;
state_map.queriers.insert(id, querier);
}
ControlMsg::UndeclareQuerier(uuid) => {
if let Some(querier) = state_map.queriers.remove(&uuid) {
querier.undeclare().await?;
} else {
warn!("No Querier Found with UUID {}", uuid);
};
}
ControlMsg::QuerierGet {
get_id,
querier_id,
encoding,
payload,
attachment,
} => {
if let Some(querier) = state_map.queriers.get(&querier_id) {
let mut get_builder = querier.get();

let payload = payload
.map(|B64String(x)| BASE64_STANDARD.decode(x))
.and_then(|res_vec_bytes| {
if let Ok(vec_bytes) = res_vec_bytes {
Some(ZBytes::from(vec_bytes))
} else {
None
}
});

let attachment: Option<ZBytes> = attachment
.map(|B64String(x)| BASE64_STANDARD.decode(x))
.and_then(|res_vec_bytes| {
if let Ok(vec_bytes) = res_vec_bytes {
Some(ZBytes::from(vec_bytes))
} else {
None
}
});
add_if_some!(encoding, get_builder);
add_if_some!(payload, get_builder);
add_if_some!(attachment, get_builder);
let receiver = get_builder.await?;
let ws_tx = state_map.websocket_tx.clone();
let finish_msg = RemoteAPIMsg::Control(ControlMsg::GetFinished { id: get_id });

spawn_future(async move {
while let Ok(reply) = receiver.recv_async().await {
let reply_ws = ReplyWS::from((reply, get_id));
let remote_api_msg = RemoteAPIMsg::Data(DataMsg::GetReply(reply_ws));
if let Err(err) = ws_tx.send(remote_api_msg) {
tracing::error!("{}", err);
}
}
if let Err(err) = ws_tx.send(finish_msg) {
tracing::error!("{}", err);
}
});
} else {
// TODO: Do we want to add an error here ?
warn!("No Querier With ID {querier_id} found")
}
}
msg @ (ControlMsg::GetFinished { id: _ }
| ControlMsg::Session(_)
| ControlMsg::Subscriber(_)) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,24 @@

use std::sync::Arc;

// mod interface::ser_de;
pub(crate) mod ser_de;
use base64::{prelude::BASE64_STANDARD, Engine};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use ser_de::{
deserialize_congestion_control, deserialize_consolidation_mode, deserialize_locality,
deserialize_priority, deserialize_query_target, deserialize_reliability,
deserialize_reply_key_expr, serialize_congestion_control, serialize_consolidation_mode,
serialize_locality, serialize_priority, serialize_query_target, serialize_reliability,
serialize_reply_key_expr,
};
use serde::{Deserialize, Serialize};
use ts_rs::TS;
use uuid::Uuid;
use zenoh::{
key_expr::OwnedKeyExpr,
qos::{CongestionControl, Priority, Reliability},
query::{ConsolidationMode, Query, Reply, ReplyError},
sample::{Sample, SampleKind},
query::{ConsolidationMode, Query, QueryTarget, Reply, ReplyError, ReplyKeyExpr},
sample::{Locality, Sample, SampleKind},
};

// ██████ ███████ ███ ███ ██████ ████████ ███████ █████ ██████ ██ ███ ███ ███████ ███████ ███████ █████ ██████ ███████
Expand All @@ -34,7 +43,7 @@ use zenoh::{
#[derive(TS)]
#[ts(export)]
#[derive(Debug, Serialize, Deserialize)]
pub(crate) struct B64String(String);
pub(crate) struct B64String(pub String);
impl From<String> for B64String {
fn from(value: String) -> Self {
B64String(value)
Expand Down Expand Up @@ -247,6 +256,70 @@ pub enum ControlMsg {
complete: bool,
},
UndeclareQueryable(Uuid),
// Quierer
DeclareQuerier {
id: Uuid,
#[ts(as = "OwnedKeyExprWrapper")]
key_expr: OwnedKeyExpr,
#[serde(
deserialize_with = "deserialize_query_target",
serialize_with = "serialize_query_target",
default
)]
#[ts(type = "number | undefined")]
target: Option<QueryTarget>,
#[ts(type = "number | undefined")]
timeout: Option<u64>,
#[serde(
deserialize_with = "deserialize_reply_key_expr",
serialize_with = "serialize_reply_key_expr",
default
)]
#[ts(type = "number | undefined")]
accept_replies: Option<ReplyKeyExpr>,
#[serde(
deserialize_with = "deserialize_locality",
serialize_with = "serialize_locality",
default
)]
#[ts(type = "number | undefined")]
allowed_destination: Option<Locality>,
#[serde(
deserialize_with = "deserialize_congestion_control",
serialize_with = "serialize_congestion_control",
default
)]
#[ts(type = "number | undefined")]
congestion_control: Option<CongestionControl>,
#[serde(
deserialize_with = "deserialize_priority",
serialize_with = "serialize_priority",
default
)]
#[ts(type = "number | undefined")]
priority: Option<Priority>,
#[serde(
deserialize_with = "deserialize_consolidation_mode",
serialize_with = "serialize_consolidation_mode",
default
)]
#[ts(type = "number | undefined")]
consolidation: Option<ConsolidationMode>,
#[ts(type = "boolean | undefined")]
express: Option<bool>,
},
UndeclareQuerier(Uuid),
// Querier
QuerierGet {
querier_id: Uuid,
get_id: Uuid,
#[ts(type = "string | undefined")]
encoding: Option<String>,
#[ts(type = "string | undefined")]
payload: Option<B64String>,
#[ts(type = "string | undefined")]
attachment: Option<B64String>,
},

// Liveliness
Liveliness(LivelinessMsg),
Expand Down Expand Up @@ -278,151 +351,6 @@ pub enum LivelinessMsg {
},
}

fn deserialize_consolidation_mode<'de, D>(d: D) -> Result<Option<ConsolidationMode>, D::Error>
where
D: Deserializer<'de>,
{
match Option::<u8>::deserialize(d) {
Ok(Some(value)) => Ok(Some(match value {
0u8 => ConsolidationMode::Auto,
1u8 => ConsolidationMode::None,
2u8 => ConsolidationMode::Monotonic,
3u8 => ConsolidationMode::Latest,
_ => {
return Err(serde::de::Error::custom(format!(
"Value not valid for ConsolidationMode Enum {:?}",
value
)))
}
})),
Ok(None) => Ok(None),
Err(err) => Err(serde::de::Error::custom(format!(
"Value not valid for ConsolidationMode Enum {:?}",
err
))),
}
}

fn serialize_consolidation_mode<S>(
consolidation_mode: &Option<ConsolidationMode>,
s: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match consolidation_mode {
Some(c_mode) => s.serialize_u8(*c_mode as u8),
None => s.serialize_none(),
}
}

fn deserialize_congestion_control<'de, D>(d: D) -> Result<Option<CongestionControl>, D::Error>
where
D: Deserializer<'de>,
{
match Option::<u8>::deserialize(d) {
Ok(Some(value)) => Ok(Some(match value {
0u8 => CongestionControl::Drop,
1u8 => CongestionControl::Block,
val => {
return Err(serde::de::Error::custom(format!(
"Value not valid for CongestionControl Enum {:?}",
val
)))
}
})),
Ok(None) => Ok(None),
val => Err(serde::de::Error::custom(format!(
"Value not valid for CongestionControl Enum {:?}",
val
))),
}
}

fn serialize_congestion_control<S>(
congestion_control: &Option<CongestionControl>,
s: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match congestion_control {
Some(c_ctrl) => s.serialize_u8(*c_ctrl as u8),
None => s.serialize_none(),
}
}

fn deserialize_priority<'de, D>(d: D) -> Result<Option<Priority>, D::Error>
where
D: Deserializer<'de>,
{
match Option::<u8>::deserialize(d) {
Ok(Some(value)) => Ok(Some(match value {
1u8 => Priority::RealTime,
2u8 => Priority::InteractiveHigh,
3u8 => Priority::InteractiveLow,
4u8 => Priority::DataHigh,
5u8 => Priority::Data,
6u8 => Priority::DataLow,
7u8 => Priority::Background,
val => {
return Err(serde::de::Error::custom(format!(
"Value not valid for Priority Enum {:?}",
val
)))
}
})),
Ok(None) => Ok(None),
val => Err(serde::de::Error::custom(format!(
"Value not valid for Priority Enum {:?}",
val
))),
}
}

fn serialize_priority<S>(priority: &Option<Priority>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match priority {
Some(prio) => s.serialize_u8(*prio as u8),
None => s.serialize_none(),
}
}

fn deserialize_reliability<'de, D>(d: D) -> Result<Option<Reliability>, D::Error>
where
D: Deserializer<'de>,
{
match Option::<u8>::deserialize(d) {
Ok(Some(value)) => Ok(Some(match value {
0u8 => Reliability::Reliable,
1u8 => Reliability::BestEffort,
val => {
return Err(serde::de::Error::custom(format!(
"Value not valid for Reliability Enum {:?}",
val
)))
}
})),
Ok(None) => Ok(None),
val => Err(serde::de::Error::custom(format!(
"Value not valid for Reliability Enum {:?}",
val
))),
}
}

fn serialize_reliability<S>(reliability: &Option<Reliability>, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match reliability {
Some(prio) => s.serialize_u8(*prio as u8),
None => s.serialize_none(),
}
}

#[derive(Debug, Serialize, Deserialize, TS)]
pub(crate) enum HandlerChannel {
Fifo(usize),
Expand Down
Loading

0 comments on commit 3ea3779

Please sign in to comment.