From bdbcd460cfe49f87004607a9a6a18315199f2616 Mon Sep 17 00:00:00 2001 From: Reisen Date: Wed, 17 Apr 2024 08:15:40 +0000 Subject: [PATCH] refactor: functionalize rpc --- src/agent/pythd/api/rpc.rs | 640 +++++++++++++++++++------------------ 1 file changed, 334 insertions(+), 306 deletions(-) diff --git a/src/agent/pythd/api/rpc.rs b/src/agent/pythd/api/rpc.rs index 0021e44..177aa69 100644 --- a/src/agent/pythd/api/rpc.rs +++ b/src/agent/pythd/api/rpc.rs @@ -113,348 +113,378 @@ enum ConnectionError { WebsocketConnectionClosed, } -struct Connection { - // Channel for communicating with the adapter +async fn handle_connection( + ws_conn: WebSocket, adapter_tx: mpsc::Sender, - - // Channel Websocket messages are sent and received on - ws_tx: SplitSink, - ws_rx: SplitStream, - - // Channel NotifyPrice events are sent and received on - notify_price_tx: mpsc::Sender, - notify_price_rx: mpsc::Receiver, - - // Channel NotifyPriceSched events are sent and received on - notify_price_sched_tx: mpsc::Sender, - notify_price_sched_rx: mpsc::Receiver, - + notify_price_tx_buffer: usize, + notify_price_sched_tx_buffer: usize, logger: Logger, -} +) { + // Create the channels + let (mut ws_tx, mut ws_rx) = ws_conn.split(); + let (mut notify_price_tx, mut notify_price_rx) = mpsc::channel(notify_price_tx_buffer); + let (mut notify_price_sched_tx, mut notify_price_sched_rx) = + mpsc::channel(notify_price_sched_tx_buffer); + + loop { + if let Err(err) = handle_next( + &logger, + &adapter_tx, + &mut ws_tx, + &mut ws_rx, + &mut notify_price_tx, + &mut notify_price_rx, + &mut notify_price_sched_tx, + &mut notify_price_sched_rx, + ) + .await + { + if let Some(ConnectionError::WebsocketConnectionClosed) = + err.downcast_ref::() + { + info!(logger, "websocket connection closed"); + return; + } -impl Connection { - fn new( - ws_conn: WebSocket, - adapter_tx: mpsc::Sender, - notify_price_tx_buffer: usize, - notify_price_sched_tx_buffer: usize, - logger: Logger, - ) -> Self { - // Create the channels - let (ws_tx, ws_rx) = ws_conn.split(); - let (notify_price_tx, notify_price_rx) = mpsc::channel(notify_price_tx_buffer); - let (notify_price_sched_tx, notify_price_sched_rx) = - mpsc::channel(notify_price_sched_tx_buffer); - - // Create the new connection object - Connection { - adapter_tx, - ws_tx, - ws_rx, - notify_price_tx, - notify_price_rx, - notify_price_sched_tx, - notify_price_sched_rx, - logger, + error!(logger, "{}", err); + debug!(logger, "error context"; "context" => format!("{:?}", err)); } } +} - async fn consume(&mut self) { - loop { - if let Err(err) = self.handle_next().await { - if let Some(ConnectionError::WebsocketConnectionClosed) = - err.downcast_ref::() - { - info!(self.logger, "websocket connection closed"); - return; - } - - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); +async fn handle_next( + logger: &Logger, + adapter_tx: &mpsc::Sender, + ws_tx: &mut SplitSink, + ws_rx: &mut SplitStream, + notify_price_tx: &mut mpsc::Sender, + notify_price_rx: &mut mpsc::Receiver, + notify_price_sched_tx: &mut mpsc::Sender, + notify_price_sched_rx: &mut mpsc::Receiver, +) -> Result<()> { + tokio::select! { + msg = ws_rx.next() => { + match msg { + Some(body) => match body { + Ok(msg) => { + handle( + logger, + ws_tx, + adapter_tx, + notify_price_tx, + notify_price_sched_tx, + msg, + ) + .await + } + Err(e) => send_error(ws_tx, e.into(), None).await, + }, + None => Err(ConnectionError::WebsocketConnectionClosed)?, } } - } - - async fn handle_next(&mut self) -> Result<()> { - tokio::select! { - msg = self.ws_rx.next() => { - match msg { - Some(body) => self.handle_ws_rx(body).await, - None => Err(ConnectionError::WebsocketConnectionClosed)?, - } - } - Some(notify_price) = self.notify_price_rx.recv() => { - self.handle_notify_price(notify_price).await - } - Some(notify_price_sched) = self.notify_price_sched_rx.recv() => { - self.handle_notify_price_sched(notify_price_sched).await - } + Some(notify_price) = notify_price_rx.recv() => { + send_notification(ws_tx, Method::NotifyPrice, Some(notify_price)) + .await } - } - - async fn handle_ws_rx(&mut self, body: Result) -> Result<()> { - match body { - Ok(msg) => self.handle(msg).await, - Err(e) => self.send_error(e.into(), None).await, + Some(notify_price_sched) = notify_price_sched_rx.recv() => { + send_notification(ws_tx, Method::NotifyPriceSched, Some(notify_price_sched)) + .await } } +} - async fn handle_notify_price(&mut self, notify_price: NotifyPrice) -> Result<()> { - self.send_notification(Method::NotifyPrice, Some(notify_price)) - .await - } - - async fn handle_notify_price_sched( - &mut self, - notify_price_sched: NotifyPriceSched, - ) -> Result<()> { - self.send_notification(Method::NotifyPriceSched, Some(notify_price_sched)) - .await +async fn handle( + logger: &Logger, + ws_tx: &mut SplitSink, + adapter_tx: &mpsc::Sender, + notify_price_tx: &mpsc::Sender, + notify_price_sched_tx: &mpsc::Sender, + msg: Message, +) -> Result<()> { + // Ignore control and binary messages + if !msg.is_text() { + debug!(logger, "JSON RPC API: skipped non-text message"); + return Ok(()); } - async fn handle(&mut self, msg: Message) -> Result<()> { - // Ignore control and binary messages - if !msg.is_text() { - debug!(self.logger, "JSON RPC API: skipped non-text message"); - return Ok(()); - } - - // Parse and dispatch the message - match self.parse(msg).await { - Ok((requests, is_batch)) => { - let mut responses = Vec::with_capacity(requests.len()); - - // Perform requests in sequence and gather responses - for request in requests { - let response = self.dispatch_and_catch_error(&request).await; - responses.push(response) - } - - // Send an array if we're handling a batch - // request, single response object otherwise - if is_batch { - self.send_text(&serde_json::to_string(&responses)?).await?; - } else { - self.send_text(&serde_json::to_string(&responses[0])?) - .await?; - } - } - // The top-level parsing errors are fine to share with client - Err(e) => { - self.send_error(e, None).await?; + // Parse and dispatch the message + match parse(msg).await { + Ok((requests, is_batch)) => { + let mut responses = Vec::with_capacity(requests.len()); + + // Perform requests in sequence and gather responses + for request in requests { + let response = dispatch_and_catch_error( + logger, + adapter_tx, + notify_price_tx, + notify_price_sched_tx, + &request, + ) + .await; + responses.push(response) } - } - - Ok(()) - } - /// Parse a JSONRPC request object or a batch of them. The - /// bool in result informs request handling whether it needs - /// to respond with a single object or an array, to prevent - /// sending unexpected - /// `[{}]` - /// array payloads. - async fn parse(&mut self, msg: Message) -> Result<(Vec>, bool)> { - let s = msg - .to_str() - .map_err(|_| anyhow!("Could not parse message as text"))?; - - let json_value: Value = serde_json::from_str(s)?; - if let Some(array) = json_value.as_array() { - // Interpret request as JSON-RPC 2.0 batch if value is an array - let mut requests = Vec::with_capacity(array.len()); - for maybe_request in array { - // Re-serialize for parse_request(), it's the only - // jrpc parsing function available and it's taking - // &str. - let maybe_request_string = serde_json::to_string(maybe_request)?; - requests.push( - parse_request::(&maybe_request_string) - .map_err(|e| anyhow!("Could not parse message: {}", e.error.message))?, - ); + // Send an array if we're handling a batch + // request, single response object otherwise + if is_batch { + send_text(ws_tx, &serde_json::to_string(&responses)?).await?; + } else { + send_text(ws_tx, &serde_json::to_string(&responses[0])?).await?; } - - Ok((requests, true)) - } else { - // Base single request case - let single = parse_request::(s) - .map_err(|e| anyhow!("Could not parse message: {}", e.error.message))?; - Ok((vec![single], false)) } - } - - async fn dispatch_and_catch_error( - &mut self, - request: &Request, - ) -> Response { - debug!(self.logger, - "JSON RPC API: handling request"; - "method" => format!("{:?}", request.method), - ); - let result = match request.method { - Method::GetProductList => self.get_product_list().await, - Method::GetProduct => self.get_product(request).await, - Method::GetAllProducts => self.get_all_products().await, - Method::SubscribePrice => self.subscribe_price(request).await, - Method::SubscribePriceSched => self.subscribe_price_sched(request).await, - Method::UpdatePrice => self.update_price(request).await, - Method::NotifyPrice | Method::NotifyPriceSched => { - Err(anyhow!("unsupported method: {:?}", request.method)) - } - }; - - // Consider errors internal, print details to logs. - match result { - Ok(payload) => { - Response::success(request.id.clone().to_id().unwrap_or(Id::from(0)), payload) - } - Err(e) => { - warn!( - self.logger, - "Error handling JSON RPC request"; - "request" => format!("{:?}", request), - "error" => format!("{}", e.to_string()), - ); - Response::error( - request.id.clone().to_id().unwrap_or(Id::from(0)), - ErrorCode::InternalError, - e.to_string(), - None, - ) - } + // The top-level parsing errors are fine to share with client + Err(e) => { + send_error(ws_tx, e, None).await?; } } - async fn get_product_list(&mut self) -> Result { - let (result_tx, result_rx) = oneshot::channel(); - self.adapter_tx - .send(adapter::Message::GetProductList { result_tx }) - .await?; - - Ok(serde_json::to_value(result_rx.await??)?) - } - - async fn get_product(&mut self, request: &Request) -> Result { - let params: GetProductParams = self.deserialize_params(request.params.clone())?; + Ok(()) +} - let (result_tx, result_rx) = oneshot::channel(); - self.adapter_tx - .send(adapter::Message::GetProduct { - account: params.account, - result_tx, - }) - .await?; +/// Parse a JSONRPC request object or a batch of them. The +/// bool in result informs request handling whether it needs +/// to respond with a single object or an array, to prevent +/// sending unexpected +/// `[{}]` +/// array payloads. +async fn parse(msg: Message) -> Result<(Vec>, bool)> { + let s = msg + .to_str() + .map_err(|_| anyhow!("Could not parse message as text"))?; + + let json_value: Value = serde_json::from_str(s)?; + if let Some(array) = json_value.as_array() { + // Interpret request as JSON-RPC 2.0 batch if value is an array + let mut requests = Vec::with_capacity(array.len()); + for maybe_request in array { + // Re-serialize for parse_request(), it's the only + // jrpc parsing function available and it's taking + // &str. + let maybe_request_string = serde_json::to_string(maybe_request)?; + requests.push( + parse_request::(&maybe_request_string) + .map_err(|e| anyhow!("Could not parse message: {}", e.error.message))?, + ); + } - Ok(serde_json::to_value(result_rx.await??)?) + Ok((requests, true)) + } else { + // Base single request case + let single = parse_request::(s) + .map_err(|e| anyhow!("Could not parse message: {}", e.error.message))?; + Ok((vec![single], false)) } +} - async fn get_all_products(&mut self) -> Result { - let (result_tx, result_rx) = oneshot::channel(); - self.adapter_tx - .send(adapter::Message::GetAllProducts { result_tx }) - .await?; - - Ok(serde_json::to_value(result_rx.await??)?) - } +async fn dispatch_and_catch_error( + logger: &Logger, + adapter_tx: &mpsc::Sender, + notify_price_tx: &mpsc::Sender, + notify_price_sched_tx: &mpsc::Sender, + request: &Request, +) -> Response { + debug!( + logger, + "JSON RPC API: handling request"; + "method" => format!("{:?}", request.method), + ); - async fn subscribe_price( - &mut self, - request: &Request, - ) -> Result { - let params: SubscribePriceParams = self.deserialize_params(request.params.clone())?; + let result = match request.method { + Method::GetProductList => get_product_list(adapter_tx).await, + Method::GetProduct => get_product(adapter_tx, request).await, + Method::GetAllProducts => get_all_products(adapter_tx).await, + Method::UpdatePrice => update_price(adapter_tx, request).await, + Method::SubscribePrice => subscribe_price(adapter_tx, notify_price_tx, request).await, + Method::SubscribePriceSched => { + subscribe_price_sched(adapter_tx, notify_price_sched_tx, request).await + } + Method::NotifyPrice | Method::NotifyPriceSched => { + Err(anyhow!("unsupported method: {:?}", request.method)) + } + }; - let (result_tx, result_rx) = oneshot::channel(); - self.adapter_tx - .send(adapter::Message::SubscribePrice { - result_tx, - account: params.account, - notify_price_tx: self.notify_price_tx.clone(), - }) - .await?; + // Consider errors internal, print details to logs. + match result { + Ok(payload) => { + Response::success(request.id.clone().to_id().unwrap_or(Id::from(0)), payload) + } + Err(e) => { + warn!( + logger, + "Error handling JSON RPC request"; + "request" => format!("{:?}", request), + "error" => format!("{}", e.to_string()), + ); - Ok(serde_json::to_value(SubscribeResult { - subscription: result_rx.await??, - })?) + Response::error( + request.id.clone().to_id().unwrap_or(Id::from(0)), + ErrorCode::InternalError, + e.to_string(), + None, + ) + } } +} - async fn subscribe_price_sched( - &mut self, - request: &Request, - ) -> Result { - let params: SubscribePriceSchedParams = self.deserialize_params(request.params.clone())?; +async fn get_product_list( + adapter_tx: &mpsc::Sender, +) -> Result { + let (result_tx, result_rx) = oneshot::channel(); + adapter_tx + .send(adapter::Message::GetProductList { result_tx }) + .await?; + Ok(serde_json::to_value(result_rx.await??)?) +} - let (result_tx, result_rx) = oneshot::channel(); - self.adapter_tx - .send(adapter::Message::SubscribePriceSched { - result_tx, - account: params.account, - notify_price_sched_tx: self.notify_price_sched_tx.clone(), - }) - .await?; +async fn get_product( + adapter_tx: &mpsc::Sender, + request: &Request, +) -> Result { + let params: GetProductParams = { + let value = request.params.clone(); + serde_json::from_value(value.ok_or_else(|| anyhow!("Missing request parameters"))?) + }?; + + let (result_tx, result_rx) = oneshot::channel(); + adapter_tx + .send(adapter::Message::GetProduct { + account: params.account, + result_tx, + }) + .await?; + Ok(serde_json::to_value(result_rx.await??)?) +} - Ok(serde_json::to_value(SubscribeResult { - subscription: result_rx.await??, - })?) - } +async fn get_all_products( + adapter_tx: &mpsc::Sender, +) -> Result { + let (result_tx, result_rx) = oneshot::channel(); + adapter_tx + .send(adapter::Message::GetAllProducts { result_tx }) + .await?; + Ok(serde_json::to_value(result_rx.await??)?) +} - async fn update_price( - &mut self, - request: &Request, - ) -> Result { - let params: UpdatePriceParams = self.deserialize_params(request.params.clone())?; - - self.adapter_tx - .send(adapter::Message::UpdatePrice { - account: params.account, - price: params.price, - conf: params.conf, - status: params.status, - }) - .await?; +async fn subscribe_price( + adapter_tx: &mpsc::Sender, + notify_price_tx: &mpsc::Sender, + request: &Request, +) -> Result { + let params: SubscribePriceParams = serde_json::from_value( + request + .params + .clone() + .ok_or_else(|| anyhow!("Missing request parameters"))?, + )?; + + let (result_tx, result_rx) = oneshot::channel(); + adapter_tx + .send(adapter::Message::SubscribePrice { + result_tx, + account: params.account, + notify_price_tx: notify_price_tx.clone(), + }) + .await?; + + Ok(serde_json::to_value(SubscribeResult { + subscription: result_rx.await??, + })?) +} - Ok(serde_json::to_value(0)?) - } +async fn subscribe_price_sched( + adapter_tx: &mpsc::Sender, + notify_price_sched_tx: &mpsc::Sender, + request: &Request, +) -> Result { + let params: SubscribePriceSchedParams = serde_json::from_value( + request + .params + .clone() + .ok_or_else(|| anyhow!("Missing request parameters"))?, + )?; + + let (result_tx, result_rx) = oneshot::channel(); + adapter_tx + .send(adapter::Message::SubscribePriceSched { + result_tx, + account: params.account, + notify_price_sched_tx: notify_price_sched_tx.clone(), + }) + .await?; + + Ok(serde_json::to_value(SubscribeResult { + subscription: result_rx.await??, + })?) +} - fn deserialize_params(&self, value: Option) -> Result - where - T: DeserializeOwned, - { - serde_json::from_value::(value.ok_or_else(|| anyhow!("Missing request parameters"))?) - .map_err(|e| e.into()) - } +async fn update_price( + adapter_tx: &mpsc::Sender, + request: &Request, +) -> Result { + let params: UpdatePriceParams = serde_json::from_value( + request + .params + .clone() + .ok_or_else(|| anyhow!("Missing request parameters"))?, + )?; + + adapter_tx + .send(adapter::Message::UpdatePrice { + account: params.account, + price: params.price, + conf: params.conf, + status: params.status, + }) + .await?; + + Ok(serde_json::to_value(0)?) +} - async fn send_error(&mut self, error: anyhow::Error, id: Option) -> Result<()> { - let response: Response = Response::error( - id.unwrap_or_else(|| Id::from(0)), - ErrorCode::InternalError, - error.to_string(), - None, - ); - self.send_text(&response.to_string()).await - } +async fn send_error( + ws_tx: &mut SplitSink, + error: anyhow::Error, + id: Option, +) -> Result<()> { + let response: Response = Response::error( + id.unwrap_or_else(|| Id::from(0)), + ErrorCode::InternalError, + error.to_string(), + None, + ); + send_text(ws_tx, &response.to_string()).await +} - async fn send_notification(&mut self, method: Method, params: Option) -> Result<()> - where - T: Sized + Serialize + DeserializeOwned, - { - self.send_request(IdReq::Notification, method, params).await - } +async fn send_notification( + ws_tx: &mut SplitSink, + method: Method, + params: Option, +) -> Result<()> +where + T: Sized + Serialize + DeserializeOwned, +{ + send_request(ws_tx, IdReq::Notification, method, params).await +} - async fn send_request(&mut self, id: I, method: Method, params: Option) -> Result<()> - where - I: Into, - T: Sized + Serialize + DeserializeOwned, - { - let request = Request::with_params(id, method, params); - self.send_text(&request.to_string()).await - } +async fn send_request( + ws_tx: &mut SplitSink, + id: I, + method: Method, + params: Option, +) -> Result<()> +where + I: Into, + T: Sized + Serialize + DeserializeOwned, +{ + let request = Request::with_params(id, method, params); + send_text(ws_tx, &request.to_string()).await +} - async fn send_text(&mut self, msg: &str) -> Result<()> { - self.ws_tx - .send(Message::text(msg.to_string())) - .await - .map_err(|e| e.into()) - } +async fn send_text(ws_tx: &mut SplitSink, msg: &str) -> Result<()> { + ws_tx + .send(Message::text(msg.to_string())) + .await + .map_err(|e| e.into()) } #[derive(Clone)] @@ -523,15 +553,13 @@ async fn serve( config: Config| { ws.on_upgrade(move |conn| async move { info!(with_logger.logger, "websocket user connected"); - - Connection::new( + handle_connection( conn, adapter_tx, config.notify_price_tx_buffer, config.notify_price_sched_tx_buffer, with_logger.logger, ) - .consume() .await }) },