From 9d3cb67af375abcf1e6e46dc66fe4321faf0dc0e Mon Sep 17 00:00:00 2001 From: Reisen Date: Mon, 22 Apr 2024 03:02:58 +0000 Subject: [PATCH] refactor: migrate adapter service to api --- Cargo.lock | 1 + Cargo.toml | 1 + src/agent.rs | 41 +- src/agent/pythd/adapter.rs | 783 +++------------- src/agent/pythd/adapter/api.rs | 430 +++++++++ src/agent/pythd/api/rpc.rs | 881 ++---------------- src/agent/pythd/api/rpc/get_all_products.rs | 18 +- src/agent/pythd/api/rpc/get_product.rs | 24 +- src/agent/pythd/api/rpc/get_product_list.rs | 18 +- src/agent/pythd/api/rpc/subscribe_price.rs | 30 +- .../pythd/api/rpc/subscribe_price_sched.rs | 30 +- src/agent/pythd/api/rpc/update_price.rs | 24 +- src/agent/store/global.rs | 62 +- 13 files changed, 743 insertions(+), 1600 deletions(-) create mode 100644 src/agent/pythd/adapter/api.rs diff --git a/Cargo.lock b/Cargo.lock index 16421fd..baede69 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3279,6 +3279,7 @@ dependencies = [ "chrono-tz", "clap 4.5.4", "config", + "futures", "futures-util", "humantime", "humantime-serde", diff --git a/Cargo.toml b/Cargo.toml index fc99bf4..330ac5a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ async-trait = "0.1.79" warp = { version = "0.3.6", features = ["websocket"] } tokio = { version = "1.37.0", features = ["full"] } tokio-stream = "0.1.15" +futures = { version = "0.3.30" } futures-util = { version = "0.3.30", default-features = false, features = [ "sink", ] } diff --git a/src/agent.rs b/src/agent.rs index 9ececb7..bc3a153 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -73,12 +73,16 @@ pub mod store; use { self::{ config::Config, - pythd::api::rpc, + pythd::{ + adapter::notifier, + api::rpc, + }, solana::network, }, anyhow::Result, futures_util::future::join_all, slog::Logger, + std::sync::Arc, tokio::sync::{ broadcast, mpsc, @@ -113,8 +117,7 @@ impl Agent { // Create the channels // TODO: make all components listen to shutdown signal - let (shutdown_tx, shutdown_rx) = - broadcast::channel(self.config.channel_capacities.shutdown); + let (shutdown_tx, _) = broadcast::channel(self.config.channel_capacities.shutdown); let (primary_oracle_updates_tx, primary_oracle_updates_rx) = mpsc::channel(self.config.channel_capacities.primary_oracle_updates); let (secondary_oracle_updates_tx, secondary_oracle_updates_rx) = @@ -123,8 +126,6 @@ impl Agent { mpsc::channel(self.config.channel_capacities.global_store_lookup); let (local_store_tx, local_store_rx) = mpsc::channel(self.config.channel_capacities.local_store); - let (pythd_adapter_tx, pythd_adapter_rx) = - mpsc::channel(self.config.channel_capacities.pythd_adapter); let (primary_keypair_loader_tx, primary_keypair_loader_rx) = mpsc::channel(10); let (secondary_keypair_loader_tx, secondary_keypair_loader_rx) = mpsc::channel(10); @@ -152,34 +153,38 @@ impl Agent { )?); } + // Create the Pythd Adapter. + let adapter = Arc::new(pythd::adapter::Adapter::new( + self.config.pythd_adapter.clone(), + global_store_lookup_tx.clone(), + local_store_tx.clone(), + logger.clone(), + )); + + // Create the Notifier task for the Pythd RPC. + jhs.push(tokio::spawn(notifier( + adapter.clone(), + shutdown_tx.subscribe(), + ))); + // Spawn the Global Store jhs.push(store::global::spawn_store( global_store_lookup_rx, primary_oracle_updates_rx, secondary_oracle_updates_rx, - pythd_adapter_tx.clone(), + adapter.clone(), logger.clone(), )); // Spawn the Local Store jhs.push(store::local::spawn_store(local_store_rx, logger.clone())); - // Spawn the Pythd Adapter - jhs.push(pythd::adapter::spawn_adapter( - self.config.pythd_adapter.clone(), - pythd_adapter_rx, - global_store_lookup_tx.clone(), - local_store_tx.clone(), - shutdown_tx.subscribe(), - logger.clone(), - )); - // Spawn the Pythd API Server jhs.push(tokio::spawn(rpc::run( self.config.pythd_api_server.clone(), logger.clone(), - pythd_adapter_tx, - shutdown_rx, + adapter, + shutdown_tx.subscribe(), ))); // Spawn the metrics server diff --git a/src/agent/pythd/adapter.rs b/src/agent/pythd/adapter.rs index a4b3428..939dc84 100644 --- a/src/agent/pythd/adapter.rs +++ b/src/agent/pythd/adapter.rs @@ -1,40 +1,16 @@ use { super::{ - super::{ - solana, - store::{ - global, - local, - PriceIdentifier, - }, + super::store::{ + global, + local, + PriceIdentifier, }, api::{ - self, - Conf, NotifyPrice, NotifyPriceSched, - Price, - PriceAccountMetadata, - PriceUpdate, - ProductAccount, - ProductAccountMetadata, SubscriptionID, }, }, - crate::agent::{ - solana::oracle::PriceEntry, - store::global::AllAccountsData, - }, - anyhow::{ - anyhow, - Result, - }, - chrono::Utc, - pyth_sdk::Identifier, - pyth_sdk_solana::state::{ - PriceComp, - PriceStatus, - }, serde::{ Deserialize, Serialize, @@ -42,22 +18,21 @@ use { slog::Logger, std::{ collections::HashMap, + sync::atomic::AtomicI64, time::Duration, }, - tokio::{ - sync::{ - broadcast, - mpsc, - oneshot, - }, - task::JoinHandle, - time::{ - self, - Interval, - }, + tokio::sync::{ + mpsc, + RwLock, }, }; +mod api; +pub use api::{ + notifier, + AdapterApi, +}; + #[derive(Clone, Serialize, Deserialize, Debug)] #[serde(default)] pub struct Config { @@ -79,20 +54,18 @@ impl Default for Config { /// It is responsible for implementing the business logic for responding to /// the pythd websocket API calls. pub struct Adapter { - /// Channel on which messages are received - message_rx: mpsc::Receiver, - - /// Subscription ID counter - subscription_id_count: SubscriptionID, + /// Subscription ID sequencer. + subscription_id_seq: AtomicI64, /// Notify Price Sched subscriptions - notify_price_sched_subscriptions: HashMap>, + notify_price_sched_subscriptions: + RwLock>>, // Notify Price Subscriptions - notify_price_subscriptions: HashMap>, + notify_price_subscriptions: RwLock>>, /// The fixed interval at which Notify Price Sched notifications are sent - notify_price_sched_interval: Interval, + notify_price_sched_interval_duration: Duration, /// Channel on which to communicate with the global store global_store_lookup_tx: mpsc::Sender, @@ -100,9 +73,6 @@ pub struct Adapter { /// Channel on which to communicate with the local store local_store_tx: mpsc::Sender, - /// Channel on which the shutdown is broadcast - shutdown_rx: broadcast::Receiver<()>, - /// The logger logger: Logger, } @@ -123,491 +93,33 @@ struct NotifyPriceSubscription { notify_price_tx: mpsc::Sender, } -#[derive(Debug)] -pub enum Message { - GlobalStoreUpdate { - price_identifier: PriceIdentifier, - price: i64, - conf: u64, - status: PriceStatus, - valid_slot: u64, - pub_slot: u64, - }, - GetProductList { - result_tx: oneshot::Sender>>, - }, - GetProduct { - account: api::Pubkey, - result_tx: oneshot::Sender>, - }, - GetAllProducts { - result_tx: oneshot::Sender>>, - }, - SubscribePrice { - account: api::Pubkey, - notify_price_tx: mpsc::Sender, - result_tx: oneshot::Sender>, - }, - SubscribePriceSched { - account: api::Pubkey, - notify_price_sched_tx: mpsc::Sender, - result_tx: oneshot::Sender>, - }, - UpdatePrice { - account: api::Pubkey, - price: Price, - conf: Conf, - status: String, - }, -} - -pub fn spawn_adapter( - config: Config, - message_rx: mpsc::Receiver, - global_store_lookup_tx: mpsc::Sender, - local_store_tx: mpsc::Sender, - shutdown_rx: broadcast::Receiver<()>, - logger: Logger, -) -> JoinHandle<()> { - tokio::spawn(async move { - Adapter::new( - config, - message_rx, - global_store_lookup_tx, - local_store_tx, - shutdown_rx, - logger, - ) - .run() - .await - }) -} - impl Adapter { pub fn new( config: Config, - message_rx: mpsc::Receiver, global_store_lookup_tx: mpsc::Sender, local_store_tx: mpsc::Sender, - shutdown_rx: broadcast::Receiver<()>, logger: Logger, ) -> Self { Adapter { - message_rx, - subscription_id_count: 0, - notify_price_sched_subscriptions: HashMap::new(), - notify_price_subscriptions: HashMap::new(), - notify_price_sched_interval: time::interval( - config.notify_price_sched_interval_duration, - ), + subscription_id_seq: 1.into(), + notify_price_sched_subscriptions: RwLock::new(HashMap::new()), + notify_price_subscriptions: RwLock::new(HashMap::new()), + notify_price_sched_interval_duration: config.notify_price_sched_interval_duration, global_store_lookup_tx, local_store_tx, - shutdown_rx, logger, } } - - pub async fn run(&mut self) { - loop { - self.drop_closed_subscriptions(); - tokio::select! { - Some(message) = self.message_rx.recv() => { - if let Err(err) = self.handle_message(message).await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); - } - } - _ = self.shutdown_rx.recv() => { - info!(self.logger, "shutdown signal received"); - return; - } - _ = self.notify_price_sched_interval.tick() => { - if let Err(err) = self.send_notify_price_sched().await { - error!(self.logger, "{}", err); - debug!(self.logger, "error context"; "context" => format!("{:?}", err)); - } - } - } - } - } - - async fn handle_message(&mut self, message: Message) -> Result<()> { - match message { - Message::GetProductList { result_tx } => { - self.send(result_tx, self.handle_get_product_list().await) - } - Message::GetProduct { account, result_tx } => { - self.send(result_tx, self.handle_get_product(&account.parse()?).await) - } - Message::GetAllProducts { result_tx } => { - self.send(result_tx, self.handle_get_all_products().await) - } - Message::SubscribePrice { - account, - notify_price_tx, - result_tx, - } => { - let subscription_id = self - .handle_subscribe_price(&account.parse()?, notify_price_tx) - .await; - self.send(result_tx, Ok(subscription_id)) - } - Message::SubscribePriceSched { - account, - notify_price_sched_tx, - result_tx, - } => { - let subscription_id = self - .handle_subscribe_price_sched(&account.parse()?, notify_price_sched_tx) - .await; - self.send(result_tx, Ok(subscription_id)) - } - Message::UpdatePrice { - account, - price, - conf, - status, - } => { - self.handle_update_price(&account.parse()?, price, conf, status) - .await - } - Message::GlobalStoreUpdate { - price_identifier, - price, - conf, - status, - valid_slot, - pub_slot, - } => { - self.handle_global_store_update( - price_identifier, - price, - conf, - status, - valid_slot, - pub_slot, - ) - .await - } - } - } - - fn send(&self, tx: oneshot::Sender, item: T) -> Result<()> { - tx.send(item).map_err(|_| anyhow!("sending channel full")) - } - - async fn handle_get_product_list(&self) -> Result> { - let all_accounts_metadata = self.lookup_all_accounts_metadata().await?; - - let mut result = Vec::new(); - for (product_account_key, product_account) in - all_accounts_metadata.product_accounts_metadata - { - // Transform the price accounts into the API PriceAccountMetadata structs - // the API uses. - let price_accounts_metadata = product_account - .price_accounts - .iter() - .filter_map(|price_account_key| { - all_accounts_metadata - .price_accounts_metadata - .get(price_account_key) - .map(|acc| (price_account_key, acc)) - }) - .map(|(price_account_key, price_account)| PriceAccountMetadata { - account: price_account_key.to_string(), - price_type: "price".to_owned(), - price_exponent: price_account.expo as i64, - }) - .collect(); - - // Create the product account metadata struct - result.push(ProductAccountMetadata { - account: product_account_key.to_string(), - attr_dict: product_account.attr_dict, - price: price_accounts_metadata, - }) - } - - Ok(result) - } - - // Fetches the Solana-specific Oracle data from the global store - async fn lookup_all_accounts_metadata(&self) -> Result { - let (result_tx, result_rx) = oneshot::channel(); - self.global_store_lookup_tx - .send(global::Lookup::LookupAllAccountsMetadata { result_tx }) - .await?; - result_rx.await? - } - - async fn handle_get_all_products(&self) -> Result> { - let solana_data = self.lookup_all_accounts_data().await?; - - let mut result = Vec::new(); - for (product_account_key, product_account) in &solana_data.product_accounts { - let product_account_api = Self::solana_product_account_to_pythd_api_product_account( - product_account, - &solana_data, - product_account_key, - ); - - result.push(product_account_api) - } - - Ok(result) - } - - async fn lookup_all_accounts_data(&self) -> Result { - let (result_tx, result_rx) = oneshot::channel(); - self.global_store_lookup_tx - .send(global::Lookup::LookupAllAccountsData { - network: solana::network::Network::Primary, - result_tx, - }) - .await?; - result_rx.await? - } - - fn solana_product_account_to_pythd_api_product_account( - product_account: &solana::oracle::ProductEntry, - all_accounts_data: &AllAccountsData, - product_account_key: &solana_sdk::pubkey::Pubkey, - ) -> ProductAccount { - // Extract all the price accounts from the product account - let price_accounts = product_account - .price_accounts - .iter() - .filter_map(|price_account_key| { - all_accounts_data - .price_accounts - .get(price_account_key) - .map(|acc| (price_account_key, acc)) - }) - .map(|(price_account_key, price_account)| { - Self::solana_price_account_to_pythd_api_price_account( - price_account_key, - price_account, - ) - }) - .collect(); - - // Create the product account metadata struct - ProductAccount { - account: product_account_key.to_string(), - attr_dict: product_account - .account_data - .iter() - .filter(|(key, val)| !key.is_empty() && !val.is_empty()) - .map(|(key, val)| (key.to_owned(), val.to_owned())) - .collect(), - price_accounts, - } - } - - fn solana_price_account_to_pythd_api_price_account( - price_account_key: &solana_sdk::pubkey::Pubkey, - price_account: &PriceEntry, - ) -> api::PriceAccount { - api::PriceAccount { - account: price_account_key.to_string(), - price_type: "price".to_string(), - price_exponent: price_account.expo as i64, - status: Self::price_status_to_str(price_account.agg.status), - price: price_account.agg.price, - conf: price_account.agg.conf, - twap: price_account.ema_price.val, - twac: price_account.ema_conf.val, - valid_slot: price_account.valid_slot, - pub_slot: price_account.agg.pub_slot, - prev_slot: price_account.prev_slot, - prev_price: price_account.prev_price, - prev_conf: price_account.prev_conf, - publisher_accounts: price_account - .comp - .iter() - .filter(|comp| *comp != &PriceComp::default()) - .map(|comp| api::PublisherAccount { - account: comp.publisher.to_string(), - status: Self::price_status_to_str(comp.agg.status), - price: comp.agg.price, - conf: comp.agg.conf, - slot: comp.agg.pub_slot, - }) - .collect(), - } - } - - // TODO: implement Display on PriceStatus and then just call PriceStatus::to_string - fn price_status_to_str(price_status: PriceStatus) -> String { - match price_status { - PriceStatus::Unknown => "unknown", - PriceStatus::Trading => "trading", - PriceStatus::Halted => "halted", - PriceStatus::Auction => "auction", - PriceStatus::Ignored => "ignored", - } - .to_string() - } - - async fn handle_get_product( - &self, - product_account_key: &solana_sdk::pubkey::Pubkey, - ) -> Result { - let all_accounts_data = self.lookup_all_accounts_data().await?; - - // Look up the product account - let product_account = all_accounts_data - .product_accounts - .get(product_account_key) - .ok_or_else(|| anyhow!("product account not found"))?; - - Ok(Self::solana_product_account_to_pythd_api_product_account( - product_account, - &all_accounts_data, - product_account_key, - )) - } - - async fn handle_subscribe_price_sched( - &mut self, - account_pubkey: &solana_sdk::pubkey::Pubkey, - notify_price_sched_tx: mpsc::Sender, - ) -> SubscriptionID { - let subscription_id = self.next_subscription_id(); - self.notify_price_sched_subscriptions - .entry(Identifier::new(account_pubkey.to_bytes())) - .or_default() - .push(NotifyPriceSchedSubscription { - subscription_id, - notify_price_sched_tx, - }); - subscription_id - } - - fn next_subscription_id(&mut self) -> SubscriptionID { - self.subscription_id_count += 1; - self.subscription_id_count - } - - async fn handle_subscribe_price( - &mut self, - account: &solana_sdk::pubkey::Pubkey, - notify_price_tx: mpsc::Sender, - ) -> SubscriptionID { - let subscription_id = self.next_subscription_id(); - self.notify_price_subscriptions - .entry(Identifier::new(account.to_bytes())) - .or_default() - .push(NotifyPriceSubscription { - subscription_id, - notify_price_tx, - }); - subscription_id - } - - async fn send_notify_price_sched(&self) -> Result<()> { - for subscription in self.notify_price_sched_subscriptions.values().flatten() { - // Send the notify price sched update without awaiting. This results in raising errors - // if the channel is full which normally should not happen. This is because we do not - // want to block the adapter if the channel is full. - subscription - .notify_price_sched_tx - .try_send(NotifyPriceSched { - subscription: subscription.subscription_id, - })?; - } - - Ok(()) - } - - fn drop_closed_subscriptions(&mut self) { - for subscriptions in self.notify_price_subscriptions.values_mut() { - subscriptions.retain(|subscription| !subscription.notify_price_tx.is_closed()) - } - - for subscriptions in self.notify_price_sched_subscriptions.values_mut() { - subscriptions.retain(|subscription| !subscription.notify_price_sched_tx.is_closed()) - } - } - - async fn handle_update_price( - &self, - account: &solana_sdk::pubkey::Pubkey, - price: Price, - conf: Conf, - status: String, - ) -> Result<()> { - self.local_store_tx - .send(local::Message::Update { - price_identifier: pyth_sdk::Identifier::new(account.to_bytes()), - price_info: local::PriceInfo { - status: Adapter::map_status(&status)?, - price, - conf, - timestamp: Utc::now().timestamp(), - }, - }) - .await - .map_err(|_| anyhow!("failed to send update to local store")) - } - - // TODO: implement FromStr method on PriceStatus - fn map_status(status: &str) -> Result { - match status { - "unknown" => Ok(PriceStatus::Unknown), - "trading" => Ok(PriceStatus::Trading), - "halted" => Ok(PriceStatus::Halted), - "auction" => Ok(PriceStatus::Auction), - "ignored" => Ok(PriceStatus::Ignored), - _ => Err(anyhow!("invalid price status: {:#?}", status)), - } - } - - async fn handle_global_store_update( - &self, - price_identifier: PriceIdentifier, - price: i64, - conf: u64, - status: PriceStatus, - valid_slot: u64, - pub_slot: u64, - ) -> Result<()> { - // Look up any subcriptions associated with the price identifier - let empty = Vec::new(); - let subscriptions = self - .notify_price_subscriptions - .get(&price_identifier) - .unwrap_or(&empty); - - // Send the Notify Price update to each subscription - for subscription in subscriptions { - // Send the notify price update without awaiting. This results in raising errors if the - // channel is full which normally should not happen. This is because we do not want to - // block the adapter if the channel is full. - subscription.notify_price_tx.try_send(NotifyPrice { - subscription: subscription.subscription_id, - result: PriceUpdate { - price, - conf, - status: Self::price_status_to_str(status), - valid_slot, - pub_slot, - }, - })?; - } - - Ok(()) - } } #[cfg(test)] mod tests { use { super::{ + notifier, Adapter, + AdapterApi, Config, - Message, }, crate::agent::{ pythd::{ @@ -649,56 +161,48 @@ mod tests { HashMap, }, str::FromStr, + sync::Arc, time::Duration, }, tokio::{ sync::{ broadcast, mpsc, - oneshot, }, task::JoinHandle, }, }; struct TestAdapter { - message_tx: mpsc::Sender, - shutdown_tx: broadcast::Sender<()>, + adapter: Arc, global_store_lookup_rx: mpsc::Receiver, local_store_rx: mpsc::Receiver, + shutdown_tx: broadcast::Sender<()>, jh: JoinHandle<()>, } - impl Drop for TestAdapter { - fn drop(&mut self) { - let _ = self.shutdown_tx.send(()); - self.jh.abort(); - } - } - async fn setup() -> TestAdapter { // Create and spawn an adapter - let (adapter_tx, adapter_rx) = mpsc::channel(100); let (global_store_lookup_tx, global_store_lookup_rx) = mpsc::channel(1000); let (local_store_tx, local_store_rx) = mpsc::channel(1000); let notify_price_sched_interval_duration = Duration::from_nanos(10); let logger = slog_test::new_test_logger(IoBuffer::new()); - let (shutdown_tx, shutdown_rx) = broadcast::channel(10); let config = Config { notify_price_sched_interval_duration, }; - let mut adapter = Adapter::new( + let adapter = Arc::new(Adapter::new( config, - adapter_rx, global_store_lookup_tx, local_store_tx, - shutdown_rx, logger, - ); - let jh = tokio::spawn(async move { adapter.run().await }); + )); + let (shutdown_tx, _) = broadcast::channel(1); + + // Spawn Price Notifier + let jh = tokio::spawn(notifier(adapter.clone(), shutdown_tx.subscribe())); TestAdapter { - message_tx: adapter_tx, + adapter, global_store_lookup_rx, local_store_rx, shutdown_tx, @@ -711,20 +215,14 @@ mod tests { let test_adapter = setup().await; // Send a Subscribe Price Sched message - let account = "2wrWGm63xWubz7ue4iYR3qvBbaUJhZVi4eSpNuU8k8iF".to_string(); - let (notify_price_sched_tx, mut notify_price_sched_rx) = mpsc::channel(1000); - let (result_tx, result_rx) = oneshot::channel(); - test_adapter - .message_tx - .send(Message::SubscribePriceSched { - account, - notify_price_sched_tx, - result_tx, - }) - .await + let account = "2wrWGm63xWubz7ue4iYR3qvBbaUJhZVi4eSpNuU8k8iF" + .parse::() .unwrap(); - - let subscription_id = result_rx.await.unwrap().unwrap(); + let (notify_price_sched_tx, mut notify_price_sched_rx) = mpsc::channel(1000); + let subscription_id = test_adapter + .adapter + .subscribe_price_sched(&account, notify_price_sched_tx) + .await; // Expect that we recieve several Notify Price Sched notifications for _ in 0..10 { @@ -735,6 +233,9 @@ mod tests { } ) } + + let _ = test_adapter.shutdown_tx.send(()); + test_adapter.jh.abort(); } fn get_test_all_accounts_metadata() -> global::AllAccountsMetadata { @@ -857,23 +358,23 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_get_product_list() { // Start the test adapter - let mut test_adapter = setup().await; - - // Send a Get Product List message - let (result_tx, result_rx) = oneshot::channel(); - test_adapter - .message_tx - .send(Message::GetProductList { result_tx }) - .await - .unwrap(); + let test_adapter = setup().await; // Return the product list to the adapter, from the global store - match test_adapter.global_store_lookup_rx.recv().await.unwrap() { - global::Lookup::LookupAllAccountsMetadata { result_tx } => result_tx - .send(Ok(get_test_all_accounts_metadata())) - .unwrap(), - _ => panic!("Uexpected message received from adapter"), - }; + { + let mut global_store_lookup_rx = test_adapter.global_store_lookup_rx; + tokio::spawn(async move { + match global_store_lookup_rx.recv().await.unwrap() { + global::Lookup::LookupAllAccountsMetadata { result_tx } => result_tx + .send(Ok(get_test_all_accounts_metadata())) + .unwrap(), + _ => panic!("Uexpected message received from adapter"), + }; + }); + } + + // Send a Get Product List message + let mut product_list = test_adapter.adapter.get_product_list().await.unwrap(); // Check that the result is what we expected let expected = vec![ @@ -941,9 +442,10 @@ mod tests { }, ]; - let mut result = result_rx.await.unwrap().unwrap(); - result.sort(); - assert_eq!(result, expected); + product_list.sort(); + assert_eq!(product_list, expected); + let _ = test_adapter.shutdown_tx.send(()); + test_adapter.jh.abort(); } fn pad_price_comps(mut inputs: Vec) -> [PriceComp; 32] { @@ -1568,24 +1070,24 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_get_all_products() { // Start the test adapter - let mut test_adapter = setup().await; - - // Send a Get All Products message - let (result_tx, result_rx) = oneshot::channel(); - test_adapter - .message_tx - .send(Message::GetAllProducts { result_tx }) - .await - .unwrap(); + let test_adapter = setup().await; // Return the account data to the adapter, from the global store - match test_adapter.global_store_lookup_rx.recv().await.unwrap() { - global::Lookup::LookupAllAccountsData { - network: Network::Primary, - result_tx, - } => result_tx.send(Ok(get_all_accounts_data())).unwrap(), - _ => panic!("Uexpected message received from adapter"), - }; + { + let mut global_store_lookup_rx = test_adapter.global_store_lookup_rx; + tokio::spawn(async move { + match global_store_lookup_rx.recv().await.unwrap() { + global::Lookup::LookupAllAccountsData { + network: Network::Primary, + result_tx, + } => result_tx.send(Ok(get_all_accounts_data())).unwrap(), + _ => panic!("Uexpected message received from adapter"), + }; + }); + } + + // Send a Get All Products message + let mut all_products = test_adapter.adapter.get_all_products().await.unwrap(); // Check that the result of the conversion to the Pythd API format is what we expected let expected = vec![ @@ -1782,40 +1284,40 @@ mod tests { }, ]; - let mut result = result_rx.await.unwrap().unwrap(); - result.sort(); - assert_eq!(result, expected); + all_products.sort(); + assert_eq!(all_products, expected); + let _ = test_adapter.shutdown_tx.send(()); + test_adapter.jh.abort(); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn test_get_product() { // Start the test adapter - let mut test_adapter = setup().await; + let test_adapter = setup().await; + + // Return the account data to the adapter, from the global store + { + let mut global_store_lookup_rx = test_adapter.global_store_lookup_rx; + tokio::spawn(async move { + match global_store_lookup_rx.recv().await.unwrap() { + global::Lookup::LookupAllAccountsData { + network: Network::Primary, + result_tx, + } => result_tx.send(Ok(get_all_accounts_data())).unwrap(), + _ => panic!("Uexpected message received from adapter"), + }; + }); + } // Send a Get Product message - let account = "CkMrDWtmFJZcmAUC11qNaWymbXQKvnRx4cq1QudLav7t".to_string(); - let (result_tx, result_rx) = oneshot::channel(); - test_adapter - .message_tx - .send(Message::GetProduct { - account: account.clone(), - result_tx, - }) - .await + let account = "CkMrDWtmFJZcmAUC11qNaWymbXQKvnRx4cq1QudLav7t" + .parse::() .unwrap(); - - // Return the account data to the adapter, from the global store - match test_adapter.global_store_lookup_rx.recv().await.unwrap() { - global::Lookup::LookupAllAccountsData { - network: Network::Primary, - result_tx, - } => result_tx.send(Ok(get_all_accounts_data())).unwrap(), - _ => panic!("Uexpected message received from adapter"), - }; + let product = test_adapter.adapter.get_product(&account).await.unwrap(); // Check that the result of the conversion to the Pythd API format is what we expected let expected = ProductAccount { - account, + account: account.to_string(), price_accounts: vec![ api::PriceAccount { account: "GVXRSBjFk6e6J3NbVPXohDJetcTjaeeuykUpbQF8UoMU".to_string(), @@ -1887,7 +1389,7 @@ mod tests { ], }, ], - attr_dict: BTreeMap::from( + attr_dict: BTreeMap::from( [ ("symbol", "Crypto.LTC/USD"), ("asset_type", "Crypto"), @@ -1900,8 +1402,9 @@ mod tests { ), }; - let result = result_rx.await.unwrap().unwrap(); - assert_eq!(result, expected); + assert_eq!(product, expected); + let _ = test_adapter.shutdown_tx.send(()); + test_adapter.jh.abort(); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -1910,17 +1413,14 @@ mod tests { let mut test_adapter = setup().await; // Send an Update Price message - let account = "CkMrDWtmFJZcmAUC11qNaWymbXQKvnRx4cq1QudLav7t".to_string(); + let account = "CkMrDWtmFJZcmAUC11qNaWymbXQKvnRx4cq1QudLav7t" + .parse::() + .unwrap(); let price = 2365; let conf = 98754; - test_adapter - .message_tx - .send(Message::UpdatePrice { - account: account.clone(), - price, - conf, - status: "trading".to_string(), - }) + let _ = test_adapter + .adapter + .update_price(&account, price, conf, "trading".to_string()) .await .unwrap(); @@ -1930,21 +1430,16 @@ mod tests { price_identifier, price_info, } => { - assert_eq!( - price_identifier, - Identifier::new( - account - .parse::() - .unwrap() - .to_bytes() - ) - ); + assert_eq!(price_identifier, Identifier::new(account.to_bytes())); assert_eq!(price_info.price, price); assert_eq!(price_info.conf, conf); assert_eq!(price_info.status, PriceStatus::Trading); } _ => panic!("Uexpected message received by local store from adapter"), }; + + let _ = test_adapter.shutdown_tx.send(()); + test_adapter.jh.abort(); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -1953,41 +1448,30 @@ mod tests { let test_adapter = setup().await; // Send a Subscribe Price message - let account = "2wrWGm63xWubz7ue4iYR3qvBbaUJhZVi4eSpNuU8k8iF".to_string(); - let (notify_price_tx, mut notify_price_rx) = mpsc::channel(1000); - let (result_tx, result_rx) = oneshot::channel(); - test_adapter - .message_tx - .send(Message::SubscribePrice { - account: account.clone(), - notify_price_tx, - result_tx, - }) - .await + let account = "2wrWGm63xWubz7ue4iYR3qvBbaUJhZVi4eSpNuU8k8iF" + .parse::() .unwrap(); - - let subscription_id = result_rx.await.unwrap().unwrap(); + let (notify_price_tx, mut notify_price_rx) = mpsc::channel(1000); + let subscription_id = test_adapter + .adapter + .subscribe_price(&account, notify_price_tx) + .await; // Send an update from the global store to the adapter let price = 52162; let conf = 1646; let valid_slot = 75684; let pub_slot = 32565; - test_adapter - .message_tx - .send(Message::GlobalStoreUpdate { - price_identifier: Identifier::new( - account - .parse::() - .unwrap() - .to_bytes(), - ), + let _ = test_adapter + .adapter + .global_store_update( + Identifier::new(account.to_bytes()), price, conf, - status: PriceStatus::Trading, + PriceStatus::Trading, valid_slot, pub_slot, - }) + ) .await .unwrap(); @@ -2005,6 +1489,9 @@ mod tests { pub_slot }, } - ) + ); + + let _ = test_adapter.shutdown_tx.send(()); + test_adapter.jh.abort(); } } diff --git a/src/agent/pythd/adapter/api.rs b/src/agent/pythd/adapter/api.rs new file mode 100644 index 0000000..acbe6da --- /dev/null +++ b/src/agent/pythd/adapter/api.rs @@ -0,0 +1,430 @@ +use { + super::{ + Adapter, + NotifyPriceSchedSubscription, + NotifyPriceSubscription, + }, + crate::agent::{ + pythd::api::{ + Conf, + NotifyPrice, + NotifyPriceSched, + Price, + PriceAccount, + PriceAccountMetadata, + PriceUpdate, + ProductAccount, + ProductAccountMetadata, + PublisherAccount, + SubscriptionID, + }, + solana::{ + self, + oracle::PriceEntry, + }, + store::{ + global::{ + self, + AllAccountsData, + AllAccountsMetadata, + }, + local, + PriceIdentifier, + }, + }, + anyhow::{ + anyhow, + Result, + }, + chrono::Utc, + pyth_sdk::Identifier, + pyth_sdk_solana::state::{ + PriceComp, + PriceStatus, + }, + std::sync::Arc, + tokio::sync::{ + broadcast, + mpsc, + oneshot, + }, +}; + +// TODO: implement Display on PriceStatus and then just call PriceStatus::to_string +fn price_status_to_str(price_status: PriceStatus) -> String { + match price_status { + PriceStatus::Unknown => "unknown", + PriceStatus::Trading => "trading", + PriceStatus::Halted => "halted", + PriceStatus::Auction => "auction", + PriceStatus::Ignored => "ignored", + } + .to_string() +} + +fn solana_product_account_to_pythd_api_product_account( + product_account: &solana::oracle::ProductEntry, + all_accounts_data: &AllAccountsData, + product_account_key: &solana_sdk::pubkey::Pubkey, +) -> ProductAccount { + // Extract all the price accounts from the product account + let price_accounts = product_account + .price_accounts + .iter() + .filter_map(|price_account_key| { + all_accounts_data + .price_accounts + .get(price_account_key) + .map(|acc| (price_account_key, acc)) + }) + .map(|(price_account_key, price_account)| { + solana_price_account_to_pythd_api_price_account(price_account_key, price_account) + }) + .collect(); + + // Create the product account metadata struct + ProductAccount { + account: product_account_key.to_string(), + attr_dict: product_account + .account_data + .iter() + .filter(|(key, val)| !key.is_empty() && !val.is_empty()) + .map(|(key, val)| (key.to_owned(), val.to_owned())) + .collect(), + price_accounts, + } +} + +fn solana_price_account_to_pythd_api_price_account( + price_account_key: &solana_sdk::pubkey::Pubkey, + price_account: &PriceEntry, +) -> PriceAccount { + PriceAccount { + account: price_account_key.to_string(), + price_type: "price".to_string(), + price_exponent: price_account.expo as i64, + status: price_status_to_str(price_account.agg.status), + price: price_account.agg.price, + conf: price_account.agg.conf, + twap: price_account.ema_price.val, + twac: price_account.ema_conf.val, + valid_slot: price_account.valid_slot, + pub_slot: price_account.agg.pub_slot, + prev_slot: price_account.prev_slot, + prev_price: price_account.prev_price, + prev_conf: price_account.prev_conf, + publisher_accounts: price_account + .comp + .iter() + .filter(|comp| *comp != &PriceComp::default()) + .map(|comp| PublisherAccount { + account: comp.publisher.to_string(), + status: price_status_to_str(comp.agg.status), + price: comp.agg.price, + conf: comp.agg.conf, + slot: comp.agg.pub_slot, + }) + .collect(), + } +} + +#[async_trait::async_trait] +pub trait AdapterApi { + async fn get_product_list(&self) -> Result>; + async fn lookup_all_accounts_metadata(&self) -> Result; + async fn get_all_products(&self) -> Result>; + async fn lookup_all_accounts_data(&self) -> Result; + async fn get_product( + &self, + product_account_key: &solana_sdk::pubkey::Pubkey, + ) -> Result; + async fn subscribe_price_sched( + &self, + account_pubkey: &solana_sdk::pubkey::Pubkey, + notify_price_sched_tx: mpsc::Sender, + ) -> SubscriptionID; + fn next_subscription_id(&self) -> SubscriptionID; + async fn subscribe_price( + &self, + account: &solana_sdk::pubkey::Pubkey, + notify_price_tx: mpsc::Sender, + ) -> SubscriptionID; + async fn send_notify_price_sched(&self) -> Result<()>; + async fn drop_closed_subscriptions(&self); + async fn update_price( + &self, + account: &solana_sdk::pubkey::Pubkey, + price: Price, + conf: Conf, + status: String, + ) -> Result<()>; + // TODO: implement FromStr method on PriceStatus + fn map_status(status: &str) -> Result; + async fn global_store_update( + &self, + price_identifier: PriceIdentifier, + price: i64, + conf: u64, + status: PriceStatus, + valid_slot: u64, + pub_slot: u64, + ) -> Result<()>; +} + +pub async fn notifier(adapter: Arc, mut shutdown_rx: broadcast::Receiver<()>) { + let mut interval = tokio::time::interval(adapter.notify_price_sched_interval_duration); + loop { + adapter.drop_closed_subscriptions().await; + tokio::select! { + _ = shutdown_rx.recv() => { + info!(adapter.logger, "shutdown signal received"); + return; + } + _ = interval.tick() => { + if let Err(err) = adapter.send_notify_price_sched().await { + error!(adapter.logger, "{}", err); + debug!(adapter.logger, "error context"; "context" => format!("{:?}", err)); + } + } + } + } +} + +#[async_trait::async_trait] +impl AdapterApi for Adapter { + async fn get_product_list(&self) -> Result> { + let all_accounts_metadata = self.lookup_all_accounts_metadata().await?; + + let mut result = Vec::new(); + for (product_account_key, product_account) in + all_accounts_metadata.product_accounts_metadata + { + // Transform the price accounts into the API PriceAccountMetadata structs + // the API uses. + let price_accounts_metadata = product_account + .price_accounts + .iter() + .filter_map(|price_account_key| { + all_accounts_metadata + .price_accounts_metadata + .get(price_account_key) + .map(|acc| (price_account_key, acc)) + }) + .map(|(price_account_key, price_account)| PriceAccountMetadata { + account: price_account_key.to_string(), + price_type: "price".to_owned(), + price_exponent: price_account.expo as i64, + }) + .collect(); + + // Create the product account metadata struct + result.push(ProductAccountMetadata { + account: product_account_key.to_string(), + attr_dict: product_account.attr_dict, + price: price_accounts_metadata, + }) + } + + Ok(result) + } + + // Fetches the Solana-specific Oracle data from the global store + async fn lookup_all_accounts_metadata(&self) -> Result { + let (result_tx, result_rx) = oneshot::channel(); + self.global_store_lookup_tx + .send(global::Lookup::LookupAllAccountsMetadata { result_tx }) + .await?; + result_rx.await? + } + + async fn get_all_products(&self) -> Result> { + let solana_data = self.lookup_all_accounts_data().await?; + + let mut result = Vec::new(); + for (product_account_key, product_account) in &solana_data.product_accounts { + let product_account_api = solana_product_account_to_pythd_api_product_account( + product_account, + &solana_data, + product_account_key, + ); + + result.push(product_account_api) + } + + Ok(result) + } + + async fn lookup_all_accounts_data(&self) -> Result { + let (result_tx, result_rx) = oneshot::channel(); + self.global_store_lookup_tx + .send(global::Lookup::LookupAllAccountsData { + network: solana::network::Network::Primary, + result_tx, + }) + .await?; + result_rx.await? + } + + async fn get_product( + &self, + product_account_key: &solana_sdk::pubkey::Pubkey, + ) -> Result { + let all_accounts_data = self.lookup_all_accounts_data().await?; + + // Look up the product account + let product_account = all_accounts_data + .product_accounts + .get(product_account_key) + .ok_or_else(|| anyhow!("product account not found"))?; + + Ok(solana_product_account_to_pythd_api_product_account( + product_account, + &all_accounts_data, + product_account_key, + )) + } + + async fn subscribe_price_sched( + &self, + account_pubkey: &solana_sdk::pubkey::Pubkey, + notify_price_sched_tx: mpsc::Sender, + ) -> SubscriptionID { + let subscription_id = self.next_subscription_id(); + self.notify_price_sched_subscriptions + .write() + .await + .entry(Identifier::new(account_pubkey.to_bytes())) + .or_default() + .push(NotifyPriceSchedSubscription { + subscription_id, + notify_price_sched_tx, + }); + subscription_id + } + + fn next_subscription_id(&self) -> SubscriptionID { + self.subscription_id_seq + .fetch_add(1, std::sync::atomic::Ordering::SeqCst) + } + + async fn subscribe_price( + &self, + account: &solana_sdk::pubkey::Pubkey, + notify_price_tx: mpsc::Sender, + ) -> SubscriptionID { + let subscription_id = self.next_subscription_id(); + self.notify_price_subscriptions + .write() + .await + .entry(Identifier::new(account.to_bytes())) + .or_default() + .push(NotifyPriceSubscription { + subscription_id, + notify_price_tx, + }); + subscription_id + } + + async fn send_notify_price_sched(&self) -> Result<()> { + for subscription in self + .notify_price_sched_subscriptions + .read() + .await + .values() + .flatten() + { + // Send the notify price sched update without awaiting. This results in raising errors + // if the channel is full which normally should not happen. This is because we do not + // want to block the adapter if the channel is full. + subscription + .notify_price_sched_tx + .try_send(NotifyPriceSched { + subscription: subscription.subscription_id, + })?; + } + + Ok(()) + } + + async fn drop_closed_subscriptions(&self) { + for subscriptions in self.notify_price_subscriptions.write().await.values_mut() { + subscriptions.retain(|subscription| !subscription.notify_price_tx.is_closed()) + } + + for subscriptions in self + .notify_price_sched_subscriptions + .write() + .await + .values_mut() + { + subscriptions.retain(|subscription| !subscription.notify_price_sched_tx.is_closed()) + } + } + + async fn update_price( + &self, + account: &solana_sdk::pubkey::Pubkey, + price: Price, + conf: Conf, + status: String, + ) -> Result<()> { + self.local_store_tx + .send(local::Message::Update { + price_identifier: pyth_sdk::Identifier::new(account.to_bytes()), + price_info: local::PriceInfo { + status: Adapter::map_status(&status)?, + price, + conf, + timestamp: Utc::now().timestamp(), + }, + }) + .await + .map_err(|_| anyhow!("failed to send update to local store")) + } + + // TODO: implement FromStr method on PriceStatus + fn map_status(status: &str) -> Result { + match status { + "unknown" => Ok(PriceStatus::Unknown), + "trading" => Ok(PriceStatus::Trading), + "halted" => Ok(PriceStatus::Halted), + "auction" => Ok(PriceStatus::Auction), + "ignored" => Ok(PriceStatus::Ignored), + _ => Err(anyhow!("invalid price status: {:#?}", status)), + } + } + + async fn global_store_update( + &self, + price_identifier: PriceIdentifier, + price: i64, + conf: u64, + status: PriceStatus, + valid_slot: u64, + pub_slot: u64, + ) -> Result<()> { + // Look up any subcriptions associated with the price identifier + let empty = Vec::new(); + let subscriptions = self.notify_price_subscriptions.read().await; + let subscriptions = subscriptions.get(&price_identifier).unwrap_or(&empty); + + // Send the Notify Price update to each subscription + for subscription in subscriptions { + // Send the notify price update without awaiting. This results in raising errors if the + // channel is full which normally should not happen. This is because we do not want to + // block the adapter if the channel is full. + subscription.notify_price_tx.try_send(NotifyPrice { + subscription: subscription.subscription_id, + result: PriceUpdate { + price, + conf, + status: price_status_to_str(status), + valid_slot, + pub_slot, + }, + })?; + } + + Ok(()) + } +} diff --git a/src/agent/pythd/api/rpc.rs b/src/agent/pythd/api/rpc.rs index 51adb54..4316ca6 100644 --- a/src/agent/pythd/api/rpc.rs +++ b/src/agent/pythd/api/rpc.rs @@ -48,6 +48,7 @@ use { std::{ fmt::Debug, net::SocketAddr, + sync::Arc, }, tokio::sync::{ broadcast, @@ -112,13 +113,18 @@ enum ConnectionError { WebsocketConnectionClosed, } -async fn handle_connection( +async fn handle_connection( ws_conn: WebSocket, - adapter_tx: mpsc::Sender, + adapter: Arc, notify_price_tx_buffer: usize, notify_price_sched_tx_buffer: usize, logger: Logger, -) { +) where + S: adapter::AdapterApi, + S: Send, + S: Sync, + S: 'static, +{ // Create the channels let (mut ws_tx, mut ws_rx) = ws_conn.split(); let (mut notify_price_tx, mut notify_price_rx) = mpsc::channel(notify_price_tx_buffer); @@ -128,7 +134,7 @@ async fn handle_connection( loop { if let Err(err) = handle_next( &logger, - &adapter_tx, + &*adapter, &mut ws_tx, &mut ws_rx, &mut notify_price_tx, @@ -151,16 +157,19 @@ async fn handle_connection( } } -async fn handle_next( +async fn handle_next( logger: &Logger, - adapter_tx: &mpsc::Sender, + adapter: &S, ws_tx: &mut SplitSink, ws_rx: &mut SplitStream, notify_price_tx: &mut mpsc::Sender, notify_price_rx: &mut mpsc::Receiver, notify_price_sched_tx: &mut mpsc::Sender, notify_price_sched_rx: &mut mpsc::Receiver, -) -> Result<()> { +) -> Result<()> +where + S: adapter::AdapterApi, +{ tokio::select! { msg = ws_rx.next() => { match msg { @@ -169,7 +178,7 @@ async fn handle_next( handle( logger, ws_tx, - adapter_tx, + adapter, notify_price_tx, notify_price_sched_tx, msg, @@ -192,14 +201,17 @@ async fn handle_next( } } -async fn handle( +async fn handle( logger: &Logger, ws_tx: &mut SplitSink, - adapter_tx: &mpsc::Sender, + adapter: &S, notify_price_tx: &mpsc::Sender, notify_price_sched_tx: &mpsc::Sender, msg: Message, -) -> Result<()> { +) -> Result<()> +where + S: adapter::AdapterApi, +{ // Ignore control and binary messages if !msg.is_text() { debug!(logger, "JSON RPC API: skipped non-text message"); @@ -215,7 +227,7 @@ async fn handle( for request in requests { let response = dispatch_and_catch_error( logger, - adapter_tx, + adapter, notify_price_tx, notify_price_sched_tx, &request, @@ -276,13 +288,16 @@ async fn parse(msg: Message) -> Result<(Vec>, bool)> { } } -async fn dispatch_and_catch_error( +async fn dispatch_and_catch_error( logger: &Logger, - adapter_tx: &mpsc::Sender, + adapter: &S, notify_price_tx: &mpsc::Sender, notify_price_sched_tx: &mpsc::Sender, request: &Request, -) -> Response { +) -> Response +where + S: adapter::AdapterApi, +{ debug!( logger, "JSON RPC API: handling request"; @@ -290,13 +305,13 @@ async fn dispatch_and_catch_error( ); let result = match request.method { - Method::GetProductList => get_product_list(adapter_tx).await, - Method::GetProduct => get_product(adapter_tx, request).await, - Method::GetAllProducts => get_all_products(adapter_tx).await, - Method::UpdatePrice => update_price(adapter_tx, request).await, - Method::SubscribePrice => subscribe_price(adapter_tx, notify_price_tx, request).await, + Method::GetProductList => get_product_list(adapter).await, + Method::GetProduct => get_product(adapter, request).await, + Method::GetAllProducts => get_all_products(adapter).await, + Method::UpdatePrice => update_price(adapter, request).await, + Method::SubscribePrice => subscribe_price(adapter, notify_price_tx, request).await, Method::SubscribePriceSched => { - subscribe_price_sched(adapter_tx, notify_price_sched_tx, request).await + subscribe_price_sched(adapter, notify_price_sched_tx, request).await } Method::NotifyPrice | Method::NotifyPriceSched => { Err(anyhow!("unsupported method: {:?}", request.method)) @@ -415,25 +430,35 @@ impl Default for Config { } } -pub async fn run( +pub async fn run( config: Config, logger: Logger, - adapter_tx: mpsc::Sender, + adapter: Arc, shutdown_rx: broadcast::Receiver<()>, -) { - if let Err(err) = serve(config, &logger, adapter_tx, shutdown_rx).await { +) where + S: adapter::AdapterApi, + S: Send, + S: Sync, + S: 'static, +{ + if let Err(err) = serve(config, &logger, adapter, shutdown_rx).await { error!(logger, "{}", err); debug!(logger, "error context"; "context" => format!("{:?}", err)); } } -async fn serve( +async fn serve( config: Config, logger: &Logger, - adapter_tx: mpsc::Sender, + adapter: Arc, mut shutdown_rx: broadcast::Receiver<()>, -) -> Result<()> { - let adapter_tx = adapter_tx.clone(); +) -> Result<()> +where + S: adapter::AdapterApi, + S: Send, + S: Sync, + S: 'static, +{ let config = config.clone(); let with_logger = WithLogger { logger: logger.clone(), @@ -443,19 +468,16 @@ async fn serve( let config = config.clone(); warp::path::end() .and(warp::ws()) - .and(warp::any().map(move || adapter_tx.clone())) + .and(warp::any().map(move || adapter.clone())) .and(warp::any().map(move || with_logger.clone())) .and(warp::any().map(move || config.clone())) .map( - |ws: Ws, - adapter_tx: mpsc::Sender, - with_logger: WithLogger, - config: Config| { + |ws: Ws, adapter: Arc, with_logger: WithLogger, config: Config| { ws.on_upgrade(move |conn| async move { info!(with_logger.logger, "websocket user connected"); handle_connection( conn, - adapter_tx, + adapter, config.notify_price_tx_buffer, config.notify_price_sched_tx_buffer, with_logger.logger, @@ -477,792 +499,3 @@ async fn serve( tokio::task::spawn(serve).await.map_err(|e| e.into()) } - -#[cfg(test)] -mod tests { - use { - super::{ - super::{ - rpc::GetProductParams, - Attrs, - PriceAccount, - PriceAccountMetadata, - ProductAccount, - ProductAccountMetadata, - Pubkey, - PublisherAccount, - SubscriptionID, - }, - Config, - }, - crate::agent::pythd::{ - adapter, - api::{ - rpc::{ - SubscribePriceParams, - SubscribePriceSchedParams, - UpdatePriceParams, - }, - NotifyPrice, - NotifyPriceSched, - PriceUpdate, - }, - }, - anyhow::anyhow, - iobuffer::IoBuffer, - jrpc::{ - Id, - Request, - }, - rand::Rng, - serde::{ - de::DeserializeOwned, - Serialize, - }, - slog_extlog::slog_test, - soketto::handshake::{ - Client, - ServerResponse, - }, - std::str::from_utf8, - tokio::{ - net::TcpStream, - sync::{ - broadcast, - mpsc, - }, - task::JoinHandle, - }, - tokio_retry::{ - strategy::FixedInterval, - Retry, - }, - tokio_util::compat::{ - Compat, - TokioAsyncReadCompatExt, - }, - }; - - struct TestAdapter { - rx: mpsc::Receiver, - } - - impl TestAdapter { - async fn recv(&mut self) -> adapter::Message { - self.rx.recv().await.unwrap() - } - } - - struct TestServer { - shutdown_tx: broadcast::Sender<()>, - jh: JoinHandle<()>, - } - - impl Drop for TestServer { - fn drop(&mut self) { - let _ = self.shutdown_tx.send(()); - self.jh.abort(); - } - } - - struct TestClient { - sender: soketto::Sender>, - receiver: soketto::Receiver>, - } - - impl TestClient { - async fn new(server_port: u16) -> Self { - // Connect to the server, retrying as the server may take some time to respond to requests initially - let socket = Retry::spawn(FixedInterval::from_millis(100).take(20), || { - TcpStream::connect(("127.0.0.1", server_port)) - }) - .await - .unwrap(); - let mut client = Client::new(socket.compat(), "...", "/"); - - // Perform the websocket handshake - let handshake = client.handshake().await.unwrap(); - assert!(matches!(handshake, ServerResponse::Accepted { .. })); - - let (sender, receiver) = client.into_builder().finish(); - TestClient { sender, receiver } - } - - async fn send(&mut self, request: Request) - where - T: Serialize + DeserializeOwned, - { - self.sender.send_text(request.to_string()).await.unwrap(); - } - - async fn send_batch(&mut self, requests: Vec>) { - let serialized = serde_json::to_string(&requests).unwrap(); - self.sender.send_text(serialized).await.unwrap() - } - - async fn recv_json(&mut self) -> String { - let bytes = self.recv_bytes().await; - from_utf8(&bytes).unwrap().to_string() - } - - async fn recv_bytes(&mut self) -> Vec { - let mut bytes = Vec::new(); - self.receiver.receive_data(&mut bytes).await.unwrap(); - bytes - } - } - - async fn start_server() -> (TestServer, TestClient, TestAdapter, IoBuffer) { - let listen_port = portpicker::pick_unused_port().unwrap(); - - // Create the test adapter - let (adapter_tx, adapter_rx) = mpsc::channel(100); - let test_adapter = TestAdapter { rx: adapter_rx }; - - // Create and spawn a server (the SUT) - let (shutdown_tx, shutdown_rx) = broadcast::channel(10); - let log_buffer = IoBuffer::new(); - let logger = slog_test::new_test_logger(log_buffer.clone()); - let config = Config { - listen_address: format!("127.0.0.1:{:}", listen_port), - ..Default::default() - }; - let jh = tokio::spawn(super::run(config, logger, adapter_tx, shutdown_rx)); - let test_server = TestServer { shutdown_tx, jh }; - - // Create a test client to interact with the server - let test_client = TestClient::new(listen_port).await; - - (test_server, test_client, test_adapter, log_buffer) - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn json_get_product_success_test() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, mut test_adapter, _) = start_server().await; - - // Make a binary request, which should be safely ignored - let random_bytes = rand::thread_rng().gen::<[u8; 32]>(); - test_client.sender.send_binary(random_bytes).await.unwrap(); - - // Define the product account we expect to receive back - let product_account_key = "some_product_account".to_string(); - let product_account = ProductAccount { - account: product_account_key.clone(), - attr_dict: Attrs::from( - [ - ("symbol", "BTC/USD"), - ("asset_type", "Crypto"), - ("country", "US"), - ("quote_currency", "USD"), - ("tenor", "spot"), - ] - .map(|(k, v)| (k.to_string(), v.to_string())), - ), - price_accounts: vec![PriceAccount { - account: "some_price_account".to_string(), - price_type: "price".to_string(), - price_exponent: 8, - status: "trading".to_string(), - price: 536, - conf: 67, - twap: 276, - twac: 463, - valid_slot: 4628, - pub_slot: 4736, - prev_slot: 3856, - prev_price: 400, - prev_conf: 45, - publisher_accounts: vec![ - PublisherAccount { - account: "some_publisher_account".to_string(), - status: "trading".to_string(), - price: 500, - conf: 24, - slot: 3563, - }, - PublisherAccount { - account: "another_publisher_account".to_string(), - status: "halted".to_string(), - price: 300, - conf: 683, - slot: 5834, - }, - ], - }], - }; - - // Make a request - test_client - .send(Request::with_params( - Id::from(1), - "get_product".to_string(), - GetProductParams { - account: product_account_key, - }, - )) - .await; - - // Expect the adapter to receive the corresponding message and send the product account in return - if let adapter::Message::GetProduct { result_tx, .. } = test_adapter.recv().await { - result_tx.send(Ok(product_account.clone())).unwrap(); - } - - // Wait for the result to come back - let received_json = test_client.recv_json().await; - - // Check that the JSON representation is correct - let expected = serde_json::json!({ - "jsonrpc":"2.0", - "result": { - "account": "some_product_account", - "attr_dict": { - "symbol": "BTC/USD", - "asset_type": "Crypto", - "country": "US", - "quote_currency": "USD", - "tenor": "spot" - }, - "price_accounts": [ - { - "account": "some_price_account", - "price_type": "price", - "price_exponent": 8, - "status": "trading", - "price": 536, - "conf": 67, - "twap": 276, - "twac": 463, - "valid_slot": 4628, - "pub_slot": 4736, - "prev_slot": 3856, - "prev_price": 400, - "prev_conf": 45, - "publisher_accounts": [ - { - "account": "some_publisher_account", - "status": "trading", - "price": 500, - "conf": 24, - "slot": 3563 - }, - { - "account": "another_publisher_account", - "status": "halted", - "price": 300, - "conf": 683, - "slot": 5834 - } - ] - - } - ] - }, - "id": 1 - } - ); - let received: serde_json::Value = serde_json::from_str(&received_json).unwrap(); - assert_eq!(received, expected); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn json_unknown_method_error_test() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, _, _) = start_server().await; - - // Make a request with an unknown methid - test_client - .send(Request::with_params( - Id::from(3), - "wrong_method".to_string(), - GetProductParams { - account: "some_account".to_string(), - }, - )) - .await; - - // Wait for the result to come back - let received_json = test_client.recv_json().await; - - // Check that the result is what we expect - let expected_json = r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"Could not parse message: unknown variant `wrong_method`, expected one of `get_product_list`, `get_product`, `get_all_products`, `subscribe_price`, `notify_price`, `subscribe_price_sched`, `notify_price_sched`, `update_price`","data":null},"id":0}"#; - assert_eq!(received_json, expected_json); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn json_missing_request_parameters_test() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, _, _) = start_server().await; - - // Make a request with missing parameters - test_client - .send(Request::new(Id::from(5), "update_price".to_string())) - .await; - - // Wait for the result to come back - let received_json = test_client.recv_json().await; - - // Check that the result is what we expect - let expected_json = r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"Missing request parameters","data":null},"id":5}"#; - assert_eq!(received_json, expected_json); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn json_internal_error() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, mut test_adapter, _) = start_server().await; - - // Make a request - test_client - .send(Request::with_params( - Id::from(9), - "get_product".to_string(), - GetProductParams { - account: "some_account".to_string(), - }, - )) - .await; - - // Make the adapter throw an error in return - if let adapter::Message::GetProduct { result_tx, .. } = test_adapter.recv().await { - result_tx.send(Err(anyhow!("some internal error"))).unwrap(); - } - - // Get the result back - let received_json = test_client.recv_json().await; - - // Check that the result is what we expect - let expected_json = r#"{"jsonrpc":"2.0","error":{"code":-32603,"message":"some internal error","data":null},"id":9}"#; - assert_eq!(expected_json, received_json); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn json_update_price_success() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, mut test_adapter, _) = start_server().await; - - // Make a request to update the price - let status = "trading"; - let params = UpdatePriceParams { - account: Pubkey::from("some_price_account"), - price: 7467, - conf: 892, - status: status.to_string(), - }; - test_client - .send(Request::with_params( - Id::from(15), - "update_price".to_string(), - params.clone(), - )) - .await; - - // Assert that the adapter receives this - assert!(matches!( - test_adapter.recv().await, - adapter::Message::UpdatePrice { - account, - price, - conf, - status - } if account == params.account && price == params.price && conf == params.conf && status == params.status - )); - - // Get the result back - let received_json = test_client.recv_json().await; - - // Assert that the result is what we expect - let expected_json = r#"{"jsonrpc":"2.0","result":0,"id":15}"#; - assert_eq!(received_json, expected_json); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn get_product_list_success_test() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, mut test_adapter, _) = start_server().await; - - // Define the data we are working with - let product_account = Pubkey::from("some_product_account"); - let data = vec![ProductAccountMetadata { - account: product_account.clone(), - attr_dict: Attrs::from( - [ - ("symbol", "BTC/USD"), - ("asset_type", "Crypto"), - ("country", "US"), - ("quote_currency", "USD"), - ("tenor", "spot"), - ] - .map(|(k, v)| (k.to_string(), v.to_string())), - ), - price: vec![ - PriceAccountMetadata { - account: Pubkey::from("some_price_account"), - price_type: "price".to_string(), - price_exponent: 4, - }, - PriceAccountMetadata { - account: Pubkey::from("another_price_account"), - price_type: "special".to_string(), - price_exponent: 6, - }, - ], - }]; - - // Make a GetProductList request - test_client - .send(Request::new(Id::from(11), "get_product_list".to_string())) - .await; - - // Instruct the adapter to send our data back - if let adapter::Message::GetProductList { result_tx } = test_adapter.recv().await { - result_tx.send(Ok(data.clone())).unwrap(); - } - - // Get the result back - let bytes = test_client.recv_bytes().await; - - // Assert that the result is what we expect - let response: jrpc::Response> = - serde_json::from_slice(&bytes).unwrap(); - assert!(matches!(response, jrpc::Response::Ok(success) if success.result == data)); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn get_all_products_success() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, mut test_adapter, _) = start_server().await; - - // Define the data we are working with - let data = vec![ProductAccount { - account: Pubkey::from("some_product_account"), - attr_dict: Attrs::from( - [ - ("symbol", "LTC/USD"), - ("asset_type", "Crypto"), - ("country", "US"), - ("quote_currency", "USD"), - ("tenor", "spot"), - ] - .map(|(k, v)| (k.to_string(), v.to_string())), - ), - price_accounts: vec![PriceAccount { - account: Pubkey::from("some_price_account"), - price_type: "price".to_string(), - price_exponent: 7463, - status: "trading".to_string(), - price: 6453, - conf: 3434, - twap: 6454, - twac: 365, - valid_slot: 3646, - pub_slot: 2857, - prev_slot: 7463, - prev_price: 3784, - prev_conf: 9879, - publisher_accounts: vec![ - PublisherAccount { - account: Pubkey::from("some_publisher_account"), - status: "trading".to_string(), - price: 756, - conf: 8787, - slot: 2209, - }, - PublisherAccount { - account: Pubkey::from("another_publisher_account"), - status: "halted".to_string(), - price: 0, - conf: 0, - slot: 6676, - }, - ], - }], - }]; - - // Make a GetAllProducts request - test_client - .send(Request::new(Id::from(5), "get_all_products".to_string())) - .await; - - // Instruct the adapter to send our data back - if let adapter::Message::GetAllProducts { result_tx, .. } = test_adapter.recv().await { - result_tx.send(Ok(data.clone())).unwrap(); - } - - // Get the result back - let bytes = test_client.recv_bytes().await; - - // Assert that the result is what we expect - let response: jrpc::Response> = serde_json::from_slice(&bytes).unwrap(); - assert!(matches!(response, jrpc::Response::Ok(success) if success.result == data)); - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn subscribe_price_success() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, mut test_adapter, _) = start_server().await; - - // Make a SubscribePrice request - let price_account = Pubkey::from("some_price_account"); - test_client - .send(Request::with_params( - Id::from(13), - "subscribe_price".to_string(), - SubscribePriceParams { - account: price_account, - }, - )) - .await; - - // Send a subscription ID back, and then a Notify Price update. - // Check that both are received by the client. - match test_adapter.recv().await { - adapter::Message::SubscribePrice { - account: _, - notify_price_tx, - result_tx, - } => { - // Send the subscription ID from the adapter to the server - let subscription_id = SubscriptionID::from(16); - result_tx.send(Ok(subscription_id)).unwrap(); - - // Assert that the client connection receives the subscription ID - assert_eq!( - test_client.recv_json().await, - r#"{"jsonrpc":"2.0","result":{"subscription":16},"id":13}"# - ); - - // Send a Notify Price event from the adapter to the server, with the corresponding subscription id - let notify_price_update = NotifyPrice { - subscription: subscription_id, - result: PriceUpdate { - price: 74, - conf: 24, - status: "trading".to_string(), - valid_slot: 6786, - pub_slot: 9897, - }, - }; - notify_price_tx.send(notify_price_update).await.unwrap(); - - // Assert that the client connection receives the notify_price notification - // with the subscription ID and price update. - assert_eq!( - test_client.recv_json().await, - r#"{"jsonrpc":"2.0","method":"notify_price","params":{"subscription":16,"result":{"price":74,"conf":24,"status":"trading","valid_slot":6786,"pub_slot":9897}}}"# - ) - } - _ => panic!("Uexpected message received from adapter"), - }; - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn subscribe_price_sched_success() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, mut test_adapter, _) = start_server().await; - - // Make a SubscribePriceSched request - let price_account = Pubkey::from("some_price_account"); - test_client - .send(Request::with_params( - Id::from(19), - "subscribe_price_sched".to_string(), - SubscribePriceSchedParams { - account: price_account, - }, - )) - .await; - - // Send a subscription ID back, and then a Notify Price Sched update. - // Check that both are received by the client. - match test_adapter.recv().await { - adapter::Message::SubscribePriceSched { - account: _, - notify_price_sched_tx, - result_tx, - } => { - // Send the subscription ID from the adapter to the server - let subscription_id = SubscriptionID::from(27); - result_tx.send(Ok(subscription_id)).unwrap(); - - // Assert that the client connection receives the subscription ID - assert_eq!( - test_client.recv_json().await, - r#"{"jsonrpc":"2.0","result":{"subscription":27},"id":19}"# - ); - - // Send a Notify Price Sched event from the adapter to the server, with the corresponding subscription id - let notify_price_sched_update = NotifyPriceSched { - subscription: subscription_id, - }; - notify_price_sched_tx - .send(notify_price_sched_update) - .await - .unwrap(); - - // Assert that the client connection receives the notify_price notification - // with the correct subscription ID. - assert_eq!( - test_client.recv_json().await, - r#"{"jsonrpc":"2.0","method":"notify_price_sched","params":{"subscription":27}}"# - ) - } - _ => panic!("Uexpected message received from adapter"), - }; - } - - /// Send a batch of requests with one of them mangled. - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn batch_request_partial_failure() { - // Start and connect to the JRPC server - let (_test_server, mut test_client, mut test_adapter, mut log_buffer) = - start_server().await; - - let product_account_key = "some_product_account".to_string(); - let valid_params = GetProductParams { - account: product_account_key.clone(), - }; - - test_client - .send_batch(vec![ - // Should work - Request::with_params( - Id::from(15), - "get_product".to_string(), - serde_json::to_value(&valid_params).unwrap(), - ), - // Should fail - Request::with_params( - Id::from(666), // Note: Spooky - "update_price".to_string(), - serde_json::json!({}), - ), - ]) - .await; - - let product_account = ProductAccount { - account: product_account_key, - attr_dict: Attrs::from( - [ - ("symbol", "BTC/USD"), - ("asset_type", "Crypto"), - ("country", "US"), - ("quote_currency", "USD"), - ("tenor", "spot"), - ] - .map(|(k, v)| (k.to_string(), v.to_string())), - ), - price_accounts: vec![PriceAccount { - account: "some_price_account".to_string(), - price_type: "price".to_string(), - price_exponent: 8, - status: "trading".to_string(), - price: 536, - conf: 67, - twap: 276, - twac: 463, - valid_slot: 4628, - pub_slot: 4736, - prev_slot: 3856, - prev_price: 400, - prev_conf: 45, - publisher_accounts: vec![ - PublisherAccount { - account: "some_publisher_account".to_string(), - status: "trading".to_string(), - price: 500, - conf: 24, - slot: 3563, - }, - PublisherAccount { - account: "another_publisher_account".to_string(), - status: "halted".to_string(), - price: 300, - conf: 683, - slot: 5834, - }, - ], - }], - }; - - // Handle the request in test adapter - if let adapter::Message::GetProduct { result_tx, .. } = test_adapter.recv().await { - result_tx.send(Ok(product_account.clone())).unwrap(); - } - - let received: serde_json::Value = - serde_json::from_str(&test_client.recv_json().await).unwrap(); - - let expected_product_account_json = serde_json::json!({ - "account": "some_product_account", - "attr_dict": { - "symbol": "BTC/USD", - "asset_type": "Crypto", - "country": "US", - "quote_currency": "USD", - "tenor": "spot" - }, - "price_accounts": [ - { - "account": "some_price_account", - "price_type": "price", - "price_exponent": 8, - "status": "trading", - "price": 536, - "conf": 67, - "twap": 276, - "twac": 463, - "valid_slot": 4628, - "pub_slot": 4736, - "prev_slot": 3856, - "prev_price": 400, - "prev_conf": 45, - "publisher_accounts": [ - { - "account": "some_publisher_account", - "status": "trading", - "price": 500, - "conf": 24, - "slot": 3563 - }, - { - "account": "another_publisher_account", - "status": "halted", - "price": 300, - "conf": 683, - "slot": 5834 - } - ] - - } - ] - } - ); - - let expected = serde_json::json!( - [ - { - "jsonrpc": "2.0", - "id": 15, - "result": expected_product_account_json, - }, - { - "jsonrpc": "2.0", - "error": { - "code": -32603, - "message": "missing field `account`", - "data": null - }, - "id": 666 - } - ] - ); - - println!("Log contents:"); - for line in log_buffer.lines() { - println!("{}", String::from_utf8(line).unwrap()); - } - - assert_eq!(received, expected); - } -} diff --git a/src/agent/pythd/api/rpc/get_all_products.rs b/src/agent/pythd/api/rpc/get_all_products.rs index 0d4d3bb..9806a00 100644 --- a/src/agent/pythd/api/rpc/get_all_products.rs +++ b/src/agent/pythd/api/rpc/get_all_products.rs @@ -1,18 +1,12 @@ use { crate::agent::pythd::adapter, anyhow::Result, - tokio::sync::{ - mpsc, - oneshot, - }, }; -pub async fn get_all_products( - adapter_tx: &mpsc::Sender, -) -> Result { - let (result_tx, result_rx) = oneshot::channel(); - adapter_tx - .send(adapter::Message::GetAllProducts { result_tx }) - .await?; - Ok(serde_json::to_value(result_rx.await??)?) +pub async fn get_all_products(adapter: &S) -> Result +where + S: adapter::AdapterApi, +{ + let products = adapter.get_all_products().await?; + Ok(serde_json::to_value(products)?) } diff --git a/src/agent/pythd/api/rpc/get_product.rs b/src/agent/pythd/api/rpc/get_product.rs index a46a670..2a6d8c0 100644 --- a/src/agent/pythd/api/rpc/get_product.rs +++ b/src/agent/pythd/api/rpc/get_product.rs @@ -12,27 +12,21 @@ use { Request, Value, }, - tokio::sync::{ - mpsc, - oneshot, - }, }; -pub async fn get_product( - adapter_tx: &mpsc::Sender, +pub async fn get_product( + adapter: &S, request: &Request, -) -> Result { +) -> Result +where + S: adapter::AdapterApi, +{ let params: GetProductParams = { let value = request.params.clone(); serde_json::from_value(value.ok_or_else(|| anyhow!("Missing request parameters"))?) }?; - let (result_tx, result_rx) = oneshot::channel(); - adapter_tx - .send(adapter::Message::GetProduct { - account: params.account, - result_tx, - }) - .await?; - Ok(serde_json::to_value(result_rx.await??)?) + let account = params.account.parse::()?; + let product = adapter.get_product(&account).await?; + Ok(serde_json::to_value(product)?) } diff --git a/src/agent/pythd/api/rpc/get_product_list.rs b/src/agent/pythd/api/rpc/get_product_list.rs index 2495525..ff3e498 100644 --- a/src/agent/pythd/api/rpc/get_product_list.rs +++ b/src/agent/pythd/api/rpc/get_product_list.rs @@ -1,18 +1,12 @@ use { crate::agent::pythd::adapter, anyhow::Result, - tokio::sync::{ - mpsc, - oneshot, - }, }; -pub async fn get_product_list( - adapter_tx: &mpsc::Sender, -) -> Result { - let (result_tx, result_rx) = oneshot::channel(); - adapter_tx - .send(adapter::Message::GetProductList { result_tx }) - .await?; - Ok(serde_json::to_value(result_rx.await??)?) +pub async fn get_product_list(adapter: &S) -> Result +where + S: adapter::AdapterApi, +{ + let product_list = adapter.get_product_list().await?; + Ok(serde_json::to_value(product_list)?) } diff --git a/src/agent/pythd/api/rpc/subscribe_price.rs b/src/agent/pythd/api/rpc/subscribe_price.rs index a0843e9..bed8e22 100644 --- a/src/agent/pythd/api/rpc/subscribe_price.rs +++ b/src/agent/pythd/api/rpc/subscribe_price.rs @@ -14,17 +14,17 @@ use { Request, Value, }, - tokio::sync::{ - mpsc, - oneshot, - }, + tokio::sync::mpsc, }; -pub async fn subscribe_price( - adapter_tx: &mpsc::Sender, +pub async fn subscribe_price( + adapter: &S, notify_price_tx: &mpsc::Sender, request: &Request, -) -> Result { +) -> Result +where + S: adapter::AdapterApi, +{ let params: SubscribePriceParams = serde_json::from_value( request .params @@ -32,16 +32,10 @@ pub async fn subscribe_price( .ok_or_else(|| anyhow!("Missing request parameters"))?, )?; - let (result_tx, result_rx) = oneshot::channel(); - adapter_tx - .send(adapter::Message::SubscribePrice { - result_tx, - account: params.account, - notify_price_tx: notify_price_tx.clone(), - }) - .await?; + let account = params.account.parse::()?; + let subscription = adapter + .subscribe_price(&account, notify_price_tx.clone()) + .await; - Ok(serde_json::to_value(SubscribeResult { - subscription: result_rx.await??, - })?) + Ok(serde_json::to_value(SubscribeResult { subscription })?) } diff --git a/src/agent/pythd/api/rpc/subscribe_price_sched.rs b/src/agent/pythd/api/rpc/subscribe_price_sched.rs index 7bcea3f..619f7c1 100644 --- a/src/agent/pythd/api/rpc/subscribe_price_sched.rs +++ b/src/agent/pythd/api/rpc/subscribe_price_sched.rs @@ -14,17 +14,17 @@ use { Request, Value, }, - tokio::sync::{ - mpsc, - oneshot, - }, + tokio::sync::mpsc, }; -pub async fn subscribe_price_sched( - adapter_tx: &mpsc::Sender, +pub async fn subscribe_price_sched( + adapter: &S, notify_price_sched_tx: &mpsc::Sender, request: &Request, -) -> Result { +) -> Result +where + S: adapter::AdapterApi, +{ let params: SubscribePriceSchedParams = serde_json::from_value( request .params @@ -32,16 +32,10 @@ pub async fn subscribe_price_sched( .ok_or_else(|| anyhow!("Missing request parameters"))?, )?; - let (result_tx, result_rx) = oneshot::channel(); - adapter_tx - .send(adapter::Message::SubscribePriceSched { - result_tx, - account: params.account, - notify_price_sched_tx: notify_price_sched_tx.clone(), - }) - .await?; + let account = params.account.parse::()?; + let subscription = adapter + .subscribe_price_sched(&account, notify_price_sched_tx.clone()) + .await; - Ok(serde_json::to_value(SubscribeResult { - subscription: result_rx.await??, - })?) + Ok(serde_json::to_value(SubscribeResult { subscription })?) } diff --git a/src/agent/pythd/api/rpc/update_price.rs b/src/agent/pythd/api/rpc/update_price.rs index 6c44470..8fd122f 100644 --- a/src/agent/pythd/api/rpc/update_price.rs +++ b/src/agent/pythd/api/rpc/update_price.rs @@ -12,13 +12,15 @@ use { Request, Value, }, - tokio::sync::mpsc, }; -pub async fn update_price( - adapter_tx: &mpsc::Sender, +pub async fn update_price( + adapter: &S, request: &Request, -) -> Result { +) -> Result +where + S: adapter::AdapterApi, +{ let params: UpdatePriceParams = serde_json::from_value( request .params @@ -26,13 +28,13 @@ pub async fn update_price( .ok_or_else(|| anyhow!("Missing request parameters"))?, )?; - adapter_tx - .send(adapter::Message::UpdatePrice { - account: params.account, - price: params.price, - conf: params.conf, - status: params.status, - }) + adapter + .update_price( + ¶ms.account.parse::()?, + params.price, + params.conf, + params.status, + ) .await?; Ok(serde_json::to_value(0)?) diff --git a/src/agent/store/global.rs b/src/agent/store/global.rs index 154b212..cb534a2 100644 --- a/src/agent/store/global.rs +++ b/src/agent/store/global.rs @@ -1,4 +1,3 @@ -use std::collections::HashSet; // The Global Store stores a copy of all the product and price information held in the Pyth // on-chain aggregation contracts, across both the primary and secondary networks. // This enables this data to be easily queried by other components. @@ -14,7 +13,7 @@ use { ProductGlobalMetrics, PROMETHEUS_REGISTRY, }, - pythd::adapter, + pythd::adapter::AdapterApi, solana::network::Network, }, anyhow::{ @@ -24,9 +23,13 @@ use { pyth_sdk::Identifier, slog::Logger, solana_sdk::pubkey::Pubkey, - std::collections::{ - BTreeMap, - HashMap, + std::{ + collections::{ + BTreeMap, + HashMap, + HashSet, + }, + sync::Arc, }, tokio::{ sync::{ @@ -119,7 +122,7 @@ pub enum Lookup { }, } -pub struct Store { +pub struct Store { /// The actual data on primary network account_data_primary: AllAccountsData, @@ -147,25 +150,31 @@ pub struct Store { /// Channel on which account updates are received from the secondary network secondary_updates_rx: mpsc::Receiver, - /// Channel on which to communicate with the pythd API adapter - pythd_adapter_tx: mpsc::Sender, + /// Reference the Pythd Adapter. + pythd_adapter: Arc, logger: Logger, } -pub fn spawn_store( +pub fn spawn_store( lookup_rx: mpsc::Receiver, primary_updates_rx: mpsc::Receiver, secondary_updates_rx: mpsc::Receiver, - pythd_adapter_tx: mpsc::Sender, + pythd_adapter: Arc, logger: Logger, -) -> JoinHandle<()> { +) -> JoinHandle<()> +where + S: AdapterApi, + S: Send, + S: Sync, + S: 'static, +{ tokio::spawn(async move { Store::new( lookup_rx, primary_updates_rx, secondary_updates_rx, - pythd_adapter_tx, + pythd_adapter, logger, ) .await @@ -174,12 +183,17 @@ pub fn spawn_store( }) } -impl Store { +impl Store +where + S: AdapterApi, + S: Send, + S: Sync, +{ pub async fn new( lookup_rx: mpsc::Receiver, primary_updates_rx: mpsc::Receiver, secondary_updates_rx: mpsc::Receiver, - pythd_adapter_tx: mpsc::Sender, + pythd_adapter: Arc, logger: Logger, ) -> Self { let prom_registry_ref = &mut &mut PROMETHEUS_REGISTRY.lock().await; @@ -193,7 +207,7 @@ impl Store { lookup_rx, primary_updates_rx, secondary_updates_rx, - pythd_adapter_tx, + pythd_adapter, logger, } } @@ -277,15 +291,15 @@ impl Store { // As the account data might differ between the two networks // we only notify the adapter of the primary network updates. if let Network::Primary = network { - self.pythd_adapter_tx - .send(adapter::Message::GlobalStoreUpdate { - price_identifier: Identifier::new(account_key.to_bytes()), - price: account.agg.price, - conf: account.agg.conf, - status: account.agg.status, - valid_slot: account.valid_slot, - pub_slot: account.agg.pub_slot, - }) + self.pythd_adapter + .global_store_update( + Identifier::new(account_key.to_bytes()), + account.agg.price, + account.agg.conf, + account.agg.status, + account.valid_slot, + account.agg.pub_slot, + ) .await .map_err(|_| anyhow!("failed to notify pythd adapter of account update"))?; }